Commit | Line | Data |
---|---|---|
e7096c13 JD |
1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* | |
3 | * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. | |
4 | */ | |
5 | ||
6 | #include "allowedips.h" | |
7 | #include "peer.h" | |
8 | ||
dc680de2 JD |
9 | static struct kmem_cache *node_cache; |
10 | ||
e7096c13 JD |
11 | static void swap_endian(u8 *dst, const u8 *src, u8 bits) |
12 | { | |
13 | if (bits == 32) { | |
14 | *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); | |
15 | } else if (bits == 128) { | |
16 | ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); | |
17 | ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); | |
18 | } | |
19 | } | |
20 | ||
21 | static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, | |
22 | u8 cidr, u8 bits) | |
23 | { | |
24 | node->cidr = cidr; | |
25 | node->bit_at_a = cidr / 8U; | |
26 | #ifdef __LITTLE_ENDIAN | |
27 | node->bit_at_a ^= (bits / 8U - 1U) % 8U; | |
28 | #endif | |
29 | node->bit_at_b = 7U - (cidr % 8U); | |
30 | node->bitlen = bits; | |
31 | memcpy(node->bits, src, bits / 8U); | |
32 | } | |
33 | #define CHOOSE_NODE(parent, key) \ | |
34 | parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] | |
35 | ||
e7096c13 JD |
36 | static void push_rcu(struct allowedips_node **stack, |
37 | struct allowedips_node __rcu *p, unsigned int *len) | |
38 | { | |
39 | if (rcu_access_pointer(p)) { | |
40 | WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); | |
41 | stack[(*len)++] = rcu_dereference_raw(p); | |
42 | } | |
43 | } | |
44 | ||
dc680de2 JD |
45 | static void node_free_rcu(struct rcu_head *rcu) |
46 | { | |
47 | kmem_cache_free(node_cache, container_of(rcu, struct allowedips_node, rcu)); | |
48 | } | |
49 | ||
e7096c13 JD |
50 | static void root_free_rcu(struct rcu_head *rcu) |
51 | { | |
52 | struct allowedips_node *node, *stack[128] = { | |
53 | container_of(rcu, struct allowedips_node, rcu) }; | |
54 | unsigned int len = 1; | |
55 | ||
56 | while (len > 0 && (node = stack[--len])) { | |
57 | push_rcu(stack, node->bit[0], &len); | |
58 | push_rcu(stack, node->bit[1], &len); | |
dc680de2 | 59 | kmem_cache_free(node_cache, node); |
e7096c13 JD |
60 | } |
61 | } | |
62 | ||
63 | static void root_remove_peer_lists(struct allowedips_node *root) | |
64 | { | |
65 | struct allowedips_node *node, *stack[128] = { root }; | |
66 | unsigned int len = 1; | |
67 | ||
68 | while (len > 0 && (node = stack[--len])) { | |
69 | push_rcu(stack, node->bit[0], &len); | |
70 | push_rcu(stack, node->bit[1], &len); | |
71 | if (rcu_access_pointer(node->peer)) | |
72 | list_del(&node->peer_list); | |
73 | } | |
74 | } | |
75 | ||
e7096c13 JD |
76 | static unsigned int fls128(u64 a, u64 b) |
77 | { | |
78 | return a ? fls64(a) + 64U : fls64(b); | |
79 | } | |
80 | ||
81 | static u8 common_bits(const struct allowedips_node *node, const u8 *key, | |
82 | u8 bits) | |
83 | { | |
84 | if (bits == 32) | |
85 | return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); | |
86 | else if (bits == 128) | |
87 | return 128U - fls128( | |
88 | *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], | |
89 | *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); | |
90 | return 0; | |
91 | } | |
92 | ||
93 | static bool prefix_matches(const struct allowedips_node *node, const u8 *key, | |
94 | u8 bits) | |
95 | { | |
96 | /* This could be much faster if it actually just compared the common | |
97 | * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and | |
98 | * the rest, but it turns out that common_bits is already super fast on | |
99 | * modern processors, even taking into account the unfortunate bswap. | |
100 | * So, we just inline it like this instead. | |
101 | */ | |
102 | return common_bits(node, key, bits) >= node->cidr; | |
103 | } | |
104 | ||
105 | static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, | |
106 | const u8 *key) | |
107 | { | |
108 | struct allowedips_node *node = trie, *found = NULL; | |
109 | ||
110 | while (node && prefix_matches(node, key, bits)) { | |
111 | if (rcu_access_pointer(node->peer)) | |
112 | found = node; | |
113 | if (node->cidr == bits) | |
114 | break; | |
115 | node = rcu_dereference_bh(CHOOSE_NODE(node, key)); | |
116 | } | |
117 | return found; | |
118 | } | |
119 | ||
120 | /* Returns a strong reference to a peer */ | |
121 | static struct wg_peer *lookup(struct allowedips_node __rcu *root, u8 bits, | |
122 | const void *be_ip) | |
123 | { | |
124 | /* Aligned so it can be passed to fls/fls64 */ | |
125 | u8 ip[16] __aligned(__alignof(u64)); | |
126 | struct allowedips_node *node; | |
127 | struct wg_peer *peer = NULL; | |
128 | ||
129 | swap_endian(ip, be_ip, bits); | |
130 | ||
131 | rcu_read_lock_bh(); | |
132 | retry: | |
133 | node = find_node(rcu_dereference_bh(root), bits, ip); | |
134 | if (node) { | |
135 | peer = wg_peer_get_maybe_zero(rcu_dereference_bh(node->peer)); | |
136 | if (!peer) | |
137 | goto retry; | |
138 | } | |
139 | rcu_read_unlock_bh(); | |
140 | return peer; | |
141 | } | |
142 | ||
143 | static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, | |
144 | u8 cidr, u8 bits, struct allowedips_node **rnode, | |
145 | struct mutex *lock) | |
146 | { | |
147 | struct allowedips_node *node = rcu_dereference_protected(trie, | |
148 | lockdep_is_held(lock)); | |
149 | struct allowedips_node *parent = NULL; | |
150 | bool exact = false; | |
151 | ||
152 | while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { | |
153 | parent = node; | |
154 | if (parent->cidr == cidr) { | |
155 | exact = true; | |
156 | break; | |
157 | } | |
158 | node = rcu_dereference_protected(CHOOSE_NODE(parent, key), | |
159 | lockdep_is_held(lock)); | |
160 | } | |
161 | *rnode = parent; | |
162 | return exact; | |
163 | } | |
164 | ||
165 | static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, | |
166 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
167 | { | |
168 | struct allowedips_node *node, *parent, *down, *newnode; | |
169 | ||
170 | if (unlikely(cidr > bits || !peer)) | |
171 | return -EINVAL; | |
172 | ||
173 | if (!rcu_access_pointer(*trie)) { | |
dc680de2 | 174 | node = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
e7096c13 JD |
175 | if (unlikely(!node)) |
176 | return -ENOMEM; | |
177 | RCU_INIT_POINTER(node->peer, peer); | |
178 | list_add_tail(&node->peer_list, &peer->allowedips_list); | |
179 | copy_and_assign_cidr(node, key, cidr, bits); | |
f634f418 | 180 | rcu_assign_pointer(node->parent_bit, trie); |
e7096c13 JD |
181 | rcu_assign_pointer(*trie, node); |
182 | return 0; | |
183 | } | |
184 | if (node_placement(*trie, key, cidr, bits, &node, lock)) { | |
185 | rcu_assign_pointer(node->peer, peer); | |
186 | list_move_tail(&node->peer_list, &peer->allowedips_list); | |
187 | return 0; | |
188 | } | |
189 | ||
dc680de2 | 190 | newnode = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
e7096c13 JD |
191 | if (unlikely(!newnode)) |
192 | return -ENOMEM; | |
193 | RCU_INIT_POINTER(newnode->peer, peer); | |
194 | list_add_tail(&newnode->peer_list, &peer->allowedips_list); | |
195 | copy_and_assign_cidr(newnode, key, cidr, bits); | |
196 | ||
197 | if (!node) { | |
198 | down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); | |
199 | } else { | |
f634f418 | 200 | down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); |
e7096c13 | 201 | if (!down) { |
f634f418 | 202 | rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key)); |
e7096c13 JD |
203 | rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); |
204 | return 0; | |
205 | } | |
206 | } | |
207 | cidr = min(cidr, common_bits(down, key, bits)); | |
208 | parent = node; | |
209 | ||
210 | if (newnode->cidr == cidr) { | |
f634f418 | 211 | rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits)); |
e7096c13 | 212 | rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); |
f634f418 JD |
213 | if (!parent) { |
214 | rcu_assign_pointer(newnode->parent_bit, trie); | |
e7096c13 | 215 | rcu_assign_pointer(*trie, newnode); |
f634f418 JD |
216 | } else { |
217 | rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits)); | |
218 | rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); | |
e7096c13 | 219 | } |
f634f418 JD |
220 | return 0; |
221 | } | |
222 | ||
dc680de2 | 223 | node = kmem_cache_zalloc(node_cache, GFP_KERNEL); |
f634f418 JD |
224 | if (unlikely(!node)) { |
225 | list_del(&newnode->peer_list); | |
dc680de2 | 226 | kmem_cache_free(node_cache, newnode); |
f634f418 JD |
227 | return -ENOMEM; |
228 | } | |
229 | INIT_LIST_HEAD(&node->peer_list); | |
230 | copy_and_assign_cidr(node, newnode->bits, cidr, bits); | |
231 | ||
232 | rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits)); | |
233 | rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); | |
234 | rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits)); | |
235 | rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); | |
236 | if (!parent) { | |
237 | rcu_assign_pointer(node->parent_bit, trie); | |
238 | rcu_assign_pointer(*trie, node); | |
239 | } else { | |
240 | rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits)); | |
241 | rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); | |
e7096c13 JD |
242 | } |
243 | return 0; | |
244 | } | |
245 | ||
246 | void wg_allowedips_init(struct allowedips *table) | |
247 | { | |
248 | table->root4 = table->root6 = NULL; | |
249 | table->seq = 1; | |
250 | } | |
251 | ||
252 | void wg_allowedips_free(struct allowedips *table, struct mutex *lock) | |
253 | { | |
254 | struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; | |
255 | ||
256 | ++table->seq; | |
257 | RCU_INIT_POINTER(table->root4, NULL); | |
258 | RCU_INIT_POINTER(table->root6, NULL); | |
259 | if (rcu_access_pointer(old4)) { | |
260 | struct allowedips_node *node = rcu_dereference_protected(old4, | |
261 | lockdep_is_held(lock)); | |
262 | ||
263 | root_remove_peer_lists(node); | |
264 | call_rcu(&node->rcu, root_free_rcu); | |
265 | } | |
266 | if (rcu_access_pointer(old6)) { | |
267 | struct allowedips_node *node = rcu_dereference_protected(old6, | |
268 | lockdep_is_held(lock)); | |
269 | ||
270 | root_remove_peer_lists(node); | |
271 | call_rcu(&node->rcu, root_free_rcu); | |
272 | } | |
273 | } | |
274 | ||
275 | int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, | |
276 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
277 | { | |
278 | /* Aligned so it can be passed to fls */ | |
279 | u8 key[4] __aligned(__alignof(u32)); | |
280 | ||
281 | ++table->seq; | |
282 | swap_endian(key, (const u8 *)ip, 32); | |
283 | return add(&table->root4, 32, key, cidr, peer, lock); | |
284 | } | |
285 | ||
286 | int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, | |
287 | u8 cidr, struct wg_peer *peer, struct mutex *lock) | |
288 | { | |
289 | /* Aligned so it can be passed to fls64 */ | |
290 | u8 key[16] __aligned(__alignof(u64)); | |
291 | ||
292 | ++table->seq; | |
293 | swap_endian(key, (const u8 *)ip, 128); | |
294 | return add(&table->root6, 128, key, cidr, peer, lock); | |
295 | } | |
296 | ||
297 | void wg_allowedips_remove_by_peer(struct allowedips *table, | |
298 | struct wg_peer *peer, struct mutex *lock) | |
299 | { | |
f634f418 JD |
300 | struct allowedips_node *node, *child, *tmp; |
301 | ||
302 | if (list_empty(&peer->allowedips_list)) | |
303 | return; | |
e7096c13 | 304 | ++table->seq; |
f634f418 JD |
305 | list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { |
306 | list_del_init(&node->peer_list); | |
307 | RCU_INIT_POINTER(node->peer, NULL); | |
308 | if (node->bit[0] && node->bit[1]) | |
309 | continue; | |
310 | child = rcu_dereference_protected( | |
311 | node->bit[!rcu_access_pointer(node->bit[0])], | |
312 | lockdep_is_held(lock)); | |
313 | if (child) | |
314 | child->parent_bit = node->parent_bit; | |
315 | *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child; | |
dc680de2 | 316 | call_rcu(&node->rcu, node_free_rcu); |
f634f418 JD |
317 | |
318 | /* TODO: Note that we currently don't walk up and down in order to | |
319 | * free any potential filler nodes. This means that this function | |
320 | * doesn't free up as much as it could, which could be revisited | |
321 | * at some point. | |
322 | */ | |
323 | } | |
e7096c13 JD |
324 | } |
325 | ||
326 | int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) | |
327 | { | |
328 | const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); | |
329 | swap_endian(ip, node->bits, node->bitlen); | |
330 | memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); | |
331 | if (node->cidr) | |
332 | ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); | |
333 | ||
334 | *cidr = node->cidr; | |
335 | return node->bitlen == 32 ? AF_INET : AF_INET6; | |
336 | } | |
337 | ||
338 | /* Returns a strong reference to a peer */ | |
339 | struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, | |
340 | struct sk_buff *skb) | |
341 | { | |
342 | if (skb->protocol == htons(ETH_P_IP)) | |
343 | return lookup(table->root4, 32, &ip_hdr(skb)->daddr); | |
344 | else if (skb->protocol == htons(ETH_P_IPV6)) | |
345 | return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr); | |
346 | return NULL; | |
347 | } | |
348 | ||
349 | /* Returns a strong reference to a peer */ | |
350 | struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, | |
351 | struct sk_buff *skb) | |
352 | { | |
353 | if (skb->protocol == htons(ETH_P_IP)) | |
354 | return lookup(table->root4, 32, &ip_hdr(skb)->saddr); | |
355 | else if (skb->protocol == htons(ETH_P_IPV6)) | |
356 | return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); | |
357 | return NULL; | |
358 | } | |
359 | ||
dc680de2 JD |
360 | int __init wg_allowedips_slab_init(void) |
361 | { | |
362 | node_cache = KMEM_CACHE(allowedips_node, 0); | |
363 | return node_cache ? 0 : -ENOMEM; | |
364 | } | |
365 | ||
366 | void wg_allowedips_slab_uninit(void) | |
367 | { | |
368 | rcu_barrier(); | |
369 | kmem_cache_destroy(node_cache); | |
370 | } | |
371 | ||
e7096c13 | 372 | #include "selftest/allowedips.c" |