static_call: Add some validation
[linux-block.git] / kernel / static_call.c
CommitLineData
9183c3f9
JP
1// SPDX-License-Identifier: GPL-2.0
2#include <linux/init.h>
3#include <linux/static_call.h>
4#include <linux/bug.h>
5#include <linux/smp.h>
6#include <linux/sort.h>
7#include <linux/slab.h>
8#include <linux/module.h>
9#include <linux/cpu.h>
10#include <linux/processor.h>
11#include <asm/sections.h>
12
13extern struct static_call_site __start_static_call_sites[],
14 __stop_static_call_sites[];
15
16static bool static_call_initialized;
17
9183c3f9
JP
18/* mutex to protect key modules/sites */
19static DEFINE_MUTEX(static_call_mutex);
20
21static void static_call_lock(void)
22{
23 mutex_lock(&static_call_mutex);
24}
25
26static void static_call_unlock(void)
27{
28 mutex_unlock(&static_call_mutex);
29}
30
31static inline void *static_call_addr(struct static_call_site *site)
32{
33 return (void *)((long)site->addr + (long)&site->addr);
34}
35
36
37static inline struct static_call_key *static_call_key(const struct static_call_site *site)
38{
39 return (struct static_call_key *)
5b06fd3b 40 (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
9183c3f9
JP
41}
42
43/* These assume the key is word-aligned. */
44static inline bool static_call_is_init(struct static_call_site *site)
45{
5b06fd3b
PZ
46 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
47}
48
49static inline bool static_call_is_tail(struct static_call_site *site)
50{
51 return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
9183c3f9
JP
52}
53
54static inline void static_call_set_init(struct static_call_site *site)
55{
5b06fd3b 56 site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
9183c3f9
JP
57 (long)&site->key;
58}
59
60static int static_call_site_cmp(const void *_a, const void *_b)
61{
62 const struct static_call_site *a = _a;
63 const struct static_call_site *b = _b;
64 const struct static_call_key *key_a = static_call_key(a);
65 const struct static_call_key *key_b = static_call_key(b);
66
67 if (key_a < key_b)
68 return -1;
69
70 if (key_a > key_b)
71 return 1;
72
73 return 0;
74}
75
76static void static_call_site_swap(void *_a, void *_b, int size)
77{
78 long delta = (unsigned long)_a - (unsigned long)_b;
79 struct static_call_site *a = _a;
80 struct static_call_site *b = _b;
81 struct static_call_site tmp = *a;
82
83 a->addr = b->addr - delta;
84 a->key = b->key - delta;
85
86 b->addr = tmp.addr + delta;
87 b->key = tmp.key + delta;
88}
89
90static inline void static_call_sort_entries(struct static_call_site *start,
91 struct static_call_site *stop)
92{
93 sort(start, stop - start, sizeof(struct static_call_site),
94 static_call_site_cmp, static_call_site_swap);
95}
96
97void __static_call_update(struct static_call_key *key, void *tramp, void *func)
98{
99 struct static_call_site *site, *stop;
100 struct static_call_mod *site_mod;
101
102 cpus_read_lock();
103 static_call_lock();
104
105 if (key->func == func)
106 goto done;
107
108 key->func = func;
109
5b06fd3b 110 arch_static_call_transform(NULL, tramp, func, false);
9183c3f9
JP
111
112 /*
113 * If uninitialized, we'll not update the callsites, but they still
114 * point to the trampoline and we just patched that.
115 */
116 if (WARN_ON_ONCE(!static_call_initialized))
117 goto done;
118
119 for (site_mod = key->mods; site_mod; site_mod = site_mod->next) {
120 struct module *mod = site_mod->mod;
121
122 if (!site_mod->sites) {
123 /*
124 * This can happen if the static call key is defined in
125 * a module which doesn't use it.
126 */
127 continue;
128 }
129
130 stop = __stop_static_call_sites;
131
132#ifdef CONFIG_MODULES
133 if (mod) {
134 stop = mod->static_call_sites +
135 mod->num_static_call_sites;
136 }
137#endif
138
139 for (site = site_mod->sites;
140 site < stop && static_call_key(site) == key; site++) {
141 void *site_addr = static_call_addr(site);
142
143 if (static_call_is_init(site)) {
144 /*
145 * Don't write to call sites which were in
146 * initmem and have since been freed.
147 */
148 if (!mod && system_state >= SYSTEM_RUNNING)
149 continue;
150 if (mod && !within_module_init((unsigned long)site_addr, mod))
151 continue;
152 }
153
154 if (!kernel_text_address((unsigned long)site_addr)) {
155 WARN_ONCE(1, "can't patch static call site at %pS",
156 site_addr);
157 continue;
158 }
159
5b06fd3b
PZ
160 arch_static_call_transform(site_addr, NULL, func,
161 static_call_is_tail(site));
9183c3f9
JP
162 }
163 }
164
165done:
166 static_call_unlock();
167 cpus_read_unlock();
168}
169EXPORT_SYMBOL_GPL(__static_call_update);
170
171static int __static_call_init(struct module *mod,
172 struct static_call_site *start,
173 struct static_call_site *stop)
174{
175 struct static_call_site *site;
176 struct static_call_key *key, *prev_key = NULL;
177 struct static_call_mod *site_mod;
178
179 if (start == stop)
180 return 0;
181
182 static_call_sort_entries(start, stop);
183
184 for (site = start; site < stop; site++) {
185 void *site_addr = static_call_addr(site);
186
187 if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
188 (!mod && init_section_contains(site_addr, 1)))
189 static_call_set_init(site);
190
191 key = static_call_key(site);
192 if (key != prev_key) {
193 prev_key = key;
194
195 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
196 if (!site_mod)
197 return -ENOMEM;
198
199 site_mod->mod = mod;
200 site_mod->sites = site;
201 site_mod->next = key->mods;
202 key->mods = site_mod;
203 }
204
5b06fd3b
PZ
205 arch_static_call_transform(site_addr, NULL, key->func,
206 static_call_is_tail(site));
9183c3f9
JP
207 }
208
209 return 0;
210}
211
6333e8f7
PZ
212static int addr_conflict(struct static_call_site *site, void *start, void *end)
213{
214 unsigned long addr = (unsigned long)static_call_addr(site);
215
216 if (addr <= (unsigned long)end &&
217 addr + CALL_INSN_SIZE > (unsigned long)start)
218 return 1;
219
220 return 0;
221}
222
223static int __static_call_text_reserved(struct static_call_site *iter_start,
224 struct static_call_site *iter_stop,
225 void *start, void *end)
226{
227 struct static_call_site *iter = iter_start;
228
229 while (iter < iter_stop) {
230 if (addr_conflict(iter, start, end))
231 return 1;
232 iter++;
233 }
234
235 return 0;
236}
237
9183c3f9
JP
238#ifdef CONFIG_MODULES
239
6333e8f7
PZ
240static int __static_call_mod_text_reserved(void *start, void *end)
241{
242 struct module *mod;
243 int ret;
244
245 preempt_disable();
246 mod = __module_text_address((unsigned long)start);
247 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
248 if (!try_module_get(mod))
249 mod = NULL;
250 preempt_enable();
251
252 if (!mod)
253 return 0;
254
255 ret = __static_call_text_reserved(mod->static_call_sites,
256 mod->static_call_sites + mod->num_static_call_sites,
257 start, end);
258
259 module_put(mod);
260
261 return ret;
262}
263
9183c3f9
JP
264static int static_call_add_module(struct module *mod)
265{
266 return __static_call_init(mod, mod->static_call_sites,
267 mod->static_call_sites + mod->num_static_call_sites);
268}
269
270static void static_call_del_module(struct module *mod)
271{
272 struct static_call_site *start = mod->static_call_sites;
273 struct static_call_site *stop = mod->static_call_sites +
274 mod->num_static_call_sites;
275 struct static_call_key *key, *prev_key = NULL;
276 struct static_call_mod *site_mod, **prev;
277 struct static_call_site *site;
278
279 for (site = start; site < stop; site++) {
280 key = static_call_key(site);
281 if (key == prev_key)
282 continue;
283
284 prev_key = key;
285
286 for (prev = &key->mods, site_mod = key->mods;
287 site_mod && site_mod->mod != mod;
288 prev = &site_mod->next, site_mod = site_mod->next)
289 ;
290
291 if (!site_mod)
292 continue;
293
294 *prev = site_mod->next;
295 kfree(site_mod);
296 }
297}
298
299static int static_call_module_notify(struct notifier_block *nb,
300 unsigned long val, void *data)
301{
302 struct module *mod = data;
303 int ret = 0;
304
305 cpus_read_lock();
306 static_call_lock();
307
308 switch (val) {
309 case MODULE_STATE_COMING:
310 ret = static_call_add_module(mod);
311 if (ret) {
312 WARN(1, "Failed to allocate memory for static calls");
313 static_call_del_module(mod);
314 }
315 break;
316 case MODULE_STATE_GOING:
317 static_call_del_module(mod);
318 break;
319 }
320
321 static_call_unlock();
322 cpus_read_unlock();
323
324 return notifier_from_errno(ret);
325}
326
327static struct notifier_block static_call_module_nb = {
328 .notifier_call = static_call_module_notify,
329};
330
6333e8f7
PZ
331#else
332
333static inline int __static_call_mod_text_reserved(void *start, void *end)
334{
335 return 0;
336}
337
9183c3f9
JP
338#endif /* CONFIG_MODULES */
339
6333e8f7
PZ
340int static_call_text_reserved(void *start, void *end)
341{
342 int ret = __static_call_text_reserved(__start_static_call_sites,
343 __stop_static_call_sites, start, end);
344
345 if (ret)
346 return ret;
347
348 return __static_call_mod_text_reserved(start, end);
349}
350
9183c3f9
JP
351static void __init static_call_init(void)
352{
353 int ret;
354
355 if (static_call_initialized)
356 return;
357
358 cpus_read_lock();
359 static_call_lock();
360 ret = __static_call_init(NULL, __start_static_call_sites,
361 __stop_static_call_sites);
362 static_call_unlock();
363 cpus_read_unlock();
364
365 if (ret) {
366 pr_err("Failed to allocate memory for static_call!\n");
367 BUG();
368 }
369
370 static_call_initialized = true;
371
372#ifdef CONFIG_MODULES
373 register_module_notifier(&static_call_module_nb);
374#endif
375}
376early_initcall(static_call_init);
f03c4129
PZ
377
378#ifdef CONFIG_STATIC_CALL_SELFTEST
379
380static int func_a(int x)
381{
382 return x+1;
383}
384
385static int func_b(int x)
386{
387 return x+2;
388}
389
390DEFINE_STATIC_CALL(sc_selftest, func_a);
391
392static struct static_call_data {
393 int (*func)(int);
394 int val;
395 int expect;
396} static_call_data [] __initdata = {
397 { NULL, 2, 3 },
398 { func_b, 2, 4 },
399 { func_a, 2, 3 }
400};
401
402static int __init test_static_call_init(void)
403{
404 int i;
405
406 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
407 struct static_call_data *scd = &static_call_data[i];
408
409 if (scd->func)
410 static_call_update(sc_selftest, scd->func);
411
412 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
413 }
414
415 return 0;
416}
417early_initcall(test_static_call_init);
418
419#endif /* CONFIG_STATIC_CALL_SELFTEST */