Commit | Line | Data |
---|---|---|
e7096c13 JD |
1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* | |
3 | * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. | |
4 | */ | |
5 | ||
6 | #include "allowedips.h" | |
7 | #include "peer.h" | |
8 | ||
46622219 | 9 | enum { MAX_ALLOWEDIPS_DEPTH = 129 }; |
c31b14d8 | 10 | |
dc680de2 JD |
11 | static struct kmem_cache *node_cache; |
12 | ||
e7096c13 JD |
13 | static void swap_endian(u8 *dst, const u8 *src, u8 bits) |
14 | { | |
15 | if (bits == 32) { | |
16 | *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); | |
17 | } else if (bits == 128) { | |
948f991c HD |
18 | ((u64 *)dst)[0] = get_unaligned_be64(src); |
19 | ((u64 *)dst)[1] = get_unaligned_be64(src + 8); | |
e7096c13 JD |
20 | } |
21 | } | |
22 | ||
23 | static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, | |
24 | u8 cidr, u8 bits) | |
25 | { | |
26 | node->cidr = cidr; | |
27 | node->bit_at_a = cidr / 8U; | |
28 | #ifdef __LITTLE_ENDIAN | |
29 | node->bit_at_a ^= (bits / 8U - 1U) % 8U; | |
30 | #endif | |
31 | node->bit_at_b = 7U - (cidr % 8U); | |
32 | node->bitlen = bits; | |
33 | memcpy(node->bits, src, bits / 8U); | |
34 | } | |
bf7b042d JD |
35 | |
36 | static inline u8 choose(struct allowedips_node *node, const u8 *key) | |
37 | { | |
38 | return (key[node->bit_at_a] >> node->bit_at_b) & 1; | |
39 | } | |
e7096c13 | 40 | |
e7096c13 JD |
41 | static void push_rcu(struct allowedips_node **stack, |
42 | struct allowedips_node __rcu *p, unsigned int *len) | |
43 | { | |
44 | if (rcu_access_pointer(p)) { | |
46622219 | 45 | if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_DEPTH)) |
c31b14d8 | 46 | return; |
e7096c13 JD |
47 | stack[(*len)++] = rcu_dereference_raw(p); |
48 | } | |
49 | } | |
50 | ||
dc680de2 JD |
51 | static void node_free_rcu(struct rcu_head *rcu) |
52 | { | |
53 | kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu)); | |
54 | } | |
55 | ||
e7096c13 JD |
56 | static void root_free_rcu(struct rcu_head *rcu) |
57 | { | |
46622219 | 58 | struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { |
e7096c13 JD |
59 | container_of(rcu, struct allowedips_node, rcu) }; |
60 | unsigned int len = 1; | |
61 | ||
62 | while (len > 0 && (node = stack[--len])) { | |
63 | push_rcu(stack, node->bit[0], &len); | |
64 | push_rcu(stack, node->bit[1], &len); | |
dc680de2 | 65 | kmem_cache_free(node_cache, node); |
e7096c13 JD |
66 | } |
67 | } | |
68 | ||
69 | static void root_remove_peer_lists(struct allowedips_node *root) | |
70 | { | |
46622219 | 71 | struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { root }; |
e7096c13 JD |
72 | unsigned int len = 1; |
73 | ||
74 | while (len > 0 && (node = stack[--len])) { | |
75 | push_rcu(stack, node->bit[0], &len); | |
76 | push_rcu(stack, node->bit[1], &len); | |
77 | if (rcu_access_pointer(node->peer)) | |
78 | list_del(&node->peer_list); | |
79 | } | |
80 | } | |
81 | ||
e7096c13 JD |
82 | static unsigned int fls128(u64 a, u64 b) |
83 | { | |
84 | return a ? fls64(a) + 64U : fls64(b); | |
85 | } | |
86 | ||
87 | static u8 common_bits(const struct allowedips_node *node, const u8 *key, | |
88 | u8 bits) | |
89 | { | |
90 | if (bits == 32) | |
91 | return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); | |
92 | else if (bits == 128) | |
93 | return 128U - fls128( | |
94 | *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], | |
95 | *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); | |
96 | return 0; | |
97 | } | |
98 | ||
99 | static bool prefix_matches(const struct allowedips_node *node, const u8 *key, | |
100 | u8 bits) | |
101 | { | |
102 | /* This could be much faster if it actually just compared the common | |
103 | * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and | |
104 | * the rest, but it turns out that common_bits is already super fast on | |
105 | * modern processors, even taking into account the unfortunate bswap. | |
106 | * So, we just inline it like this instead. | |
107 | */ | |
108 | return common_bits(node, key, bits) >= node->cidr; | |
109 | } | |
110 | ||
111 | static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, | |
112 | const u8 *key) | |
113 | { | |
114 | struct allowedips_node *node = trie, *found = NULL; | |
115 | ||
116 | while (node && prefix_matches(node, key, bits)) { | |
117 | if (rcu_access_pointer(node->peer)) | |
118 | found = node; | |
119 | if (node->cidr == bits) | |
120 | break; | |
bf7b042d | 121 | node = rcu_dereference_bh(node->bit[choose(node, key)]); |
e7096c13 JD |
122 | } |
123 | return found; | |
124 | } | |
125 | ||
126 | /* Returns a strong reference to a peer */ | |
127 | static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, | |
128 | const void *be_ip) | |
129 | { | |
130 | /* Aligned so it can be passed to fls/fls64 */ | |
131 | u8 ip[16] __aligned(__alignof(u64)); | |
132 | struct allowedips_node *node; | |
133 | struct wg_peer *peer = NULL; | |
134 | ||
135 | swap_endian(ip, be_ip, bits); | |
136 | ||
137 | rcu_read_lock_bh(); | |
138 | retry: | |
139 | node = find_node(rcu_dereference_bh(root), bits, ip); | |
140 | if (node) { | |
141 | peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); | |
142 | if (!peer) | |
143 | goto retry; | |
144 | } | |
145 | rcu_read_unlock_bh(); | |
146 | return peer; | |
147 | } | |
148 | ||
149 | static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, | |
150 | u8 cidr, u8 bits, struct allowedips_node **rnode, | |
151 | struct mutex *lock) | |
152 | { | |
bf7b042d | 153 | struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); |
e7096c13 JD |
154 | struct allowedips_node *parent = NULL; |
155 | bool exact = false; | |
156 | ||
157 | while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { | |
158 | parent = node; | |
159 | if (parent->cidr == cidr) { | |
160 | exact = true; | |
161 | break; | |
162 | } | |
bf7b042d | 163 | node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); |
e7096c13 JD |
164 | } |
165 | *rnode = parent; | |
166 | return exact; | |
167 | } | |
168 | ||
ae928781 | 169 | static inline void connect_node(struct allowedips_node __rcu **parent, u8 bit, struct allowedips_node *node) |
bf7b042d JD |
170 | { |
171 | node->parent_bit_packed = (unsigned long)parent | bit; | |
172 | rcu_assign_pointer(*parent, node); | |
173 | } | |
174 | ||
175 | static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) | |
176 | { | |
177 | u8 bit = choose(parent, node->bits); | |
178 | connect_node(&parent->bit[bit], bit, node); | |
179 | } | |
180 | ||
e7096c13 JD |
181 | static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, |
182 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
183 | { | |
184 | struct allowedips_node *node, *parent, *down, *newnode; | |
185 | ||
186 | if (unlikely(cidr > bits || !peer)) | |
187 | return -EINVAL; | |
188 | ||
189 | if (!rcu_access_pointer(*trie)) { | |
dc680de2 | 190 | node = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
e7096c13 JD |
191 | if (unlikely(!node)) |
192 | return -ENOMEM; | |
193 | RCU_INIT_POINTER(node->peer, peer); | |
194 | list_add_tail(&node->peer_list, &peer->allowedips_list); | |
195 | copy_and_assign_cidr(node, key, cidr, bits); | |
bf7b042d | 196 | connect_node(trie, 2, node); |
e7096c13 JD |
197 | return 0; |
198 | } | |
199 | if (node_placement(*trie, key, cidr, bits, &node, lock)) { | |
200 | rcu_assign_pointer(node->peer, peer); | |
201 | list_move_tail(&node->peer_list, &peer->allowedips_list); | |
202 | return 0; | |
203 | } | |
204 | ||
dc680de2 | 205 | newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
e7096c13 JD |
206 | if (unlikely(!newnode)) |
207 | return -ENOMEM; | |
208 | RCU_INIT_POINTER(newnode->peer, peer); | |
209 | list_add_tail(&newnode->peer_list, &peer->allowedips_list); | |
210 | copy_and_assign_cidr(newnode, key, cidr, bits); | |
211 | ||
212 | if (!node) { | |
213 | down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); | |
214 | } else { | |
bf7b042d JD |
215 | const u8 bit = choose(node, key); |
216 | down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); | |
e7096c13 | 217 | if (!down) { |
bf7b042d | 218 | connect_node(&node->bit[bit], bit, newnode); |
e7096c13 JD |
219 | return 0; |
220 | } | |
221 | } | |
222 | cidr = min(cidr, common_bits(down, key, bits)); | |
223 | parent = node; | |
224 | ||
225 | if (newnode->cidr == cidr) { | |
bf7b042d JD |
226 | choose_and_connect_node(newnode, down); |
227 | if (!parent) | |
228 | connect_node(trie, 2, newnode); | |
229 | else | |
230 | choose_and_connect_node(parent, newnode); | |
f634f418 JD |
231 | return 0; |
232 | } | |
233 | ||
dc680de2 | 234 | node = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
f634f418 JD |
235 | if (unlikely(!node)) { |
236 | list_del(&newnode->peer_list); | |
dc680de2 | 237 | kmem_cache_free(node_cache, newnode); |
f634f418 JD |
238 | return -ENOMEM; |
239 | } | |
240 | INIT_LIST_HEAD(&node->peer_list); | |
241 | copy_and_assign_cidr(node, newnode->bits, cidr, bits); | |
242 | ||
bf7b042d JD |
243 | choose_and_connect_node(node, down); |
244 | choose_and_connect_node(node, newnode); | |
245 | if (!parent) | |
246 | connect_node(trie, 2, node); | |
247 | else | |
248 | choose_and_connect_node(parent, node); | |
e7096c13 JD |
249 | return 0; |
250 | } | |
251 | ||
252 | void wg_allowedips_init(struct allowedips *table) | |
253 | { | |
254 | table->root4 = table->root6 = NULL; | |
255 | table->seq = 1; | |
256 | } | |
257 | ||
258 | void wg_allowedips_free(struct allowedips *table, struct mutex *lock) | |
259 | { | |
260 | struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; | |
261 | ||
262 | ++table->seq; | |
263 | RCU_INIT_POINTER(table->root4, NULL); | |
264 | RCU_INIT_POINTER(table->root6, NULL); | |
265 | if (rcu_access_pointer(old4)) { | |
266 | struct allowedips_node *node = rcu_dereference_protected(old4, | |
267 | lockdep_is_held(lock)); | |
268 | ||
269 | root_remove_peer_lists(node); | |
270 | call_rcu(&node->rcu, root_free_rcu); | |
271 | } | |
272 | if (rcu_access_pointer(old6)) { | |
273 | struct allowedips_node *node = rcu_dereference_protected(old6, | |
274 | lockdep_is_held(lock)); | |
275 | ||
276 | root_remove_peer_lists(node); | |
277 | call_rcu(&node->rcu, root_free_rcu); | |
278 | } | |
279 | } | |
280 | ||
281 | int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, | |
282 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
283 | { | |
284 | /* Aligned so it can be passed to fls */ | |
285 | u8 key[4] __aligned(__alignof(u32)); | |
286 | ||
287 | ++table->seq; | |
288 | swap_endian(key, (const u8 *)ip, 32); | |
289 | return add(&table->root4, 32, key, cidr, peer, lock); | |
290 | } | |
291 | ||
292 | int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, | |
293 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
294 | { | |
295 | /* Aligned so it can be passed to fls64 */ | |
296 | u8 key[16] __aligned(__alignof(u64)); | |
297 | ||
298 | ++table->seq; | |
299 | swap_endian(key, (const u8 *)ip, 128); | |
300 | return add(&table->root6, 128, key, cidr, peer, lock); | |
301 | } | |
302 | ||
303 | void wg_allowedips_remove_by_peer(struct allowedips *table, | |
304 | struct wg_peer *peer, struct mutex *lock) | |
305 | { | |
bf7b042d JD |
306 | struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; |
307 | bool free_parent; | |
f634f418 JD |
308 | |
309 | if (list_empty(&peer->allowedips_list)) | |
310 | return; | |
e7096c13 | 311 | ++table->seq; |
f634f418 JD |
312 | list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { |
313 | list_del_init(&node->peer_list); | |
314 | RCU_INIT_POINTER(node->peer, NULL); | |
315 | if (node->bit[0] && node->bit[1]) | |
316 | continue; | |
bf7b042d JD |
317 | child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], |
318 | lockdep_is_held(lock)); | |
f634f418 | 319 | if (child) |
bf7b042d JD |
320 | child->parent_bit_packed = node->parent_bit_packed; |
321 | parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); | |
322 | *parent_bit = child; | |
323 | parent = (void *)parent_bit - | |
324 | offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); | |
325 | free_parent = !rcu_access_pointer(node->bit[0]) && | |
326 | !rcu_access_pointer(node->bit[1]) && | |
327 | (node->parent_bit_packed & 3) <= 1 && | |
328 | !rcu_access_pointer(parent->peer); | |
329 | if (free_parent) | |
330 | child = rcu_dereference_protected( | |
331 | parent->bit[!(node->parent_bit_packed & 1)], | |
332 | lockdep_is_held(lock)); | |
dc680de2 | 333 | call_rcu(&node->rcu, node_free_rcu); |
bf7b042d JD |
334 | if (!free_parent) |
335 | continue; | |
336 | if (child) | |
337 | child->parent_bit_packed = parent->parent_bit_packed; | |
338 | *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; | |
339 | call_rcu(&parent->rcu, node_free_rcu); | |
f634f418 | 340 | } |
e7096c13 JD |
341 | } |
342 | ||
343 | int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) | |
344 | { | |
345 | const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); | |
346 | swap_endian(ip, node->bits, node->bitlen); | |
347 | memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); | |
348 | if (node->cidr) | |
349 | ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); | |
350 | ||
351 | *cidr = node->cidr; | |
352 | return node->bitlen == 32 ? AF_INET : AF_INET6; | |
353 | } | |
354 | ||
355 | /* Returns a strong reference to a peer */ | |
356 | struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, | |
357 | struct sk_buff *skb) | |
358 | { | |
359 | if (skb->protocol == htons(ETH_P_IP)) | |
360 | return lookup(table->root4, 32, &ip_hdr(skb)->daddr); | |
361 | else if (skb->protocol == htons(ETH_P_IPV6)) | |
362 | return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); | |
363 | return NULL; | |
364 | } | |
365 | ||
366 | /* Returns a strong reference to a peer */ | |
367 | struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, | |
368 | struct sk_buff *skb) | |
369 | { | |
370 | if (skb->protocol == htons(ETH_P_IP)) | |
371 | return lookup(table->root4, 32, &ip_hdr(skb)->saddr); | |
372 | else if (skb->protocol == htons(ETH_P_IPV6)) | |
373 | return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); | |
374 | return NULL; | |
375 | } | |
376 | ||
dc680de2 JD |
377 | int __init wg_allowedips_slab_init(void) |
378 | { | |
379 | node_cache = KMEM_CACHE(allowedips_node, 0); | |
380 | return node_cache ? 0 : -ENOMEM; | |
381 | } | |
382 | ||
383 | void wg_allowedips_slab_uninit(void) | |
384 | { | |
385 | rcu_barrier(); | |
386 | kmem_cache_destroy(node_cache); | |
387 | } | |
388 | ||
e7096c13 | 389 | #include "selftest/allowedips.c" |