Merge tag 'net-6.7-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[linux-2.6-block.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import filecmp
7 import os
8 import re
9 import shutil
10 import tempfile
11 import yaml
12
13 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
14
15
16 def c_upper(name):
17     return name.upper().replace('-', '_')
18
19
20 def c_lower(name):
21     return name.lower().replace('-', '_')
22
23
24 def limit_to_number(name):
25     """
26     Turn a string limit like u32-max or s64-min into its numerical value
27     """
28     if name[0] == 'u' and name.endswith('-min'):
29         return 0
30     width = int(name[1:-4])
31     if name[0] == 's':
32         width -= 1
33     value = (1 << width) - 1
34     if name[0] == 's' and name.endswith('-min'):
35         value = -value - 1
36     return value
37
38
39 class BaseNlLib:
40     def get_family_id(self):
41         return 'ys->family_id'
42
43     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
44         ind = '\n\t\t' + '\t' * indent + ' '
45         if is_dump:
46             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
47         else:
48             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
49                    "ynl_cb_array, NLMSG_MIN_TYPE)"
50
51
52 class Type(SpecAttr):
53     def __init__(self, family, attr_set, attr, value):
54         super().__init__(family, attr_set, attr, value)
55
56         self.attr = attr
57         self.attr_set = attr_set
58         self.type = attr['type']
59         self.checks = attr.get('checks', {})
60
61         self.request = False
62         self.reply = False
63
64         if 'len' in attr:
65             self.len = attr['len']
66
67         if 'nested-attributes' in attr:
68             self.nested_attrs = attr['nested-attributes']
69             if self.nested_attrs == family.name:
70                 self.nested_render_name = f"{family.name}"
71             else:
72                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
73
74             if self.nested_attrs in self.family.consts:
75                 self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
76             else:
77                 self.nested_struct_type = 'struct ' + self.nested_render_name
78
79         self.c_name = c_lower(self.name)
80         if self.c_name in _C_KW:
81             self.c_name += '_'
82
83         # Added by resolve():
84         self.enum_name = None
85         delattr(self, "enum_name")
86
87     def get_limit(self, limit, default=None):
88         value = self.checks.get(limit, default)
89         if value is None:
90             return value
91         if not isinstance(value, int):
92             value = limit_to_number(value)
93         return value
94
95     def resolve(self):
96         if 'name-prefix' in self.attr:
97             enum_name = f"{self.attr['name-prefix']}{self.name}"
98         else:
99             enum_name = f"{self.attr_set.name_prefix}{self.name}"
100         self.enum_name = c_upper(enum_name)
101
102     def is_multi_val(self):
103         return None
104
105     def is_scalar(self):
106         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
107
108     def presence_type(self):
109         return 'bit'
110
111     def presence_member(self, space, type_filter):
112         if self.presence_type() != type_filter:
113             return
114
115         if self.presence_type() == 'bit':
116             pfx = '__' if space == 'user' else ''
117             return f"{pfx}u32 {self.c_name}:1;"
118
119         if self.presence_type() == 'len':
120             pfx = '__' if space == 'user' else ''
121             return f"{pfx}u32 {self.c_name}_len;"
122
123     def _complex_member_type(self, ri):
124         return None
125
126     def free_needs_iter(self):
127         return False
128
129     def free(self, ri, var, ref):
130         if self.is_multi_val() or self.presence_type() == 'len':
131             ri.cw.p(f'free({var}->{ref}{self.c_name});')
132
133     def arg_member(self, ri):
134         member = self._complex_member_type(ri)
135         if member:
136             arg = [member + ' *' + self.c_name]
137             if self.presence_type() == 'count':
138                 arg += ['unsigned int n_' + self.c_name]
139             return arg
140         raise Exception(f"Struct member not implemented for class type {self.type}")
141
142     def struct_member(self, ri):
143         if self.is_multi_val():
144             ri.cw.p(f"unsigned int n_{self.c_name};")
145         member = self._complex_member_type(ri)
146         if member:
147             ptr = '*' if self.is_multi_val() else ''
148             ri.cw.p(f"{member} {ptr}{self.c_name};")
149             return
150         members = self.arg_member(ri)
151         for one in members:
152             ri.cw.p(one + ';')
153
154     def _attr_policy(self, policy):
155         return '{ .type = ' + policy + ', }'
156
157     def attr_policy(self, cw):
158         policy = c_upper('nla-' + self.attr['type'])
159
160         spec = self._attr_policy(policy)
161         cw.p(f"\t[{self.enum_name}] = {spec},")
162
163     def _mnl_type(self):
164         # mnl does not have helpers for signed integer types
165         # turn signed type into unsigned
166         # this only makes sense for scalar types
167         t = self.type
168         if t[0] == 's':
169             t = 'u' + t[1:]
170         return t
171
172     def _attr_typol(self):
173         raise Exception(f"Type policy not implemented for class type {self.type}")
174
175     def attr_typol(self, cw):
176         typol = self._attr_typol()
177         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
178
179     def _attr_put_line(self, ri, var, line):
180         if self.presence_type() == 'bit':
181             ri.cw.p(f"if ({var}->_present.{self.c_name})")
182         elif self.presence_type() == 'len':
183             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
184         ri.cw.p(f"{line};")
185
186     def _attr_put_simple(self, ri, var, put_type):
187         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
188         self._attr_put_line(ri, var, line)
189
190     def attr_put(self, ri, var):
191         raise Exception(f"Put not implemented for class type {self.type}")
192
193     def _attr_get(self, ri, var):
194         raise Exception(f"Attr get not implemented for class type {self.type}")
195
196     def attr_get(self, ri, var, first):
197         lines, init_lines, local_vars = self._attr_get(ri, var)
198         if type(lines) is str:
199             lines = [lines]
200         if type(init_lines) is str:
201             init_lines = [init_lines]
202
203         kw = 'if' if first else 'else if'
204         ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
205         if local_vars:
206             for local in local_vars:
207                 ri.cw.p(local)
208             ri.cw.nl()
209
210         if not self.is_multi_val():
211             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
212             ri.cw.p("return MNL_CB_ERROR;")
213             if self.presence_type() == 'bit':
214                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
215
216         if init_lines:
217             ri.cw.nl()
218             for line in init_lines:
219                 ri.cw.p(line)
220
221         for line in lines:
222             ri.cw.p(line)
223         ri.cw.block_end()
224         return True
225
226     def _setter_lines(self, ri, member, presence):
227         raise Exception(f"Setter not implemented for class type {self.type}")
228
229     def setter(self, ri, space, direction, deref=False, ref=None):
230         ref = (ref if ref else []) + [self.c_name]
231         var = "req"
232         member = f"{var}->{'.'.join(ref)}"
233
234         code = []
235         presence = ''
236         for i in range(0, len(ref)):
237             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
238             if self.presence_type() == 'bit':
239                 code.append(presence + ' = 1;')
240         code += self._setter_lines(ri, member, presence)
241
242         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
243         free = bool([x for x in code if 'free(' in x])
244         alloc = bool([x for x in code if 'alloc(' in x])
245         if free and not alloc:
246             func_name = '__' + func_name
247         ri.cw.write_func('static inline void', func_name, body=code,
248                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
249
250
251 class TypeUnused(Type):
252     def presence_type(self):
253         return ''
254
255     def arg_member(self, ri):
256         return []
257
258     def _attr_get(self, ri, var):
259         return ['return MNL_CB_ERROR;'], None, None
260
261     def _attr_typol(self):
262         return '.type = YNL_PT_REJECT, '
263
264     def attr_policy(self, cw):
265         pass
266
267
268 class TypePad(Type):
269     def presence_type(self):
270         return ''
271
272     def arg_member(self, ri):
273         return []
274
275     def _attr_typol(self):
276         return '.type = YNL_PT_IGNORE, '
277
278     def attr_put(self, ri, var):
279         pass
280
281     def attr_get(self, ri, var, first):
282         pass
283
284     def attr_policy(self, cw):
285         pass
286
287     def setter(self, ri, space, direction, deref=False, ref=None):
288         pass
289
290
291 class TypeScalar(Type):
292     def __init__(self, family, attr_set, attr, value):
293         super().__init__(family, attr_set, attr, value)
294
295         self.byte_order_comment = ''
296         if 'byte-order' in attr:
297             self.byte_order_comment = f" /* {attr['byte-order']} */"
298
299         if 'enum' in self.attr:
300             enum = self.family.consts[self.attr['enum']]
301             low, high = enum.value_range()
302             if 'min' not in self.checks:
303                 if low != 0 or self.type[0] == 's':
304                     self.checks['min'] = low
305             if 'max' not in self.checks:
306                 self.checks['max'] = high
307
308         if 'min' in self.checks and 'max' in self.checks:
309             if self.get_limit('min') > self.get_limit('max'):
310                 raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
311             self.checks['range'] = True
312
313         low = min(self.get_limit('min', 0), self.get_limit('max', 0))
314         high = max(self.get_limit('min', 0), self.get_limit('max', 0))
315         if low < 0 and self.type[0] == 'u':
316             raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
317         if low < -32768 or high > 32767:
318             self.checks['full-range'] = True
319
320         # Added by resolve():
321         self.is_bitfield = None
322         delattr(self, "is_bitfield")
323         self.type_name = None
324         delattr(self, "type_name")
325
326     def resolve(self):
327         self.resolve_up(super())
328
329         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
330             self.is_bitfield = True
331         elif 'enum' in self.attr:
332             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
333         else:
334             self.is_bitfield = False
335
336         maybe_enum = not self.is_bitfield and 'enum' in self.attr
337         if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
338             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
339         elif self.is_auto_scalar:
340             self.type_name = '__' + self.type[0] + '64'
341         else:
342             self.type_name = '__' + self.type
343
344     def mnl_type(self):
345         return self._mnl_type()
346
347     def _attr_policy(self, policy):
348         if 'flags-mask' in self.checks or self.is_bitfield:
349             if self.is_bitfield:
350                 enum = self.family.consts[self.attr['enum']]
351                 mask = enum.get_mask(as_flags=True)
352             else:
353                 flags = self.family.consts[self.checks['flags-mask']]
354                 flag_cnt = len(flags['entries'])
355                 mask = (1 << flag_cnt) - 1
356             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
357         elif 'full-range' in self.checks:
358             return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
359         elif 'range' in self.checks:
360             return f"NLA_POLICY_RANGE({policy}, {self.get_limit('min')}, {self.get_limit('max')})"
361         elif 'min' in self.checks:
362             return f"NLA_POLICY_MIN({policy}, {self.get_limit('min')})"
363         elif 'max' in self.checks:
364             return f"NLA_POLICY_MAX({policy}, {self.get_limit('max')})"
365         return super()._attr_policy(policy)
366
367     def _attr_typol(self):
368         return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
369
370     def arg_member(self, ri):
371         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
372
373     def attr_put(self, ri, var):
374         self._attr_put_simple(ri, var, self.mnl_type())
375
376     def _attr_get(self, ri, var):
377         return f"{var}->{self.c_name} = mnl_attr_get_{self.mnl_type()}(attr);", None, None
378
379     def _setter_lines(self, ri, member, presence):
380         return [f"{member} = {self.c_name};"]
381
382
383 class TypeFlag(Type):
384     def arg_member(self, ri):
385         return []
386
387     def _attr_typol(self):
388         return '.type = YNL_PT_FLAG, '
389
390     def attr_put(self, ri, var):
391         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
392
393     def _attr_get(self, ri, var):
394         return [], None, None
395
396     def _setter_lines(self, ri, member, presence):
397         return []
398
399
400 class TypeString(Type):
401     def arg_member(self, ri):
402         return [f"const char *{self.c_name}"]
403
404     def presence_type(self):
405         return 'len'
406
407     def struct_member(self, ri):
408         ri.cw.p(f"char *{self.c_name};")
409
410     def _attr_typol(self):
411         return f'.type = YNL_PT_NUL_STR, '
412
413     def _attr_policy(self, policy):
414         if 'exact-len' in self.checks:
415             mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
416         else:
417             mem = '{ .type = ' + policy
418             if 'max-len' in self.checks:
419                 mem += ', .len = ' + str(self.get_limit('max-len'))
420             mem += ', }'
421         return mem
422
423     def attr_policy(self, cw):
424         if self.checks.get('unterminated-ok', False):
425             policy = 'NLA_STRING'
426         else:
427             policy = 'NLA_NUL_STRING'
428
429         spec = self._attr_policy(policy)
430         cw.p(f"\t[{self.enum_name}] = {spec},")
431
432     def attr_put(self, ri, var):
433         self._attr_put_simple(ri, var, 'strz')
434
435     def _attr_get(self, ri, var):
436         len_mem = var + '->_present.' + self.c_name + '_len'
437         return [f"{len_mem} = len;",
438                 f"{var}->{self.c_name} = malloc(len + 1);",
439                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
440                 f"{var}->{self.c_name}[len] = 0;"], \
441                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
442                ['unsigned int len;']
443
444     def _setter_lines(self, ri, member, presence):
445         return [f"free({member});",
446                 f"{presence}_len = strlen({self.c_name});",
447                 f"{member} = malloc({presence}_len + 1);",
448                 f'memcpy({member}, {self.c_name}, {presence}_len);',
449                 f'{member}[{presence}_len] = 0;']
450
451
452 class TypeBinary(Type):
453     def arg_member(self, ri):
454         return [f"const void *{self.c_name}", 'size_t len']
455
456     def presence_type(self):
457         return 'len'
458
459     def struct_member(self, ri):
460         ri.cw.p(f"void *{self.c_name};")
461
462     def _attr_typol(self):
463         return f'.type = YNL_PT_BINARY,'
464
465     def _attr_policy(self, policy):
466         if 'exact-len' in self.checks:
467             mem = 'NLA_POLICY_EXACT_LEN(' + str(self.checks['exact-len']) + ')'
468         else:
469             mem = '{ '
470             if len(self.checks) == 1 and 'min-len' in self.checks:
471                 mem += '.len = ' + str(self.get_limit('min-len'))
472             elif len(self.checks) == 0:
473                 mem += '.type = NLA_BINARY'
474             else:
475                 raise Exception('One or more of binary type checks not implemented, yet')
476             mem += ', }'
477         return mem
478
479     def attr_put(self, ri, var):
480         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
481                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
482
483     def _attr_get(self, ri, var):
484         len_mem = var + '->_present.' + self.c_name + '_len'
485         return [f"{len_mem} = len;",
486                 f"{var}->{self.c_name} = malloc(len);",
487                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
488                ['len = mnl_attr_get_payload_len(attr);'], \
489                ['unsigned int len;']
490
491     def _setter_lines(self, ri, member, presence):
492         return [f"free({member});",
493                 f"{presence}_len = len;",
494                 f"{member} = malloc({presence}_len);",
495                 f'memcpy({member}, {self.c_name}, {presence}_len);']
496
497
498 class TypeBitfield32(Type):
499     def _complex_member_type(self, ri):
500         return "struct nla_bitfield32"
501
502     def _attr_typol(self):
503         return f'.type = YNL_PT_BITFIELD32, '
504
505     def _attr_policy(self, policy):
506         if not 'enum' in self.attr:
507             raise Exception('Enum required for bitfield32 attr')
508         enum = self.family.consts[self.attr['enum']]
509         mask = enum.get_mask(as_flags=True)
510         return f"NLA_POLICY_BITFIELD32({mask})"
511
512     def attr_put(self, ri, var):
513         line = f"mnl_attr_put(nlh, {self.enum_name}, sizeof(struct nla_bitfield32), &{var}->{self.c_name})"
514         self._attr_put_line(ri, var, line)
515
516     def _attr_get(self, ri, var):
517         return f"memcpy(&{var}->{self.c_name}, mnl_attr_get_payload(attr), sizeof(struct nla_bitfield32));", None, None
518
519     def _setter_lines(self, ri, member, presence):
520         return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
521
522
523 class TypeNest(Type):
524     def _complex_member_type(self, ri):
525         return self.nested_struct_type
526
527     def free(self, ri, var, ref):
528         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
529
530     def _attr_typol(self):
531         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
532
533     def _attr_policy(self, policy):
534         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
535
536     def attr_put(self, ri, var):
537         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
538                             f"{self.enum_name}, &{var}->{self.c_name})")
539
540     def _attr_get(self, ri, var):
541         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
542                      "return MNL_CB_ERROR;"]
543         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
544                       f"parg.data = &{var}->{self.c_name};"]
545         return get_lines, init_lines, None
546
547     def setter(self, ri, space, direction, deref=False, ref=None):
548         ref = (ref if ref else []) + [self.c_name]
549
550         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
551             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
552
553
554 class TypeMultiAttr(Type):
555     def __init__(self, family, attr_set, attr, value, base_type):
556         super().__init__(family, attr_set, attr, value)
557
558         self.base_type = base_type
559
560     def is_multi_val(self):
561         return True
562
563     def presence_type(self):
564         return 'count'
565
566     def mnl_type(self):
567         return self._mnl_type()
568
569     def _complex_member_type(self, ri):
570         if 'type' not in self.attr or self.attr['type'] == 'nest':
571             return self.nested_struct_type
572         elif self.attr['type'] in scalars:
573             scalar_pfx = '__' if ri.ku_space == 'user' else ''
574             return scalar_pfx + self.attr['type']
575         else:
576             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
577
578     def free_needs_iter(self):
579         return 'type' not in self.attr or self.attr['type'] == 'nest'
580
581     def free(self, ri, var, ref):
582         if self.attr['type'] in scalars:
583             ri.cw.p(f"free({var}->{ref}{self.c_name});")
584         elif 'type' not in self.attr or self.attr['type'] == 'nest':
585             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
586             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
587             ri.cw.p(f"free({var}->{ref}{self.c_name});")
588         else:
589             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
590
591     def _attr_policy(self, policy):
592         return self.base_type._attr_policy(policy)
593
594     def _attr_typol(self):
595         return self.base_type._attr_typol()
596
597     def _attr_get(self, ri, var):
598         return f'n_{self.c_name}++;', None, None
599
600     def attr_put(self, ri, var):
601         if self.attr['type'] in scalars:
602             put_type = self.mnl_type()
603             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
604             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
605         elif 'type' not in self.attr or self.attr['type'] == 'nest':
606             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
607             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
608                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
609         else:
610             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
611
612     def _setter_lines(self, ri, member, presence):
613         # For multi-attr we have a count, not presence, hack up the presence
614         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
615         return [f"free({member});",
616                 f"{member} = {self.c_name};",
617                 f"{presence} = n_{self.c_name};"]
618
619
620 class TypeArrayNest(Type):
621     def is_multi_val(self):
622         return True
623
624     def presence_type(self):
625         return 'count'
626
627     def _complex_member_type(self, ri):
628         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
629             return self.nested_struct_type
630         elif self.attr['sub-type'] in scalars:
631             scalar_pfx = '__' if ri.ku_space == 'user' else ''
632             return scalar_pfx + self.attr['sub-type']
633         else:
634             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
635
636     def _attr_typol(self):
637         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
638
639     def _attr_get(self, ri, var):
640         local_vars = ['const struct nlattr *attr2;']
641         get_lines = [f'attr_{self.c_name} = attr;',
642                      'mnl_attr_for_each_nested(attr2, attr)',
643                      f'\t{var}->n_{self.c_name}++;']
644         return get_lines, None, local_vars
645
646
647 class TypeNestTypeValue(Type):
648     def _complex_member_type(self, ri):
649         return self.nested_struct_type
650
651     def _attr_typol(self):
652         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
653
654     def _attr_get(self, ri, var):
655         prev = 'attr'
656         tv_args = ''
657         get_lines = []
658         local_vars = []
659         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
660                       f"parg.data = &{var}->{self.c_name};"]
661         if 'type-value' in self.attr:
662             tv_names = [c_lower(x) for x in self.attr["type-value"]]
663             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
664             local_vars += [f'__u32 {", ".join(tv_names)};']
665             for level in self.attr["type-value"]:
666                 level = c_lower(level)
667                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
668                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
669                 prev = 'attr_' + level
670
671             tv_args = f", {', '.join(tv_names)}"
672
673         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
674         return get_lines, init_lines, local_vars
675
676
677 class Struct:
678     def __init__(self, family, space_name, type_list=None, inherited=None):
679         self.family = family
680         self.space_name = space_name
681         self.attr_set = family.attr_sets[space_name]
682         # Use list to catch comparisons with empty sets
683         self._inherited = inherited if inherited is not None else []
684         self.inherited = []
685
686         self.nested = type_list is None
687         if family.name == c_lower(space_name):
688             self.render_name = f"{family.name}"
689         else:
690             self.render_name = f"{family.name}_{c_lower(space_name)}"
691         self.struct_name = 'struct ' + self.render_name
692         if self.nested and space_name in family.consts:
693             self.struct_name += '_'
694         self.ptr_name = self.struct_name + ' *'
695
696         self.request = False
697         self.reply = False
698
699         self.attr_list = []
700         self.attrs = dict()
701         if type_list is not None:
702             for t in type_list:
703                 self.attr_list.append((t, self.attr_set[t]),)
704         else:
705             for t in self.attr_set:
706                 self.attr_list.append((t, self.attr_set[t]),)
707
708         max_val = 0
709         self.attr_max_val = None
710         for name, attr in self.attr_list:
711             if attr.value >= max_val:
712                 max_val = attr.value
713                 self.attr_max_val = attr
714             self.attrs[name] = attr
715
716     def __iter__(self):
717         yield from self.attrs
718
719     def __getitem__(self, key):
720         return self.attrs[key]
721
722     def member_list(self):
723         return self.attr_list
724
725     def set_inherited(self, new_inherited):
726         if self._inherited != new_inherited:
727             raise Exception("Inheriting different members not supported")
728         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
729
730
731 class EnumEntry(SpecEnumEntry):
732     def __init__(self, enum_set, yaml, prev, value_start):
733         super().__init__(enum_set, yaml, prev, value_start)
734
735         if prev:
736             self.value_change = (self.value != prev.value + 1)
737         else:
738             self.value_change = (self.value != 0)
739         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
740
741         # Added by resolve:
742         self.c_name = None
743         delattr(self, "c_name")
744
745     def resolve(self):
746         self.resolve_up(super())
747
748         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
749
750
751 class EnumSet(SpecEnumSet):
752     def __init__(self, family, yaml):
753         self.render_name = c_lower(family.name + '-' + yaml['name'])
754
755         if 'enum-name' in yaml:
756             if yaml['enum-name']:
757                 self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
758             else:
759                 self.enum_name = None
760         else:
761             self.enum_name = 'enum ' + self.render_name
762
763         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
764
765         super().__init__(family, yaml)
766
767     def new_entry(self, entry, prev_entry, value_start):
768         return EnumEntry(self, entry, prev_entry, value_start)
769
770     def value_range(self):
771         low = min([x.value for x in self.entries.values()])
772         high = max([x.value for x in self.entries.values()])
773
774         if high - low + 1 != len(self.entries):
775             raise Exception("Can't get value range for a noncontiguous enum")
776
777         return low, high
778
779
780 class AttrSet(SpecAttrSet):
781     def __init__(self, family, yaml):
782         super().__init__(family, yaml)
783
784         if self.subset_of is None:
785             if 'name-prefix' in yaml:
786                 pfx = yaml['name-prefix']
787             elif self.name == family.name:
788                 pfx = family.name + '-a-'
789             else:
790                 pfx = f"{family.name}-a-{self.name}-"
791             self.name_prefix = c_upper(pfx)
792             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
793             self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
794         else:
795             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
796             self.max_name = family.attr_sets[self.subset_of].max_name
797             self.cnt_name = family.attr_sets[self.subset_of].cnt_name
798
799         # Added by resolve:
800         self.c_name = None
801         delattr(self, "c_name")
802
803     def resolve(self):
804         self.c_name = c_lower(self.name)
805         if self.c_name in _C_KW:
806             self.c_name += '_'
807         if self.c_name == self.family.c_name:
808             self.c_name = ''
809
810     def new_attr(self, elem, value):
811         if elem['type'] in scalars:
812             t = TypeScalar(self.family, self, elem, value)
813         elif elem['type'] == 'unused':
814             t = TypeUnused(self.family, self, elem, value)
815         elif elem['type'] == 'pad':
816             t = TypePad(self.family, self, elem, value)
817         elif elem['type'] == 'flag':
818             t = TypeFlag(self.family, self, elem, value)
819         elif elem['type'] == 'string':
820             t = TypeString(self.family, self, elem, value)
821         elif elem['type'] == 'binary':
822             t = TypeBinary(self.family, self, elem, value)
823         elif elem['type'] == 'bitfield32':
824             t = TypeBitfield32(self.family, self, elem, value)
825         elif elem['type'] == 'nest':
826             t = TypeNest(self.family, self, elem, value)
827         elif elem['type'] == 'array-nest':
828             t = TypeArrayNest(self.family, self, elem, value)
829         elif elem['type'] == 'nest-type-value':
830             t = TypeNestTypeValue(self.family, self, elem, value)
831         else:
832             raise Exception(f"No typed class for type {elem['type']}")
833
834         if 'multi-attr' in elem and elem['multi-attr']:
835             t = TypeMultiAttr(self.family, self, elem, value, t)
836
837         return t
838
839
840 class Operation(SpecOperation):
841     def __init__(self, family, yaml, req_value, rsp_value):
842         super().__init__(family, yaml, req_value, rsp_value)
843
844         self.render_name = family.name + '_' + c_lower(self.name)
845
846         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
847                          ('dump' in yaml and 'request' in yaml['dump'])
848
849         self.has_ntf = False
850
851         # Added by resolve:
852         self.enum_name = None
853         delattr(self, "enum_name")
854
855     def resolve(self):
856         self.resolve_up(super())
857
858         if not self.is_async:
859             self.enum_name = self.family.op_prefix + c_upper(self.name)
860         else:
861             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
862
863     def mark_has_ntf(self):
864         self.has_ntf = True
865
866
867 class Family(SpecFamily):
868     def __init__(self, file_name, exclude_ops):
869         # Added by resolve:
870         self.c_name = None
871         delattr(self, "c_name")
872         self.op_prefix = None
873         delattr(self, "op_prefix")
874         self.async_op_prefix = None
875         delattr(self, "async_op_prefix")
876         self.mcgrps = None
877         delattr(self, "mcgrps")
878         self.consts = None
879         delattr(self, "consts")
880         self.hooks = None
881         delattr(self, "hooks")
882
883         super().__init__(file_name, exclude_ops=exclude_ops)
884
885         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
886         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
887
888         if 'definitions' not in self.yaml:
889             self.yaml['definitions'] = []
890
891         if 'uapi-header' in self.yaml:
892             self.uapi_header = self.yaml['uapi-header']
893         else:
894             self.uapi_header = f"linux/{self.name}.h"
895         if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
896             self.uapi_header_name = self.uapi_header[6:-2]
897         else:
898             self.uapi_header_name = self.name
899
900     def resolve(self):
901         self.resolve_up(super())
902
903         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
904             raise Exception("Codegen only supported for genetlink")
905
906         self.c_name = c_lower(self.name)
907         if 'name-prefix' in self.yaml['operations']:
908             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
909         else:
910             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
911         if 'async-prefix' in self.yaml['operations']:
912             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
913         else:
914             self.async_op_prefix = self.op_prefix
915
916         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
917
918         self.hooks = dict()
919         for when in ['pre', 'post']:
920             self.hooks[when] = dict()
921             for op_mode in ['do', 'dump']:
922                 self.hooks[when][op_mode] = dict()
923                 self.hooks[when][op_mode]['set'] = set()
924                 self.hooks[when][op_mode]['list'] = []
925
926         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
927         self.root_sets = dict()
928         # dict space-name -> set('request', 'reply')
929         self.pure_nested_structs = dict()
930
931         self._mark_notify()
932         self._mock_up_events()
933
934         self._load_root_sets()
935         self._load_nested_sets()
936         self._load_attr_use()
937         self._load_hooks()
938
939         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
940         if self.kernel_policy == 'global':
941             self._load_global_policy()
942
943     def new_enum(self, elem):
944         return EnumSet(self, elem)
945
946     def new_attr_set(self, elem):
947         return AttrSet(self, elem)
948
949     def new_operation(self, elem, req_value, rsp_value):
950         return Operation(self, elem, req_value, rsp_value)
951
952     def _mark_notify(self):
953         for op in self.msgs.values():
954             if 'notify' in op:
955                 self.ops[op['notify']].mark_has_ntf()
956
957     # Fake a 'do' equivalent of all events, so that we can render their response parsing
958     def _mock_up_events(self):
959         for op in self.yaml['operations']['list']:
960             if 'event' in op:
961                 op['do'] = {
962                     'reply': {
963                         'attributes': op['event']['attributes']
964                     }
965                 }
966
967     def _load_root_sets(self):
968         for op_name, op in self.msgs.items():
969             if 'attribute-set' not in op:
970                 continue
971
972             req_attrs = set()
973             rsp_attrs = set()
974             for op_mode in ['do', 'dump']:
975                 if op_mode in op and 'request' in op[op_mode]:
976                     req_attrs.update(set(op[op_mode]['request']['attributes']))
977                 if op_mode in op and 'reply' in op[op_mode]:
978                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
979             if 'event' in op:
980                 rsp_attrs.update(set(op['event']['attributes']))
981
982             if op['attribute-set'] not in self.root_sets:
983                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
984             else:
985                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
986                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
987
988     def _load_nested_sets(self):
989         attr_set_queue = list(self.root_sets.keys())
990         attr_set_seen = set(self.root_sets.keys())
991
992         while len(attr_set_queue):
993             a_set = attr_set_queue.pop(0)
994             for attr, spec in self.attr_sets[a_set].items():
995                 if 'nested-attributes' not in spec:
996                     continue
997
998                 nested = spec['nested-attributes']
999                 if nested not in attr_set_seen:
1000                     attr_set_queue.append(nested)
1001                     attr_set_seen.add(nested)
1002
1003                 inherit = set()
1004                 if nested not in self.root_sets:
1005                     if nested not in self.pure_nested_structs:
1006                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1007                 else:
1008                     raise Exception(f'Using attr set as root and nested not supported - {nested}')
1009
1010                 if 'type-value' in spec:
1011                     if nested in self.root_sets:
1012                         raise Exception("Inheriting members to a space used as root not supported")
1013                     inherit.update(set(spec['type-value']))
1014                 elif spec['type'] == 'array-nest':
1015                     inherit.add('idx')
1016                 self.pure_nested_structs[nested].set_inherited(inherit)
1017
1018         for root_set, rs_members in self.root_sets.items():
1019             for attr, spec in self.attr_sets[root_set].items():
1020                 if 'nested-attributes' in spec:
1021                     nested = spec['nested-attributes']
1022                     if attr in rs_members['request']:
1023                         self.pure_nested_structs[nested].request = True
1024                     if attr in rs_members['reply']:
1025                         self.pure_nested_structs[nested].reply = True
1026
1027         # Try to reorder according to dependencies
1028         pns_key_list = list(self.pure_nested_structs.keys())
1029         pns_key_seen = set()
1030         rounds = len(pns_key_list)**2  # it's basically bubble sort
1031         for _ in range(rounds):
1032             if len(pns_key_list) == 0:
1033                 break
1034             name = pns_key_list.pop(0)
1035             finished = True
1036             for _, spec in self.attr_sets[name].items():
1037                 if 'nested-attributes' in spec:
1038                     if spec['nested-attributes'] not in pns_key_seen:
1039                         # Dicts are sorted, this will make struct last
1040                         struct = self.pure_nested_structs.pop(name)
1041                         self.pure_nested_structs[name] = struct
1042                         finished = False
1043                         break
1044             if finished:
1045                 pns_key_seen.add(name)
1046             else:
1047                 pns_key_list.append(name)
1048         # Propagate the request / reply
1049         for attr_set, struct in reversed(self.pure_nested_structs.items()):
1050             for _, spec in self.attr_sets[attr_set].items():
1051                 if 'nested-attributes' in spec:
1052                     child = self.pure_nested_structs.get(spec['nested-attributes'])
1053                     if child:
1054                         child.request |= struct.request
1055                         child.reply |= struct.reply
1056
1057     def _load_attr_use(self):
1058         for _, struct in self.pure_nested_structs.items():
1059             if struct.request:
1060                 for _, arg in struct.member_list():
1061                     arg.request = True
1062             if struct.reply:
1063                 for _, arg in struct.member_list():
1064                     arg.reply = True
1065
1066         for root_set, rs_members in self.root_sets.items():
1067             for attr, spec in self.attr_sets[root_set].items():
1068                 if attr in rs_members['request']:
1069                     spec.request = True
1070                 if attr in rs_members['reply']:
1071                     spec.reply = True
1072
1073     def _load_global_policy(self):
1074         global_set = set()
1075         attr_set_name = None
1076         for op_name, op in self.ops.items():
1077             if not op:
1078                 continue
1079             if 'attribute-set' not in op:
1080                 continue
1081
1082             if attr_set_name is None:
1083                 attr_set_name = op['attribute-set']
1084             if attr_set_name != op['attribute-set']:
1085                 raise Exception('For a global policy all ops must use the same set')
1086
1087             for op_mode in ['do', 'dump']:
1088                 if op_mode in op:
1089                     req = op[op_mode].get('request')
1090                     if req:
1091                         global_set.update(req.get('attributes', []))
1092
1093         self.global_policy = []
1094         self.global_policy_set = attr_set_name
1095         for attr in self.attr_sets[attr_set_name]:
1096             if attr in global_set:
1097                 self.global_policy.append(attr)
1098
1099     def _load_hooks(self):
1100         for op in self.ops.values():
1101             for op_mode in ['do', 'dump']:
1102                 if op_mode not in op:
1103                     continue
1104                 for when in ['pre', 'post']:
1105                     if when not in op[op_mode]:
1106                         continue
1107                     name = op[op_mode][when]
1108                     if name in self.hooks[when][op_mode]['set']:
1109                         continue
1110                     self.hooks[when][op_mode]['set'].add(name)
1111                     self.hooks[when][op_mode]['list'].append(name)
1112
1113
1114 class RenderInfo:
1115     def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1116         self.family = family
1117         self.nl = cw.nlib
1118         self.ku_space = ku_space
1119         self.op_mode = op_mode
1120         self.op = op
1121
1122         # 'do' and 'dump' response parsing is identical
1123         self.type_consistent = True
1124         if op_mode != 'do' and 'dump' in op:
1125             if 'do' in op:
1126                 if ('reply' in op['do']) != ('reply' in op["dump"]):
1127                     self.type_consistent = False
1128                 elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1129                     self.type_consistent = False
1130             else:
1131                 self.type_consistent = False
1132
1133         self.attr_set = attr_set
1134         if not self.attr_set:
1135             self.attr_set = op['attribute-set']
1136
1137         self.type_name_conflict = False
1138         if op:
1139             self.type_name = c_lower(op.name)
1140         else:
1141             self.type_name = c_lower(attr_set)
1142             if attr_set in family.consts:
1143                 self.type_name_conflict = True
1144
1145         self.cw = cw
1146
1147         self.struct = dict()
1148         if op_mode == 'notify':
1149             op_mode = 'do'
1150         for op_dir in ['request', 'reply']:
1151             if op:
1152                 type_list = []
1153                 if op_dir in op[op_mode]:
1154                     type_list = op[op_mode][op_dir]['attributes']
1155                 self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1156         if op_mode == 'event':
1157             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1158
1159
1160 class CodeWriter:
1161     def __init__(self, nlib, out_file=None):
1162         self.nlib = nlib
1163
1164         self._nl = False
1165         self._block_end = False
1166         self._silent_block = False
1167         self._ind = 0
1168         self._ifdef_block = None
1169         if out_file is None:
1170             self._out = os.sys.stdout
1171         else:
1172             self._out = tempfile.NamedTemporaryFile('w+')
1173             self._out_file = out_file
1174
1175     def __del__(self):
1176         self.close_out_file()
1177
1178     def close_out_file(self):
1179         if self._out == os.sys.stdout:
1180             return
1181         # Avoid modifying the file if contents didn't change
1182         self._out.flush()
1183         if os.path.isfile(self._out_file) and filecmp.cmp(self._out.name, self._out_file, shallow=False):
1184             return
1185         with open(self._out_file, 'w+') as out_file:
1186             self._out.seek(0)
1187             shutil.copyfileobj(self._out, out_file)
1188             self._out.close()
1189         self._out = os.sys.stdout
1190
1191     @classmethod
1192     def _is_cond(cls, line):
1193         return line.startswith('if') or line.startswith('while') or line.startswith('for')
1194
1195     def p(self, line, add_ind=0):
1196         if self._block_end:
1197             self._block_end = False
1198             if line.startswith('else'):
1199                 line = '} ' + line
1200             else:
1201                 self._out.write('\t' * self._ind + '}\n')
1202
1203         if self._nl:
1204             self._out.write('\n')
1205             self._nl = False
1206
1207         ind = self._ind
1208         if line[-1] == ':':
1209             ind -= 1
1210         if self._silent_block:
1211             ind += 1
1212         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1213         if line[0] == '#':
1214             ind = 0
1215         if add_ind:
1216             ind += add_ind
1217         self._out.write('\t' * ind + line + '\n')
1218
1219     def nl(self):
1220         self._nl = True
1221
1222     def block_start(self, line=''):
1223         if line:
1224             line = line + ' '
1225         self.p(line + '{')
1226         self._ind += 1
1227
1228     def block_end(self, line=''):
1229         if line and line[0] not in {';', ','}:
1230             line = ' ' + line
1231         self._ind -= 1
1232         self._nl = False
1233         if not line:
1234             # Delay printing closing bracket in case "else" comes next
1235             if self._block_end:
1236                 self._out.write('\t' * (self._ind + 1) + '}\n')
1237             self._block_end = True
1238         else:
1239             self.p('}' + line)
1240
1241     def write_doc_line(self, doc, indent=True):
1242         words = doc.split()
1243         line = ' *'
1244         for word in words:
1245             if len(line) + len(word) >= 79:
1246                 self.p(line)
1247                 line = ' *'
1248                 if indent:
1249                     line += '  '
1250             line += ' ' + word
1251         self.p(line)
1252
1253     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1254         if not args:
1255             args = ['void']
1256
1257         if doc:
1258             self.p('/*')
1259             self.p(' * ' + doc)
1260             self.p(' */')
1261
1262         oneline = qual_ret
1263         if qual_ret[-1] != '*':
1264             oneline += ' '
1265         oneline += f"{name}({', '.join(args)}){suffix}"
1266
1267         if len(oneline) < 80:
1268             self.p(oneline)
1269             return
1270
1271         v = qual_ret
1272         if len(v) > 3:
1273             self.p(v)
1274             v = ''
1275         elif qual_ret[-1] != '*':
1276             v += ' '
1277         v += name + '('
1278         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1279         delta_ind = len(v) - len(ind)
1280         v += args[0]
1281         i = 1
1282         while i < len(args):
1283             next_len = len(v) + len(args[i])
1284             if v[0] == '\t':
1285                 next_len += delta_ind
1286             if next_len > 76:
1287                 self.p(v + ',')
1288                 v = ind
1289             else:
1290                 v += ', '
1291             v += args[i]
1292             i += 1
1293         self.p(v + ')' + suffix)
1294
1295     def write_func_lvar(self, local_vars):
1296         if not local_vars:
1297             return
1298
1299         if type(local_vars) is str:
1300             local_vars = [local_vars]
1301
1302         local_vars.sort(key=len, reverse=True)
1303         for var in local_vars:
1304             self.p(var)
1305         self.nl()
1306
1307     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1308         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1309         self.write_func_lvar(local_vars=local_vars)
1310
1311         self.block_start()
1312         for line in body:
1313             self.p(line)
1314         self.block_end()
1315
1316     def writes_defines(self, defines):
1317         longest = 0
1318         for define in defines:
1319             if len(define[0]) > longest:
1320                 longest = len(define[0])
1321         longest = ((longest + 8) // 8) * 8
1322         for define in defines:
1323             line = '#define ' + define[0]
1324             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1325             if type(define[1]) is int:
1326                 line += str(define[1])
1327             elif type(define[1]) is str:
1328                 line += '"' + define[1] + '"'
1329             self.p(line)
1330
1331     def write_struct_init(self, members):
1332         longest = max([len(x[0]) for x in members])
1333         longest += 1  # because we prepend a .
1334         longest = ((longest + 8) // 8) * 8
1335         for one in members:
1336             line = '.' + one[0]
1337             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1338             line += '= ' + str(one[1]) + ','
1339             self.p(line)
1340
1341     def ifdef_block(self, config):
1342         config_option = None
1343         if config:
1344             config_option = 'CONFIG_' + c_upper(config)
1345         if self._ifdef_block == config_option:
1346             return
1347
1348         if self._ifdef_block:
1349             self.p('#endif /* ' + self._ifdef_block + ' */')
1350         if config_option:
1351             self.p('#ifdef ' + config_option)
1352         self._ifdef_block = config_option
1353
1354
1355 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64', 'uint', 'sint'}
1356
1357 direction_to_suffix = {
1358     'reply': '_rsp',
1359     'request': '_req',
1360     '': ''
1361 }
1362
1363 op_mode_to_wrapper = {
1364     'do': '',
1365     'dump': '_list',
1366     'notify': '_ntf',
1367     'event': '',
1368 }
1369
1370 _C_KW = {
1371     'auto',
1372     'bool',
1373     'break',
1374     'case',
1375     'char',
1376     'const',
1377     'continue',
1378     'default',
1379     'do',
1380     'double',
1381     'else',
1382     'enum',
1383     'extern',
1384     'float',
1385     'for',
1386     'goto',
1387     'if',
1388     'inline',
1389     'int',
1390     'long',
1391     'register',
1392     'return',
1393     'short',
1394     'signed',
1395     'sizeof',
1396     'static',
1397     'struct',
1398     'switch',
1399     'typedef',
1400     'union',
1401     'unsigned',
1402     'void',
1403     'volatile',
1404     'while'
1405 }
1406
1407
1408 def rdir(direction):
1409     if direction == 'reply':
1410         return 'request'
1411     if direction == 'request':
1412         return 'reply'
1413     return direction
1414
1415
1416 def op_prefix(ri, direction, deref=False):
1417     suffix = f"_{ri.type_name}"
1418
1419     if not ri.op_mode or ri.op_mode == 'do':
1420         suffix += f"{direction_to_suffix[direction]}"
1421     else:
1422         if direction == 'request':
1423             suffix += '_req_dump'
1424         else:
1425             if ri.type_consistent:
1426                 if deref:
1427                     suffix += f"{direction_to_suffix[direction]}"
1428                 else:
1429                     suffix += op_mode_to_wrapper[ri.op_mode]
1430             else:
1431                 suffix += '_rsp'
1432                 suffix += '_dump' if deref else '_list'
1433
1434     return f"{ri.family['name']}{suffix}"
1435
1436
1437 def type_name(ri, direction, deref=False):
1438     return f"struct {op_prefix(ri, direction, deref=deref)}"
1439
1440
1441 def print_prototype(ri, direction, terminate=True, doc=None):
1442     suffix = ';' if terminate else ''
1443
1444     fname = ri.op.render_name
1445     if ri.op_mode == 'dump':
1446         fname += '_dump'
1447
1448     args = ['struct ynl_sock *ys']
1449     if 'request' in ri.op[ri.op_mode]:
1450         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1451
1452     ret = 'int'
1453     if 'reply' in ri.op[ri.op_mode]:
1454         ret = f"{type_name(ri, rdir(direction))} *"
1455
1456     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1457
1458
1459 def print_req_prototype(ri):
1460     print_prototype(ri, "request", doc=ri.op['doc'])
1461
1462
1463 def print_dump_prototype(ri):
1464     print_prototype(ri, "request")
1465
1466
1467 def put_typol(cw, struct):
1468     type_max = struct.attr_set.max_name
1469     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1470
1471     for _, arg in struct.member_list():
1472         arg.attr_typol(cw)
1473
1474     cw.block_end(line=';')
1475     cw.nl()
1476
1477     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1478     cw.p(f'.max_attr = {type_max},')
1479     cw.p(f'.table = {struct.render_name}_policy,')
1480     cw.block_end(line=';')
1481     cw.nl()
1482
1483
1484 def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1485     args = [f'int {arg_name}']
1486     if enum and not ('enum-name' in enum and not enum['enum-name']):
1487         args = [f'enum {render_name} {arg_name}']
1488     cw.write_func_prot('const char *', f'{render_name}_str', args)
1489     cw.block_start()
1490     if enum and enum.type == 'flags':
1491         cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1492     cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1493     cw.p('return NULL;')
1494     cw.p(f'return {map_name}[{arg_name}];')
1495     cw.block_end()
1496     cw.nl()
1497
1498
1499 def put_op_name_fwd(family, cw):
1500     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1501
1502
1503 def put_op_name(family, cw):
1504     map_name = f'{family.name}_op_strmap'
1505     cw.block_start(line=f"static const char * const {map_name}[] =")
1506     for op_name, op in family.msgs.items():
1507         if op.rsp_value:
1508             if op.req_value == op.rsp_value:
1509                 cw.p(f'[{op.enum_name}] = "{op_name}",')
1510             else:
1511                 cw.p(f'[{op.rsp_value}] = "{op_name}",')
1512     cw.block_end(line=';')
1513     cw.nl()
1514
1515     _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1516
1517
1518 def put_enum_to_str_fwd(family, cw, enum):
1519     args = [f'enum {enum.render_name} value']
1520     if 'enum-name' in enum and not enum['enum-name']:
1521         args = ['int value']
1522     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1523
1524
1525 def put_enum_to_str(family, cw, enum):
1526     map_name = f'{enum.render_name}_strmap'
1527     cw.block_start(line=f"static const char * const {map_name}[] =")
1528     for entry in enum.entries.values():
1529         cw.p(f'[{entry.value}] = "{entry.name}",')
1530     cw.block_end(line=';')
1531     cw.nl()
1532
1533     _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1534
1535
1536 def put_req_nested(ri, struct):
1537     func_args = ['struct nlmsghdr *nlh',
1538                  'unsigned int attr_type',
1539                  f'{struct.ptr_name}obj']
1540
1541     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1542     ri.cw.block_start()
1543     ri.cw.write_func_lvar('struct nlattr *nest;')
1544
1545     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1546
1547     for _, arg in struct.member_list():
1548         arg.attr_put(ri, "obj")
1549
1550     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1551
1552     ri.cw.nl()
1553     ri.cw.p('return 0;')
1554     ri.cw.block_end()
1555     ri.cw.nl()
1556
1557
1558 def _multi_parse(ri, struct, init_lines, local_vars):
1559     if struct.nested:
1560         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1561     else:
1562         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1563
1564     array_nests = set()
1565     multi_attrs = set()
1566     needs_parg = False
1567     for arg, aspec in struct.member_list():
1568         if aspec['type'] == 'array-nest':
1569             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1570             array_nests.add(arg)
1571         if 'multi-attr' in aspec:
1572             multi_attrs.add(arg)
1573         needs_parg |= 'nested-attributes' in aspec
1574     if array_nests or multi_attrs:
1575         local_vars.append('int i;')
1576     if needs_parg:
1577         local_vars.append('struct ynl_parse_arg parg;')
1578         init_lines.append('parg.ys = yarg->ys;')
1579
1580     all_multi = array_nests | multi_attrs
1581
1582     for anest in sorted(all_multi):
1583         local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1584
1585     ri.cw.block_start()
1586     ri.cw.write_func_lvar(local_vars)
1587
1588     for line in init_lines:
1589         ri.cw.p(line)
1590     ri.cw.nl()
1591
1592     for arg in struct.inherited:
1593         ri.cw.p(f'dst->{arg} = {arg};')
1594
1595     for anest in sorted(all_multi):
1596         aspec = struct[anest]
1597         ri.cw.p(f"if (dst->{aspec.c_name})")
1598         ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1599
1600     ri.cw.nl()
1601     ri.cw.block_start(line=iter_line)
1602     ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1603     ri.cw.nl()
1604
1605     first = True
1606     for _, arg in struct.member_list():
1607         good = arg.attr_get(ri, 'dst', first=first)
1608         # First may be 'unused' or 'pad', ignore those
1609         first &= not good
1610
1611     ri.cw.block_end()
1612     ri.cw.nl()
1613
1614     for anest in sorted(array_nests):
1615         aspec = struct[anest]
1616
1617         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1618         ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1619         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1620         ri.cw.p('i = 0;')
1621         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1622         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1623         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1624         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1625         ri.cw.p('return MNL_CB_ERROR;')
1626         ri.cw.p('i++;')
1627         ri.cw.block_end()
1628         ri.cw.block_end()
1629     ri.cw.nl()
1630
1631     for anest in sorted(multi_attrs):
1632         aspec = struct[anest]
1633         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1634         ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1635         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1636         ri.cw.p('i = 0;')
1637         if 'nested-attributes' in aspec:
1638             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1639         ri.cw.block_start(line=iter_line)
1640         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1641         if 'nested-attributes' in aspec:
1642             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1643             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1644             ri.cw.p('return MNL_CB_ERROR;')
1645         elif aspec.type in scalars:
1646             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{aspec.mnl_type()}(attr);")
1647         else:
1648             raise Exception('Nest parsing type not supported yet')
1649         ri.cw.p('i++;')
1650         ri.cw.block_end()
1651         ri.cw.block_end()
1652         ri.cw.block_end()
1653     ri.cw.nl()
1654
1655     if struct.nested:
1656         ri.cw.p('return 0;')
1657     else:
1658         ri.cw.p('return MNL_CB_OK;')
1659     ri.cw.block_end()
1660     ri.cw.nl()
1661
1662
1663 def parse_rsp_nested(ri, struct):
1664     func_args = ['struct ynl_parse_arg *yarg',
1665                  'const struct nlattr *nested']
1666     for arg in struct.inherited:
1667         func_args.append('__u32 ' + arg)
1668
1669     local_vars = ['const struct nlattr *attr;',
1670                   f'{struct.ptr_name}dst = yarg->data;']
1671     init_lines = []
1672
1673     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1674
1675     _multi_parse(ri, struct, init_lines, local_vars)
1676
1677
1678 def parse_rsp_msg(ri, deref=False):
1679     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1680         return
1681
1682     func_args = ['const struct nlmsghdr *nlh',
1683                  'void *data']
1684
1685     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1686                   'struct ynl_parse_arg *yarg = data;',
1687                   'const struct nlattr *attr;']
1688     init_lines = ['dst = yarg->data;']
1689
1690     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1691
1692     if ri.struct["reply"].member_list():
1693         _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1694     else:
1695         # Empty reply
1696         ri.cw.block_start()
1697         ri.cw.p('return MNL_CB_OK;')
1698         ri.cw.block_end()
1699         ri.cw.nl()
1700
1701
1702 def print_req(ri):
1703     ret_ok = '0'
1704     ret_err = '-1'
1705     direction = "request"
1706     local_vars = ['struct nlmsghdr *nlh;',
1707                   'int err;']
1708
1709     if 'reply' in ri.op[ri.op_mode]:
1710         ret_ok = 'rsp'
1711         ret_err = 'NULL'
1712         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1713                        'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1714
1715     print_prototype(ri, direction, terminate=False)
1716     ri.cw.block_start()
1717     ri.cw.write_func_lvar(local_vars)
1718
1719     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1720
1721     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1722     if 'reply' in ri.op[ri.op_mode]:
1723         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1724     ri.cw.nl()
1725     for _, attr in ri.struct["request"].member_list():
1726         attr.attr_put(ri, "req")
1727     ri.cw.nl()
1728
1729     parse_arg = "NULL"
1730     if 'reply' in ri.op[ri.op_mode]:
1731         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1732         ri.cw.p('yrs.yarg.data = rsp;')
1733         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1734         if ri.op.value is not None:
1735             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1736         else:
1737             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1738         ri.cw.nl()
1739         parse_arg = '&yrs'
1740     ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1741     ri.cw.p('if (err < 0)')
1742     if 'reply' in ri.op[ri.op_mode]:
1743         ri.cw.p('goto err_free;')
1744     else:
1745         ri.cw.p('return -1;')
1746     ri.cw.nl()
1747
1748     ri.cw.p(f"return {ret_ok};")
1749     ri.cw.nl()
1750
1751     if 'reply' in ri.op[ri.op_mode]:
1752         ri.cw.p('err_free:')
1753         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1754         ri.cw.p(f"return {ret_err};")
1755
1756     ri.cw.block_end()
1757
1758
1759 def print_dump(ri):
1760     direction = "request"
1761     print_prototype(ri, direction, terminate=False)
1762     ri.cw.block_start()
1763     local_vars = ['struct ynl_dump_state yds = {};',
1764                   'struct nlmsghdr *nlh;',
1765                   'int err;']
1766
1767     for var in local_vars:
1768         ri.cw.p(f'{var}')
1769     ri.cw.nl()
1770
1771     ri.cw.p('yds.ys = ys;')
1772     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1773     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1774     if ri.op.value is not None:
1775         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1776     else:
1777         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1778     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1779     ri.cw.nl()
1780     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1781
1782     if "request" in ri.op[ri.op_mode]:
1783         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1784         ri.cw.nl()
1785         for _, attr in ri.struct["request"].member_list():
1786             attr.attr_put(ri, "req")
1787     ri.cw.nl()
1788
1789     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1790     ri.cw.p('if (err < 0)')
1791     ri.cw.p('goto free_list;')
1792     ri.cw.nl()
1793
1794     ri.cw.p('return yds.first;')
1795     ri.cw.nl()
1796     ri.cw.p('free_list:')
1797     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1798     ri.cw.p('return NULL;')
1799     ri.cw.block_end()
1800
1801
1802 def call_free(ri, direction, var):
1803     return f"{op_prefix(ri, direction)}_free({var});"
1804
1805
1806 def free_arg_name(direction):
1807     if direction:
1808         return direction_to_suffix[direction][1:]
1809     return 'obj'
1810
1811
1812 def print_alloc_wrapper(ri, direction):
1813     name = op_prefix(ri, direction)
1814     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1815     ri.cw.block_start()
1816     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1817     ri.cw.block_end()
1818
1819
1820 def print_free_prototype(ri, direction, suffix=';'):
1821     name = op_prefix(ri, direction)
1822     struct_name = name
1823     if ri.type_name_conflict:
1824         struct_name += '_'
1825     arg = free_arg_name(direction)
1826     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1827
1828
1829 def _print_type(ri, direction, struct):
1830     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1831     if not direction and ri.type_name_conflict:
1832         suffix += '_'
1833
1834     if ri.op_mode == 'dump':
1835         suffix += '_dump'
1836
1837     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1838
1839     meta_started = False
1840     for _, attr in struct.member_list():
1841         for type_filter in ['len', 'bit']:
1842             line = attr.presence_member(ri.ku_space, type_filter)
1843             if line:
1844                 if not meta_started:
1845                     ri.cw.block_start(line=f"struct")
1846                     meta_started = True
1847                 ri.cw.p(line)
1848     if meta_started:
1849         ri.cw.block_end(line='_present;')
1850         ri.cw.nl()
1851
1852     for arg in struct.inherited:
1853         ri.cw.p(f"__u32 {arg};")
1854
1855     for _, attr in struct.member_list():
1856         attr.struct_member(ri)
1857
1858     ri.cw.block_end(line=';')
1859     ri.cw.nl()
1860
1861
1862 def print_type(ri, direction):
1863     _print_type(ri, direction, ri.struct[direction])
1864
1865
1866 def print_type_full(ri, struct):
1867     _print_type(ri, "", struct)
1868
1869
1870 def print_type_helpers(ri, direction, deref=False):
1871     print_free_prototype(ri, direction)
1872     ri.cw.nl()
1873
1874     if ri.ku_space == 'user' and direction == 'request':
1875         for _, attr in ri.struct[direction].member_list():
1876             attr.setter(ri, ri.attr_set, direction, deref=deref)
1877     ri.cw.nl()
1878
1879
1880 def print_req_type_helpers(ri):
1881     if len(ri.struct["request"].attr_list) == 0:
1882         return
1883     print_alloc_wrapper(ri, "request")
1884     print_type_helpers(ri, "request")
1885
1886
1887 def print_rsp_type_helpers(ri):
1888     if 'reply' not in ri.op[ri.op_mode]:
1889         return
1890     print_type_helpers(ri, "reply")
1891
1892
1893 def print_parse_prototype(ri, direction, terminate=True):
1894     suffix = "_rsp" if direction == "reply" else "_req"
1895     term = ';' if terminate else ''
1896
1897     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1898                           ['const struct nlattr **tb',
1899                            f"struct {ri.op.render_name}{suffix} *req"],
1900                           suffix=term)
1901
1902
1903 def print_req_type(ri):
1904     if len(ri.struct["request"].attr_list) == 0:
1905         return
1906     print_type(ri, "request")
1907
1908
1909 def print_req_free(ri):
1910     if 'request' not in ri.op[ri.op_mode]:
1911         return
1912     _free_type(ri, 'request', ri.struct['request'])
1913
1914
1915 def print_rsp_type(ri):
1916     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1917         direction = 'reply'
1918     elif ri.op_mode == 'event':
1919         direction = 'reply'
1920     else:
1921         return
1922     print_type(ri, direction)
1923
1924
1925 def print_wrapped_type(ri):
1926     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1927     if ri.op_mode == 'dump':
1928         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1929     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1930         ri.cw.p('__u16 family;')
1931         ri.cw.p('__u8 cmd;')
1932         ri.cw.p('struct ynl_ntf_base_type *next;')
1933         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1934     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
1935     ri.cw.block_end(line=';')
1936     ri.cw.nl()
1937     print_free_prototype(ri, 'reply')
1938     ri.cw.nl()
1939
1940
1941 def _free_type_members_iter(ri, struct):
1942     for _, attr in struct.member_list():
1943         if attr.free_needs_iter():
1944             ri.cw.p('unsigned int i;')
1945             ri.cw.nl()
1946             break
1947
1948
1949 def _free_type_members(ri, var, struct, ref=''):
1950     for _, attr in struct.member_list():
1951         attr.free(ri, var, ref)
1952
1953
1954 def _free_type(ri, direction, struct):
1955     var = free_arg_name(direction)
1956
1957     print_free_prototype(ri, direction, suffix='')
1958     ri.cw.block_start()
1959     _free_type_members_iter(ri, struct)
1960     _free_type_members(ri, var, struct)
1961     if direction:
1962         ri.cw.p(f'free({var});')
1963     ri.cw.block_end()
1964     ri.cw.nl()
1965
1966
1967 def free_rsp_nested(ri, struct):
1968     _free_type(ri, "", struct)
1969
1970
1971 def print_rsp_free(ri):
1972     if 'reply' not in ri.op[ri.op_mode]:
1973         return
1974     _free_type(ri, 'reply', ri.struct['reply'])
1975
1976
1977 def print_dump_type_free(ri):
1978     sub_type = type_name(ri, 'reply')
1979
1980     print_free_prototype(ri, 'reply', suffix='')
1981     ri.cw.block_start()
1982     ri.cw.p(f"{sub_type} *next = rsp;")
1983     ri.cw.nl()
1984     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1985     _free_type_members_iter(ri, ri.struct['reply'])
1986     ri.cw.p('rsp = next;')
1987     ri.cw.p('next = rsp->next;')
1988     ri.cw.nl()
1989
1990     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1991     ri.cw.p(f'free(rsp);')
1992     ri.cw.block_end()
1993     ri.cw.block_end()
1994     ri.cw.nl()
1995
1996
1997 def print_ntf_type_free(ri):
1998     print_free_prototype(ri, 'reply', suffix='')
1999     ri.cw.block_start()
2000     _free_type_members_iter(ri, ri.struct['reply'])
2001     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2002     ri.cw.p(f'free(rsp);')
2003     ri.cw.block_end()
2004     ri.cw.nl()
2005
2006
2007 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2008     if terminate and ri and policy_should_be_static(struct.family):
2009         return
2010
2011     if terminate:
2012         prefix = 'extern '
2013     else:
2014         if ri and policy_should_be_static(struct.family):
2015             prefix = 'static '
2016         else:
2017             prefix = ''
2018
2019     suffix = ';' if terminate else ' = {'
2020
2021     max_attr = struct.attr_max_val
2022     if ri:
2023         name = ri.op.render_name
2024         if ri.op.dual_policy:
2025             name += '_' + ri.op_mode
2026     else:
2027         name = struct.render_name
2028     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2029
2030
2031 def print_req_policy(cw, struct, ri=None):
2032     if ri and ri.op:
2033         cw.ifdef_block(ri.op.get('config-cond', None))
2034     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2035     for _, arg in struct.member_list():
2036         arg.attr_policy(cw)
2037     cw.p("};")
2038     cw.ifdef_block(None)
2039     cw.nl()
2040
2041
2042 def kernel_can_gen_family_struct(family):
2043     return family.proto == 'genetlink'
2044
2045
2046 def policy_should_be_static(family):
2047     return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2048
2049
2050 def print_kernel_policy_ranges(family, cw):
2051     first = True
2052     for _, attr_set in family.attr_sets.items():
2053         if attr_set.subset_of:
2054             continue
2055
2056         for _, attr in attr_set.items():
2057             if not attr.request:
2058                 continue
2059             if 'full-range' not in attr.checks:
2060                 continue
2061
2062             if first:
2063                 cw.p('/* Integer value ranges */')
2064                 first = False
2065
2066             sign = '' if attr.type[0] == 'u' else '_signed'
2067             cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2068             members = []
2069             if 'min' in attr.checks:
2070                 members.append(('min', attr.get_limit('min')))
2071             if 'max' in attr.checks:
2072                 members.append(('max', attr.get_limit('max')))
2073             cw.write_struct_init(members)
2074             cw.block_end(line=';')
2075             cw.nl()
2076
2077
2078 def print_kernel_op_table_fwd(family, cw, terminate):
2079     exported = not kernel_can_gen_family_struct(family)
2080
2081     if not terminate or exported:
2082         cw.p(f"/* Ops table for {family.name} */")
2083
2084         pol_to_struct = {'global': 'genl_small_ops',
2085                          'per-op': 'genl_ops',
2086                          'split': 'genl_split_ops'}
2087         struct_type = pol_to_struct[family.kernel_policy]
2088
2089         if not exported:
2090             cnt = ""
2091         elif family.kernel_policy == 'split':
2092             cnt = 0
2093             for op in family.ops.values():
2094                 if 'do' in op:
2095                     cnt += 1
2096                 if 'dump' in op:
2097                     cnt += 1
2098         else:
2099             cnt = len(family.ops)
2100
2101         qual = 'static const' if not exported else 'const'
2102         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
2103         if terminate:
2104             cw.p(f"extern {line};")
2105         else:
2106             cw.block_start(line=line + ' =')
2107
2108     if not terminate:
2109         return
2110
2111     cw.nl()
2112     for name in family.hooks['pre']['do']['list']:
2113         cw.write_func_prot('int', c_lower(name),
2114                            ['const struct genl_split_ops *ops',
2115                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2116     for name in family.hooks['post']['do']['list']:
2117         cw.write_func_prot('void', c_lower(name),
2118                            ['const struct genl_split_ops *ops',
2119                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2120     for name in family.hooks['pre']['dump']['list']:
2121         cw.write_func_prot('int', c_lower(name),
2122                            ['struct netlink_callback *cb'], suffix=';')
2123     for name in family.hooks['post']['dump']['list']:
2124         cw.write_func_prot('int', c_lower(name),
2125                            ['struct netlink_callback *cb'], suffix=';')
2126
2127     cw.nl()
2128
2129     for op_name, op in family.ops.items():
2130         if op.is_async:
2131             continue
2132
2133         if 'do' in op:
2134             name = c_lower(f"{family.name}-nl-{op_name}-doit")
2135             cw.write_func_prot('int', name,
2136                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2137
2138         if 'dump' in op:
2139             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
2140             cw.write_func_prot('int', name,
2141                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2142     cw.nl()
2143
2144
2145 def print_kernel_op_table_hdr(family, cw):
2146     print_kernel_op_table_fwd(family, cw, terminate=True)
2147
2148
2149 def print_kernel_op_table(family, cw):
2150     print_kernel_op_table_fwd(family, cw, terminate=False)
2151     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2152         for op_name, op in family.ops.items():
2153             if op.is_async:
2154                 continue
2155
2156             cw.ifdef_block(op.get('config-cond', None))
2157             cw.block_start()
2158             members = [('cmd', op.enum_name)]
2159             if 'dont-validate' in op:
2160                 members.append(('validate',
2161                                 ' | '.join([c_upper('genl-dont-validate-' + x)
2162                                             for x in op['dont-validate']])), )
2163             for op_mode in ['do', 'dump']:
2164                 if op_mode in op:
2165                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2166                     members.append((op_mode + 'it', name))
2167             if family.kernel_policy == 'per-op':
2168                 struct = Struct(family, op['attribute-set'],
2169                                 type_list=op['do']['request']['attributes'])
2170
2171                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
2172                 members.append(('policy', name))
2173                 members.append(('maxattr', struct.attr_max_val.enum_name))
2174             if 'flags' in op:
2175                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2176             cw.write_struct_init(members)
2177             cw.block_end(line=',')
2178     elif family.kernel_policy == 'split':
2179         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2180                     'dump': {'pre': 'start', 'post': 'done'}}
2181
2182         for op_name, op in family.ops.items():
2183             for op_mode in ['do', 'dump']:
2184                 if op.is_async or op_mode not in op:
2185                     continue
2186
2187                 cw.ifdef_block(op.get('config-cond', None))
2188                 cw.block_start()
2189                 members = [('cmd', op.enum_name)]
2190                 if 'dont-validate' in op:
2191                     dont_validate = []
2192                     for x in op['dont-validate']:
2193                         if op_mode == 'do' and x in ['dump', 'dump-strict']:
2194                             continue
2195                         if op_mode == "dump" and x == 'strict':
2196                             continue
2197                         dont_validate.append(x)
2198
2199                     if dont_validate:
2200                         members.append(('validate',
2201                                         ' | '.join([c_upper('genl-dont-validate-' + x)
2202                                                     for x in dont_validate])), )
2203                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2204                 if 'pre' in op[op_mode]:
2205                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2206                 members.append((op_mode + 'it', name))
2207                 if 'post' in op[op_mode]:
2208                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2209                 if 'request' in op[op_mode]:
2210                     struct = Struct(family, op['attribute-set'],
2211                                     type_list=op[op_mode]['request']['attributes'])
2212
2213                     if op.dual_policy:
2214                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2215                     else:
2216                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
2217                     members.append(('policy', name))
2218                     members.append(('maxattr', struct.attr_max_val.enum_name))
2219                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2220                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2221                 cw.write_struct_init(members)
2222                 cw.block_end(line=',')
2223     cw.ifdef_block(None)
2224
2225     cw.block_end(line=';')
2226     cw.nl()
2227
2228
2229 def print_kernel_mcgrp_hdr(family, cw):
2230     if not family.mcgrps['list']:
2231         return
2232
2233     cw.block_start('enum')
2234     for grp in family.mcgrps['list']:
2235         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2236         cw.p(grp_id)
2237     cw.block_end(';')
2238     cw.nl()
2239
2240
2241 def print_kernel_mcgrp_src(family, cw):
2242     if not family.mcgrps['list']:
2243         return
2244
2245     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2246     for grp in family.mcgrps['list']:
2247         name = grp['name']
2248         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2249         cw.p('[' + grp_id + '] = { "' + name + '", },')
2250     cw.block_end(';')
2251     cw.nl()
2252
2253
2254 def print_kernel_family_struct_hdr(family, cw):
2255     if not kernel_can_gen_family_struct(family):
2256         return
2257
2258     cw.p(f"extern struct genl_family {family.name}_nl_family;")
2259     cw.nl()
2260
2261
2262 def print_kernel_family_struct_src(family, cw):
2263     if not kernel_can_gen_family_struct(family):
2264         return
2265
2266     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2267     cw.p('.name\t\t= ' + family.fam_key + ',')
2268     cw.p('.version\t= ' + family.ver_key + ',')
2269     cw.p('.netnsok\t= true,')
2270     cw.p('.parallel_ops\t= true,')
2271     cw.p('.module\t\t= THIS_MODULE,')
2272     if family.kernel_policy == 'per-op':
2273         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2274         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2275     elif family.kernel_policy == 'split':
2276         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2277         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2278     if family.mcgrps['list']:
2279         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2280         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2281     cw.block_end(';')
2282
2283
2284 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2285     start_line = 'enum'
2286     if enum_name in obj:
2287         if obj[enum_name]:
2288             start_line = 'enum ' + c_lower(obj[enum_name])
2289     elif ckey and ckey in obj:
2290         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2291     cw.block_start(line=start_line)
2292
2293
2294 def render_uapi(family, cw):
2295     hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2296     cw.p('#ifndef ' + hdr_prot)
2297     cw.p('#define ' + hdr_prot)
2298     cw.nl()
2299
2300     defines = [(family.fam_key, family["name"]),
2301                (family.ver_key, family.get('version', 1))]
2302     cw.writes_defines(defines)
2303     cw.nl()
2304
2305     defines = []
2306     for const in family['definitions']:
2307         if const['type'] != 'const':
2308             cw.writes_defines(defines)
2309             defines = []
2310             cw.nl()
2311
2312         # Write kdoc for enum and flags (one day maybe also structs)
2313         if const['type'] == 'enum' or const['type'] == 'flags':
2314             enum = family.consts[const['name']]
2315
2316             if enum.has_doc():
2317                 cw.p('/**')
2318                 doc = ''
2319                 if 'doc' in enum:
2320                     doc = ' - ' + enum['doc']
2321                 cw.write_doc_line(enum.enum_name + doc)
2322                 for entry in enum.entries.values():
2323                     if entry.has_doc():
2324                         doc = '@' + entry.c_name + ': ' + entry['doc']
2325                         cw.write_doc_line(doc)
2326                 cw.p(' */')
2327
2328             uapi_enum_start(family, cw, const, 'name')
2329             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2330             for entry in enum.entries.values():
2331                 suffix = ','
2332                 if entry.value_change:
2333                     suffix = f" = {entry.user_value()}" + suffix
2334                 cw.p(entry.c_name + suffix)
2335
2336             if const.get('render-max', False):
2337                 cw.nl()
2338                 cw.p('/* private: */')
2339                 if const['type'] == 'flags':
2340                     max_name = c_upper(name_pfx + 'mask')
2341                     max_val = f' = {enum.get_mask()},'
2342                     cw.p(max_name + max_val)
2343                 else:
2344                     max_name = c_upper(name_pfx + 'max')
2345                     cw.p('__' + max_name + ',')
2346                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2347             cw.block_end(line=';')
2348             cw.nl()
2349         elif const['type'] == 'const':
2350             defines.append([c_upper(family.get('c-define-name',
2351                                                f"{family.name}-{const['name']}")),
2352                             const['value']])
2353
2354     if defines:
2355         cw.writes_defines(defines)
2356         cw.nl()
2357
2358     max_by_define = family.get('max-by-define', False)
2359
2360     for _, attr_set in family.attr_sets.items():
2361         if attr_set.subset_of:
2362             continue
2363
2364         max_value = f"({attr_set.cnt_name} - 1)"
2365
2366         val = 0
2367         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2368         for _, attr in attr_set.items():
2369             suffix = ','
2370             if attr.value != val:
2371                 suffix = f" = {attr.value},"
2372                 val = attr.value
2373             val += 1
2374             cw.p(attr.enum_name + suffix)
2375         cw.nl()
2376         cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2377         if not max_by_define:
2378             cw.p(f"{attr_set.max_name} = {max_value}")
2379         cw.block_end(line=';')
2380         if max_by_define:
2381             cw.p(f"#define {attr_set.max_name} {max_value}")
2382         cw.nl()
2383
2384     # Commands
2385     separate_ntf = 'async-prefix' in family['operations']
2386
2387     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2388     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2389     max_value = f"({cnt_name} - 1)"
2390
2391     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2392     val = 0
2393     for op in family.msgs.values():
2394         if separate_ntf and ('notify' in op or 'event' in op):
2395             continue
2396
2397         suffix = ','
2398         if op.value != val:
2399             suffix = f" = {op.value},"
2400             val = op.value
2401         cw.p(op.enum_name + suffix)
2402         val += 1
2403     cw.nl()
2404     cw.p(cnt_name + ('' if max_by_define else ','))
2405     if not max_by_define:
2406         cw.p(f"{max_name} = {max_value}")
2407     cw.block_end(line=';')
2408     if max_by_define:
2409         cw.p(f"#define {max_name} {max_value}")
2410     cw.nl()
2411
2412     if separate_ntf:
2413         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2414         for op in family.msgs.values():
2415             if separate_ntf and not ('notify' in op or 'event' in op):
2416                 continue
2417
2418             suffix = ','
2419             if 'value' in op:
2420                 suffix = f" = {op['value']},"
2421             cw.p(op.enum_name + suffix)
2422         cw.block_end(line=';')
2423         cw.nl()
2424
2425     # Multicast
2426     defines = []
2427     for grp in family.mcgrps['list']:
2428         name = grp['name']
2429         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2430                         f'{name}'])
2431     cw.nl()
2432     if defines:
2433         cw.writes_defines(defines)
2434         cw.nl()
2435
2436     cw.p(f'#endif /* {hdr_prot} */')
2437
2438
2439 def _render_user_ntf_entry(ri, op):
2440     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2441     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2442     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2443     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2444     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2445     ri.cw.block_end(line=',')
2446
2447
2448 def render_user_family(family, cw, prototype):
2449     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2450     if prototype:
2451         cw.p(f'extern {symbol};')
2452         return
2453
2454     if family.ntfs:
2455         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2456         for ntf_op_name, ntf_op in family.ntfs.items():
2457             if 'notify' in ntf_op:
2458                 op = family.ops[ntf_op['notify']]
2459                 ri = RenderInfo(cw, family, "user", op, "notify")
2460             elif 'event' in ntf_op:
2461                 ri = RenderInfo(cw, family, "user", ntf_op, "event")
2462             else:
2463                 raise Exception('Invalid notification ' + ntf_op_name)
2464             _render_user_ntf_entry(ri, ntf_op)
2465         for op_name, op in family.ops.items():
2466             if 'event' not in op:
2467                 continue
2468             ri = RenderInfo(cw, family, "user", op, "event")
2469             _render_user_ntf_entry(ri, op)
2470         cw.block_end(line=";")
2471         cw.nl()
2472
2473     cw.block_start(f'{symbol} = ')
2474     cw.p(f'.name\t\t= "{family.name}",')
2475     if family.ntfs:
2476         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2477         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2478     cw.block_end(line=';')
2479
2480
2481 def family_contains_bitfield32(family):
2482     for _, attr_set in family.attr_sets.items():
2483         if attr_set.subset_of:
2484             continue
2485         for _, attr in attr_set.items():
2486             if attr.type == "bitfield32":
2487                 return True
2488     return False
2489
2490
2491 def find_kernel_root(full_path):
2492     sub_path = ''
2493     while True:
2494         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2495         full_path = os.path.dirname(full_path)
2496         maintainers = os.path.join(full_path, "MAINTAINERS")
2497         if os.path.exists(maintainers):
2498             return full_path, sub_path[:-1]
2499
2500
2501 def main():
2502     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2503     parser.add_argument('--mode', dest='mode', type=str, required=True)
2504     parser.add_argument('--spec', dest='spec', type=str, required=True)
2505     parser.add_argument('--header', dest='header', action='store_true', default=None)
2506     parser.add_argument('--source', dest='header', action='store_false')
2507     parser.add_argument('--user-header', nargs='+', default=[])
2508     parser.add_argument('--exclude-op', action='append', default=[])
2509     parser.add_argument('-o', dest='out_file', type=str, default=None)
2510     args = parser.parse_args()
2511
2512     if args.header is None:
2513         parser.error("--header or --source is required")
2514
2515     exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2516
2517     try:
2518         parsed = Family(args.spec, exclude_ops)
2519         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2520             print('Spec license:', parsed.license)
2521             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2522             os.sys.exit(1)
2523     except yaml.YAMLError as exc:
2524         print(exc)
2525         os.sys.exit(1)
2526         return
2527
2528     supported_models = ['unified']
2529     if args.mode in ['user', 'kernel']:
2530         supported_models += ['directional']
2531     if parsed.msg_id_model not in supported_models:
2532         print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2533         os.sys.exit(1)
2534
2535     cw = CodeWriter(BaseNlLib(), args.out_file)
2536
2537     _, spec_kernel = find_kernel_root(args.spec)
2538     if args.mode == 'uapi' or args.header:
2539         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2540     else:
2541         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2542     cw.p("/* Do not edit directly, auto-generated from: */")
2543     cw.p(f"/*\t{spec_kernel} */")
2544     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2545     if args.exclude_op or args.user_header:
2546         line = ''
2547         line += ' --user-header '.join([''] + args.user_header)
2548         line += ' --exclude-op '.join([''] + args.exclude_op)
2549         cw.p(f'/* YNL-ARG{line} */')
2550     cw.nl()
2551
2552     if args.mode == 'uapi':
2553         render_uapi(parsed, cw)
2554         return
2555
2556     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2557     if args.header:
2558         cw.p('#ifndef ' + hdr_prot)
2559         cw.p('#define ' + hdr_prot)
2560         cw.nl()
2561
2562     if args.mode == 'kernel':
2563         cw.p('#include <net/netlink.h>')
2564         cw.p('#include <net/genetlink.h>')
2565         cw.nl()
2566         if not args.header:
2567             if args.out_file:
2568                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2569             cw.nl()
2570         headers = ['uapi/' + parsed.uapi_header]
2571     else:
2572         cw.p('#include <stdlib.h>')
2573         cw.p('#include <string.h>')
2574         if args.header:
2575             cw.p('#include <linux/types.h>')
2576             if family_contains_bitfield32(parsed):
2577                 cw.p('#include <linux/netlink.h>')
2578         else:
2579             cw.p(f'#include "{parsed.name}-user.h"')
2580             cw.p('#include "ynl.h"')
2581         headers = [parsed.uapi_header]
2582     for definition in parsed['definitions']:
2583         if 'header' in definition:
2584             headers.append(definition['header'])
2585     for one in headers:
2586         cw.p(f"#include <{one}>")
2587     cw.nl()
2588
2589     if args.mode == "user":
2590         if not args.header:
2591             cw.p("#include <libmnl/libmnl.h>")
2592             cw.p("#include <linux/genetlink.h>")
2593             cw.nl()
2594             for one in args.user_header:
2595                 cw.p(f'#include "{one}"')
2596         else:
2597             cw.p('struct ynl_sock;')
2598             cw.nl()
2599             render_user_family(parsed, cw, True)
2600         cw.nl()
2601
2602     if args.mode == "kernel":
2603         if args.header:
2604             for _, struct in sorted(parsed.pure_nested_structs.items()):
2605                 if struct.request:
2606                     cw.p('/* Common nested types */')
2607                     break
2608             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2609                 if struct.request:
2610                     print_req_policy_fwd(cw, struct)
2611             cw.nl()
2612
2613             if parsed.kernel_policy == 'global':
2614                 cw.p(f"/* Global operation policy for {parsed.name} */")
2615
2616                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2617                 print_req_policy_fwd(cw, struct)
2618                 cw.nl()
2619
2620             if parsed.kernel_policy in {'per-op', 'split'}:
2621                 for op_name, op in parsed.ops.items():
2622                     if 'do' in op and 'event' not in op:
2623                         ri = RenderInfo(cw, parsed, args.mode, op, "do")
2624                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2625                         cw.nl()
2626
2627             print_kernel_op_table_hdr(parsed, cw)
2628             print_kernel_mcgrp_hdr(parsed, cw)
2629             print_kernel_family_struct_hdr(parsed, cw)
2630         else:
2631             print_kernel_policy_ranges(parsed, cw)
2632
2633             for _, struct in sorted(parsed.pure_nested_structs.items()):
2634                 if struct.request:
2635                     cw.p('/* Common nested types */')
2636                     break
2637             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2638                 if struct.request:
2639                     print_req_policy(cw, struct)
2640             cw.nl()
2641
2642             if parsed.kernel_policy == 'global':
2643                 cw.p(f"/* Global operation policy for {parsed.name} */")
2644
2645                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2646                 print_req_policy(cw, struct)
2647                 cw.nl()
2648
2649             for op_name, op in parsed.ops.items():
2650                 if parsed.kernel_policy in {'per-op', 'split'}:
2651                     for op_mode in ['do', 'dump']:
2652                         if op_mode in op and 'request' in op[op_mode]:
2653                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2654                             ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2655                             print_req_policy(cw, ri.struct['request'], ri=ri)
2656                             cw.nl()
2657
2658             print_kernel_op_table(parsed, cw)
2659             print_kernel_mcgrp_src(parsed, cw)
2660             print_kernel_family_struct_src(parsed, cw)
2661
2662     if args.mode == "user":
2663         if args.header:
2664             cw.p('/* Enums */')
2665             put_op_name_fwd(parsed, cw)
2666
2667             for name, const in parsed.consts.items():
2668                 if isinstance(const, EnumSet):
2669                     put_enum_to_str_fwd(parsed, cw, const)
2670             cw.nl()
2671
2672             cw.p('/* Common nested types */')
2673             for attr_set, struct in parsed.pure_nested_structs.items():
2674                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2675                 print_type_full(ri, struct)
2676
2677             for op_name, op in parsed.ops.items():
2678                 cw.p(f"/* ============== {op.enum_name} ============== */")
2679
2680                 if 'do' in op and 'event' not in op:
2681                     cw.p(f"/* {op.enum_name} - do */")
2682                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2683                     print_req_type(ri)
2684                     print_req_type_helpers(ri)
2685                     cw.nl()
2686                     print_rsp_type(ri)
2687                     print_rsp_type_helpers(ri)
2688                     cw.nl()
2689                     print_req_prototype(ri)
2690                     cw.nl()
2691
2692                 if 'dump' in op:
2693                     cw.p(f"/* {op.enum_name} - dump */")
2694                     ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2695                     print_req_type(ri)
2696                     print_req_type_helpers(ri)
2697                     if not ri.type_consistent:
2698                         print_rsp_type(ri)
2699                     print_wrapped_type(ri)
2700                     print_dump_prototype(ri)
2701                     cw.nl()
2702
2703                 if op.has_ntf:
2704                     cw.p(f"/* {op.enum_name} - notify */")
2705                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2706                     if not ri.type_consistent:
2707                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2708                     print_wrapped_type(ri)
2709
2710             for op_name, op in parsed.ntfs.items():
2711                 if 'event' in op:
2712                     ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2713                     cw.p(f"/* {op.enum_name} - event */")
2714                     print_rsp_type(ri)
2715                     cw.nl()
2716                     print_wrapped_type(ri)
2717             cw.nl()
2718         else:
2719             cw.p('/* Enums */')
2720             put_op_name(parsed, cw)
2721
2722             for name, const in parsed.consts.items():
2723                 if isinstance(const, EnumSet):
2724                     put_enum_to_str(parsed, cw, const)
2725             cw.nl()
2726
2727             cw.p('/* Policies */')
2728             for name in parsed.pure_nested_structs:
2729                 struct = Struct(parsed, name)
2730                 put_typol(cw, struct)
2731             for name in parsed.root_sets:
2732                 struct = Struct(parsed, name)
2733                 put_typol(cw, struct)
2734
2735             cw.p('/* Common nested types */')
2736             for attr_set, struct in parsed.pure_nested_structs.items():
2737                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2738
2739                 free_rsp_nested(ri, struct)
2740                 if struct.request:
2741                     put_req_nested(ri, struct)
2742                 if struct.reply:
2743                     parse_rsp_nested(ri, struct)
2744
2745             for op_name, op in parsed.ops.items():
2746                 cw.p(f"/* ============== {op.enum_name} ============== */")
2747                 if 'do' in op and 'event' not in op:
2748                     cw.p(f"/* {op.enum_name} - do */")
2749                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2750                     print_req_free(ri)
2751                     print_rsp_free(ri)
2752                     parse_rsp_msg(ri)
2753                     print_req(ri)
2754                     cw.nl()
2755
2756                 if 'dump' in op:
2757                     cw.p(f"/* {op.enum_name} - dump */")
2758                     ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2759                     if not ri.type_consistent:
2760                         parse_rsp_msg(ri, deref=True)
2761                     print_dump_type_free(ri)
2762                     print_dump(ri)
2763                     cw.nl()
2764
2765                 if op.has_ntf:
2766                     cw.p(f"/* {op.enum_name} - notify */")
2767                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2768                     if not ri.type_consistent:
2769                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2770                     print_ntf_type_free(ri)
2771
2772             for op_name, op in parsed.ntfs.items():
2773                 if 'event' in op:
2774                     cw.p(f"/* {op.enum_name} - event */")
2775
2776                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2777                     parse_rsp_msg(ri)
2778
2779                     ri = RenderInfo(cw, parsed, args.mode, op, "event")
2780                     print_ntf_type_free(ri)
2781             cw.nl()
2782             render_user_family(parsed, cw, False)
2783
2784     if args.header:
2785         cw.p(f'#endif /* {hdr_prot} */')
2786
2787
2788 if __name__ == "__main__":
2789     main()