Merge branch 'master' of git://blackhole.kfki.hu/nf-next
[linux-2.6-block.git] / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/rculist.h>
19 #include <net/netlink.h>
20 #include <net/net_namespace.h>
21 #include <net/netns/generic.h>
22
23 #include <linux/netfilter.h>
24 #include <linux/netfilter/x_tables.h>
25 #include <linux/netfilter/nfnetlink.h>
26 #include <linux/netfilter/ipset/ip_set.h>
27
28 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
29 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
30 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
31
32 struct ip_set_net {
33         struct ip_set * __rcu *ip_set_list;     /* all individual sets */
34         ip_set_id_t     ip_set_max;     /* max number of sets */
35         bool            is_deleted;     /* deleted by ip_set_net_exit */
36         bool            is_destroyed;   /* all sets are destroyed */
37 };
38
39 static unsigned int ip_set_net_id __read_mostly;
40
41 static inline struct ip_set_net *ip_set_pernet(struct net *net)
42 {
43         return net_generic(net, ip_set_net_id);
44 }
45
46 #define IP_SET_INC      64
47 #define STRNCMP(a, b)   (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
48
49 static unsigned int max_sets;
50
51 module_param(max_sets, int, 0600);
52 MODULE_PARM_DESC(max_sets, "maximal number of sets");
53 MODULE_LICENSE("GPL");
54 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
55 MODULE_DESCRIPTION("core IP set support");
56 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
57
58 /* When the nfnl mutex or ip_set_ref_lock is held: */
59 #define ip_set_dereference(p)           \
60         rcu_dereference_protected(p,    \
61                 lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
62                 lockdep_is_held(&ip_set_ref_lock))
63 #define ip_set(inst, id)                \
64         ip_set_dereference((inst)->ip_set_list)[id]
65 #define ip_set_ref_netlink(inst,id)     \
66         rcu_dereference_raw((inst)->ip_set_list)[id]
67
68 /* The set types are implemented in modules and registered set types
69  * can be found in ip_set_type_list. Adding/deleting types is
70  * serialized by ip_set_type_mutex.
71  */
72
73 static inline void
74 ip_set_type_lock(void)
75 {
76         mutex_lock(&ip_set_type_mutex);
77 }
78
79 static inline void
80 ip_set_type_unlock(void)
81 {
82         mutex_unlock(&ip_set_type_mutex);
83 }
84
85 /* Register and deregister settype */
86
87 static struct ip_set_type *
88 find_set_type(const char *name, u8 family, u8 revision)
89 {
90         struct ip_set_type *type;
91
92         list_for_each_entry_rcu(type, &ip_set_type_list, list)
93                 if (STRNCMP(type->name, name) &&
94                     (type->family == family ||
95                      type->family == NFPROTO_UNSPEC) &&
96                     revision >= type->revision_min &&
97                     revision <= type->revision_max)
98                         return type;
99         return NULL;
100 }
101
102 /* Unlock, try to load a set type module and lock again */
103 static bool
104 load_settype(const char *name)
105 {
106         nfnl_unlock(NFNL_SUBSYS_IPSET);
107         pr_debug("try to load ip_set_%s\n", name);
108         if (request_module("ip_set_%s", name) < 0) {
109                 pr_warn("Can't find ip_set type %s\n", name);
110                 nfnl_lock(NFNL_SUBSYS_IPSET);
111                 return false;
112         }
113         nfnl_lock(NFNL_SUBSYS_IPSET);
114         return true;
115 }
116
117 /* Find a set type and reference it */
118 #define find_set_type_get(name, family, revision, found)        \
119         __find_set_type_get(name, family, revision, found, false)
120
121 static int
122 __find_set_type_get(const char *name, u8 family, u8 revision,
123                     struct ip_set_type **found, bool retry)
124 {
125         struct ip_set_type *type;
126         int err;
127
128         if (retry && !load_settype(name))
129                 return -IPSET_ERR_FIND_TYPE;
130
131         rcu_read_lock();
132         *found = find_set_type(name, family, revision);
133         if (*found) {
134                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
135                 goto unlock;
136         }
137         /* Make sure the type is already loaded
138          * but we don't support the revision
139          */
140         list_for_each_entry_rcu(type, &ip_set_type_list, list)
141                 if (STRNCMP(type->name, name)) {
142                         err = -IPSET_ERR_FIND_TYPE;
143                         goto unlock;
144                 }
145         rcu_read_unlock();
146
147         return retry ? -IPSET_ERR_FIND_TYPE :
148                 __find_set_type_get(name, family, revision, found, true);
149
150 unlock:
151         rcu_read_unlock();
152         return err;
153 }
154
155 /* Find a given set type by name and family.
156  * If we succeeded, the supported minimal and maximum revisions are
157  * filled out.
158  */
159 #define find_set_type_minmax(name, family, min, max) \
160         __find_set_type_minmax(name, family, min, max, false)
161
162 static int
163 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
164                        bool retry)
165 {
166         struct ip_set_type *type;
167         bool found = false;
168
169         if (retry && !load_settype(name))
170                 return -IPSET_ERR_FIND_TYPE;
171
172         *min = 255; *max = 0;
173         rcu_read_lock();
174         list_for_each_entry_rcu(type, &ip_set_type_list, list)
175                 if (STRNCMP(type->name, name) &&
176                     (type->family == family ||
177                      type->family == NFPROTO_UNSPEC)) {
178                         found = true;
179                         if (type->revision_min < *min)
180                                 *min = type->revision_min;
181                         if (type->revision_max > *max)
182                                 *max = type->revision_max;
183                 }
184         rcu_read_unlock();
185         if (found)
186                 return 0;
187
188         return retry ? -IPSET_ERR_FIND_TYPE :
189                 __find_set_type_minmax(name, family, min, max, true);
190 }
191
192 #define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
193                          (f) == NFPROTO_IPV6 ? "inet6" : "any")
194
195 /* Register a set type structure. The type is identified by
196  * the unique triple of name, family and revision.
197  */
198 int
199 ip_set_type_register(struct ip_set_type *type)
200 {
201         int ret = 0;
202
203         if (type->protocol != IPSET_PROTOCOL) {
204                 pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
205                         type->name, family_name(type->family),
206                         type->revision_min, type->revision_max,
207                         type->protocol, IPSET_PROTOCOL);
208                 return -EINVAL;
209         }
210
211         ip_set_type_lock();
212         if (find_set_type(type->name, type->family, type->revision_min)) {
213                 /* Duplicate! */
214                 pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
215                         type->name, family_name(type->family),
216                         type->revision_min);
217                 ip_set_type_unlock();
218                 return -EINVAL;
219         }
220         list_add_rcu(&type->list, &ip_set_type_list);
221         pr_debug("type %s, family %s, revision %u:%u registered.\n",
222                  type->name, family_name(type->family),
223                  type->revision_min, type->revision_max);
224         ip_set_type_unlock();
225
226         return ret;
227 }
228 EXPORT_SYMBOL_GPL(ip_set_type_register);
229
230 /* Unregister a set type. There's a small race with ip_set_create */
231 void
232 ip_set_type_unregister(struct ip_set_type *type)
233 {
234         ip_set_type_lock();
235         if (!find_set_type(type->name, type->family, type->revision_min)) {
236                 pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
237                         type->name, family_name(type->family),
238                         type->revision_min);
239                 ip_set_type_unlock();
240                 return;
241         }
242         list_del_rcu(&type->list);
243         pr_debug("type %s, family %s with revision min %u unregistered.\n",
244                  type->name, family_name(type->family), type->revision_min);
245         ip_set_type_unlock();
246
247         synchronize_rcu();
248 }
249 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
250
251 /* Utility functions */
252 void *
253 ip_set_alloc(size_t size)
254 {
255         void *members = NULL;
256
257         if (size < KMALLOC_MAX_SIZE)
258                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
259
260         if (members) {
261                 pr_debug("%p: allocated with kmalloc\n", members);
262                 return members;
263         }
264
265         members = vzalloc(size);
266         if (!members)
267                 return NULL;
268         pr_debug("%p: allocated with vmalloc\n", members);
269
270         return members;
271 }
272 EXPORT_SYMBOL_GPL(ip_set_alloc);
273
274 void
275 ip_set_free(void *members)
276 {
277         pr_debug("%p: free with %s\n", members,
278                  is_vmalloc_addr(members) ? "vfree" : "kfree");
279         kvfree(members);
280 }
281 EXPORT_SYMBOL_GPL(ip_set_free);
282
283 static inline bool
284 flag_nested(const struct nlattr *nla)
285 {
286         return nla->nla_type & NLA_F_NESTED;
287 }
288
289 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
290         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
291         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
292                                             .len = sizeof(struct in6_addr) },
293 };
294
295 int
296 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
297 {
298         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
299
300         if (unlikely(!flag_nested(nla)))
301                 return -IPSET_ERR_PROTOCOL;
302         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla,
303                              ipaddr_policy, NULL))
304                 return -IPSET_ERR_PROTOCOL;
305         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
306                 return -IPSET_ERR_PROTOCOL;
307
308         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
309         return 0;
310 }
311 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
312
313 int
314 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
315 {
316         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
317
318         if (unlikely(!flag_nested(nla)))
319                 return -IPSET_ERR_PROTOCOL;
320
321         if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla,
322                              ipaddr_policy, NULL))
323                 return -IPSET_ERR_PROTOCOL;
324         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
325                 return -IPSET_ERR_PROTOCOL;
326
327         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
328                sizeof(struct in6_addr));
329         return 0;
330 }
331 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
332
333 typedef void (*destroyer)(struct ip_set *, void *);
334 /* ipset data extension types, in size order */
335
336 const struct ip_set_ext_type ip_set_extensions[] = {
337         [IPSET_EXT_ID_COUNTER] = {
338                 .type   = IPSET_EXT_COUNTER,
339                 .flag   = IPSET_FLAG_WITH_COUNTERS,
340                 .len    = sizeof(struct ip_set_counter),
341                 .align  = __alignof__(struct ip_set_counter),
342         },
343         [IPSET_EXT_ID_TIMEOUT] = {
344                 .type   = IPSET_EXT_TIMEOUT,
345                 .len    = sizeof(unsigned long),
346                 .align  = __alignof__(unsigned long),
347         },
348         [IPSET_EXT_ID_SKBINFO] = {
349                 .type   = IPSET_EXT_SKBINFO,
350                 .flag   = IPSET_FLAG_WITH_SKBINFO,
351                 .len    = sizeof(struct ip_set_skbinfo),
352                 .align  = __alignof__(struct ip_set_skbinfo),
353         },
354         [IPSET_EXT_ID_COMMENT] = {
355                 .type    = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
356                 .flag    = IPSET_FLAG_WITH_COMMENT,
357                 .len     = sizeof(struct ip_set_comment),
358                 .align   = __alignof__(struct ip_set_comment),
359                 .destroy = (destroyer) ip_set_comment_free,
360         },
361 };
362 EXPORT_SYMBOL_GPL(ip_set_extensions);
363
364 static inline bool
365 add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
366 {
367         return ip_set_extensions[id].flag ?
368                 (flags & ip_set_extensions[id].flag) :
369                 !!tb[IPSET_ATTR_TIMEOUT];
370 }
371
372 size_t
373 ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
374                 size_t align)
375 {
376         enum ip_set_ext_id id;
377         u32 cadt_flags = 0;
378
379         if (tb[IPSET_ATTR_CADT_FLAGS])
380                 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
381         if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
382                 set->flags |= IPSET_CREATE_FLAG_FORCEADD;
383         if (!align)
384                 align = 1;
385         for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
386                 if (!add_extension(id, cadt_flags, tb))
387                         continue;
388                 len = ALIGN(len, ip_set_extensions[id].align);
389                 set->offset[id] = len;
390                 set->extensions |= ip_set_extensions[id].type;
391                 len += ip_set_extensions[id].len;
392         }
393         return ALIGN(len, align);
394 }
395 EXPORT_SYMBOL_GPL(ip_set_elem_len);
396
397 int
398 ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
399                       struct ip_set_ext *ext)
400 {
401         u64 fullmark;
402
403         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
404                      !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
405                      !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
406                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
407                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
408                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
409                 return -IPSET_ERR_PROTOCOL;
410
411         if (tb[IPSET_ATTR_TIMEOUT]) {
412                 if (!SET_WITH_TIMEOUT(set))
413                         return -IPSET_ERR_TIMEOUT;
414                 ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
415         }
416         if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
417                 if (!SET_WITH_COUNTER(set))
418                         return -IPSET_ERR_COUNTER;
419                 if (tb[IPSET_ATTR_BYTES])
420                         ext->bytes = be64_to_cpu(nla_get_be64(
421                                                  tb[IPSET_ATTR_BYTES]));
422                 if (tb[IPSET_ATTR_PACKETS])
423                         ext->packets = be64_to_cpu(nla_get_be64(
424                                                    tb[IPSET_ATTR_PACKETS]));
425         }
426         if (tb[IPSET_ATTR_COMMENT]) {
427                 if (!SET_WITH_COMMENT(set))
428                         return -IPSET_ERR_COMMENT;
429                 ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
430         }
431         if (tb[IPSET_ATTR_SKBMARK]) {
432                 if (!SET_WITH_SKBINFO(set))
433                         return -IPSET_ERR_SKBINFO;
434                 fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
435                 ext->skbinfo.skbmark = fullmark >> 32;
436                 ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
437         }
438         if (tb[IPSET_ATTR_SKBPRIO]) {
439                 if (!SET_WITH_SKBINFO(set))
440                         return -IPSET_ERR_SKBINFO;
441                 ext->skbinfo.skbprio =
442                         be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
443         }
444         if (tb[IPSET_ATTR_SKBQUEUE]) {
445                 if (!SET_WITH_SKBINFO(set))
446                         return -IPSET_ERR_SKBINFO;
447                 ext->skbinfo.skbqueue =
448                         be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
449         }
450         return 0;
451 }
452 EXPORT_SYMBOL_GPL(ip_set_get_extensions);
453
454 int
455 ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
456                       const void *e, bool active)
457 {
458         if (SET_WITH_TIMEOUT(set)) {
459                 unsigned long *timeout = ext_timeout(e, set);
460
461                 if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
462                         htonl(active ? ip_set_timeout_get(timeout)
463                                 : *timeout)))
464                         return -EMSGSIZE;
465         }
466         if (SET_WITH_COUNTER(set) &&
467             ip_set_put_counter(skb, ext_counter(e, set)))
468                 return -EMSGSIZE;
469         if (SET_WITH_COMMENT(set) &&
470             ip_set_put_comment(skb, ext_comment(e, set)))
471                 return -EMSGSIZE;
472         if (SET_WITH_SKBINFO(set) &&
473             ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
474                 return -EMSGSIZE;
475         return 0;
476 }
477 EXPORT_SYMBOL_GPL(ip_set_put_extensions);
478
479 bool
480 ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
481                         struct ip_set_ext *mext, u32 flags, void *data)
482 {
483         if (SET_WITH_TIMEOUT(set) &&
484             ip_set_timeout_expired(ext_timeout(data, set)))
485                 return false;
486         if (SET_WITH_COUNTER(set)) {
487                 struct ip_set_counter *counter = ext_counter(data, set);
488
489                 if (flags & IPSET_FLAG_MATCH_COUNTERS &&
490                     !(ip_set_match_counter(ip_set_get_packets(counter),
491                                 mext->packets, mext->packets_op) &&
492                       ip_set_match_counter(ip_set_get_bytes(counter),
493                                 mext->bytes, mext->bytes_op)))
494                         return false;
495                 ip_set_update_counter(counter, ext, flags);
496         }
497         if (SET_WITH_SKBINFO(set))
498                 ip_set_get_skbinfo(ext_skbinfo(data, set),
499                                    ext, mext, flags);
500         return true;
501 }
502 EXPORT_SYMBOL_GPL(ip_set_match_extensions);
503
504 /* Creating/destroying/renaming/swapping affect the existence and
505  * the properties of a set. All of these can be executed from userspace
506  * only and serialized by the nfnl mutex indirectly from nfnetlink.
507  *
508  * Sets are identified by their index in ip_set_list and the index
509  * is used by the external references (set/SET netfilter modules).
510  *
511  * The set behind an index may change by swapping only, from userspace.
512  */
513
514 static inline void
515 __ip_set_get(struct ip_set *set)
516 {
517         write_lock_bh(&ip_set_ref_lock);
518         set->ref++;
519         write_unlock_bh(&ip_set_ref_lock);
520 }
521
522 static inline void
523 __ip_set_put(struct ip_set *set)
524 {
525         write_lock_bh(&ip_set_ref_lock);
526         BUG_ON(set->ref == 0);
527         set->ref--;
528         write_unlock_bh(&ip_set_ref_lock);
529 }
530
531 /* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
532  * a separate reference counter
533  */
534 static inline void
535 __ip_set_put_netlink(struct ip_set *set)
536 {
537         write_lock_bh(&ip_set_ref_lock);
538         BUG_ON(set->ref_netlink == 0);
539         set->ref_netlink--;
540         write_unlock_bh(&ip_set_ref_lock);
541 }
542
543 /* Add, del and test set entries from kernel.
544  *
545  * The set behind the index must exist and must be referenced
546  * so it can't be destroyed (or changed) under our foot.
547  */
548
549 static inline struct ip_set *
550 ip_set_rcu_get(struct net *net, ip_set_id_t index)
551 {
552         struct ip_set *set;
553         struct ip_set_net *inst = ip_set_pernet(net);
554
555         rcu_read_lock();
556         /* ip_set_list itself needs to be protected */
557         set = rcu_dereference(inst->ip_set_list)[index];
558         rcu_read_unlock();
559
560         return set;
561 }
562
563 int
564 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
565             const struct xt_action_param *par, struct ip_set_adt_opt *opt)
566 {
567         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
568         int ret = 0;
569
570         BUG_ON(!set);
571         pr_debug("set %s, index %u\n", set->name, index);
572
573         if (opt->dim < set->type->dimension ||
574             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
575                 return 0;
576
577         rcu_read_lock_bh();
578         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
579         rcu_read_unlock_bh();
580
581         if (ret == -EAGAIN) {
582                 /* Type requests element to be completed */
583                 pr_debug("element must be completed, ADD is triggered\n");
584                 spin_lock_bh(&set->lock);
585                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
586                 spin_unlock_bh(&set->lock);
587                 ret = 1;
588         } else {
589                 /* --return-nomatch: invert matched element */
590                 if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
591                     (set->type->features & IPSET_TYPE_NOMATCH) &&
592                     (ret > 0 || ret == -ENOTEMPTY))
593                         ret = -ret;
594         }
595
596         /* Convert error codes to nomatch */
597         return (ret < 0 ? 0 : ret);
598 }
599 EXPORT_SYMBOL_GPL(ip_set_test);
600
601 int
602 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
603            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
604 {
605         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
606         int ret;
607
608         BUG_ON(!set);
609         pr_debug("set %s, index %u\n", set->name, index);
610
611         if (opt->dim < set->type->dimension ||
612             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
613                 return -IPSET_ERR_TYPE_MISMATCH;
614
615         spin_lock_bh(&set->lock);
616         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
617         spin_unlock_bh(&set->lock);
618
619         return ret;
620 }
621 EXPORT_SYMBOL_GPL(ip_set_add);
622
623 int
624 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
625            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
626 {
627         struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
628         int ret = 0;
629
630         BUG_ON(!set);
631         pr_debug("set %s, index %u\n", set->name, index);
632
633         if (opt->dim < set->type->dimension ||
634             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
635                 return -IPSET_ERR_TYPE_MISMATCH;
636
637         spin_lock_bh(&set->lock);
638         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
639         spin_unlock_bh(&set->lock);
640
641         return ret;
642 }
643 EXPORT_SYMBOL_GPL(ip_set_del);
644
645 /* Find set by name, reference it once. The reference makes sure the
646  * thing pointed to, does not go away under our feet.
647  *
648  */
649 ip_set_id_t
650 ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
651 {
652         ip_set_id_t i, index = IPSET_INVALID_ID;
653         struct ip_set *s;
654         struct ip_set_net *inst = ip_set_pernet(net);
655
656         rcu_read_lock();
657         for (i = 0; i < inst->ip_set_max; i++) {
658                 s = rcu_dereference(inst->ip_set_list)[i];
659                 if (s && STRNCMP(s->name, name)) {
660                         __ip_set_get(s);
661                         index = i;
662                         *set = s;
663                         break;
664                 }
665         }
666         rcu_read_unlock();
667
668         return index;
669 }
670 EXPORT_SYMBOL_GPL(ip_set_get_byname);
671
672 /* If the given set pointer points to a valid set, decrement
673  * reference count by 1. The caller shall not assume the index
674  * to be valid, after calling this function.
675  *
676  */
677
678 static inline void
679 __ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
680 {
681         struct ip_set *set;
682
683         rcu_read_lock();
684         set = rcu_dereference(inst->ip_set_list)[index];
685         if (set)
686                 __ip_set_put(set);
687         rcu_read_unlock();
688 }
689
690 void
691 ip_set_put_byindex(struct net *net, ip_set_id_t index)
692 {
693         struct ip_set_net *inst = ip_set_pernet(net);
694
695         __ip_set_put_byindex(inst, index);
696 }
697 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
698
699 /* Get the name of a set behind a set index.
700  * Set itself is protected by RCU, but its name isn't: to protect against
701  * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
702  * name.
703  */
704 void
705 ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
706 {
707         struct ip_set *set = ip_set_rcu_get(net, index);
708
709         BUG_ON(!set);
710
711         read_lock_bh(&ip_set_ref_lock);
712         strncpy(name, set->name, IPSET_MAXNAMELEN);
713         read_unlock_bh(&ip_set_ref_lock);
714 }
715 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
716
717 /* Routines to call by external subsystems, which do not
718  * call nfnl_lock for us.
719  */
720
721 /* Find set by index, reference it once. The reference makes sure the
722  * thing pointed to, does not go away under our feet.
723  *
724  * The nfnl mutex is used in the function.
725  */
726 ip_set_id_t
727 ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
728 {
729         struct ip_set *set;
730         struct ip_set_net *inst = ip_set_pernet(net);
731
732         if (index >= inst->ip_set_max)
733                 return IPSET_INVALID_ID;
734
735         nfnl_lock(NFNL_SUBSYS_IPSET);
736         set = ip_set(inst, index);
737         if (set)
738                 __ip_set_get(set);
739         else
740                 index = IPSET_INVALID_ID;
741         nfnl_unlock(NFNL_SUBSYS_IPSET);
742
743         return index;
744 }
745 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
746
747 /* If the given set pointer points to a valid set, decrement
748  * reference count by 1. The caller shall not assume the index
749  * to be valid, after calling this function.
750  *
751  * The nfnl mutex is used in the function.
752  */
753 void
754 ip_set_nfnl_put(struct net *net, ip_set_id_t index)
755 {
756         struct ip_set *set;
757         struct ip_set_net *inst = ip_set_pernet(net);
758
759         nfnl_lock(NFNL_SUBSYS_IPSET);
760         if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
761                 set = ip_set(inst, index);
762                 if (set)
763                         __ip_set_put(set);
764         }
765         nfnl_unlock(NFNL_SUBSYS_IPSET);
766 }
767 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
768
769 /* Communication protocol with userspace over netlink.
770  *
771  * The commands are serialized by the nfnl mutex.
772  */
773
774 static inline u8 protocol(const struct nlattr * const tb[])
775 {
776         return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
777 }
778
779 static inline bool
780 protocol_failed(const struct nlattr * const tb[])
781 {
782         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
783 }
784
785 static inline bool
786 protocol_min_failed(const struct nlattr * const tb[])
787 {
788         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
789 }
790
791 static inline u32
792 flag_exist(const struct nlmsghdr *nlh)
793 {
794         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
795 }
796
797 static struct nlmsghdr *
798 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
799           enum ipset_cmd cmd)
800 {
801         struct nlmsghdr *nlh;
802         struct nfgenmsg *nfmsg;
803
804         nlh = nlmsg_put(skb, portid, seq, nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd),
805                         sizeof(*nfmsg), flags);
806         if (!nlh)
807                 return NULL;
808
809         nfmsg = nlmsg_data(nlh);
810         nfmsg->nfgen_family = NFPROTO_IPV4;
811         nfmsg->version = NFNETLINK_V0;
812         nfmsg->res_id = 0;
813
814         return nlh;
815 }
816
817 /* Create a set */
818
819 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
820         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
821         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
822                                     .len = IPSET_MAXNAMELEN - 1 },
823         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
824                                     .len = IPSET_MAXNAMELEN - 1},
825         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
826         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
827         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
828 };
829
830 static struct ip_set *
831 find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
832 {
833         struct ip_set *set = NULL;
834         ip_set_id_t i;
835
836         *id = IPSET_INVALID_ID;
837         for (i = 0; i < inst->ip_set_max; i++) {
838                 set = ip_set(inst, i);
839                 if (set && STRNCMP(set->name, name)) {
840                         *id = i;
841                         break;
842                 }
843         }
844         return (*id == IPSET_INVALID_ID ? NULL : set);
845 }
846
847 static inline struct ip_set *
848 find_set(struct ip_set_net *inst, const char *name)
849 {
850         ip_set_id_t id;
851
852         return find_set_and_id(inst, name, &id);
853 }
854
855 static int
856 find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
857              struct ip_set **set)
858 {
859         struct ip_set *s;
860         ip_set_id_t i;
861
862         *index = IPSET_INVALID_ID;
863         for (i = 0;  i < inst->ip_set_max; i++) {
864                 s = ip_set(inst, i);
865                 if (!s) {
866                         if (*index == IPSET_INVALID_ID)
867                                 *index = i;
868                 } else if (STRNCMP(name, s->name)) {
869                         /* Name clash */
870                         *set = s;
871                         return -EEXIST;
872                 }
873         }
874         if (*index == IPSET_INVALID_ID)
875                 /* No free slot remained */
876                 return -IPSET_ERR_MAX_SETS;
877         return 0;
878 }
879
880 static int ip_set_none(struct net *net, struct sock *ctnl, struct sk_buff *skb,
881                        const struct nlmsghdr *nlh,
882                        const struct nlattr * const attr[],
883                        struct netlink_ext_ack *extack)
884 {
885         return -EOPNOTSUPP;
886 }
887
888 static int ip_set_create(struct net *net, struct sock *ctnl,
889                          struct sk_buff *skb, const struct nlmsghdr *nlh,
890                          const struct nlattr * const attr[],
891                          struct netlink_ext_ack *extack)
892 {
893         struct ip_set_net *inst = ip_set_pernet(net);
894         struct ip_set *set, *clash = NULL;
895         ip_set_id_t index = IPSET_INVALID_ID;
896         struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
897         const char *name, *typename;
898         u8 family, revision;
899         u32 flags = flag_exist(nlh);
900         int ret = 0;
901
902         if (unlikely(protocol_min_failed(attr) ||
903                      !attr[IPSET_ATTR_SETNAME] ||
904                      !attr[IPSET_ATTR_TYPENAME] ||
905                      !attr[IPSET_ATTR_REVISION] ||
906                      !attr[IPSET_ATTR_FAMILY] ||
907                      (attr[IPSET_ATTR_DATA] &&
908                       !flag_nested(attr[IPSET_ATTR_DATA]))))
909                 return -IPSET_ERR_PROTOCOL;
910
911         name = nla_data(attr[IPSET_ATTR_SETNAME]);
912         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
913         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
914         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
915         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
916                  name, typename, family_name(family), revision);
917
918         /* First, and without any locks, allocate and initialize
919          * a normal base set structure.
920          */
921         set = kzalloc(sizeof(*set), GFP_KERNEL);
922         if (!set)
923                 return -ENOMEM;
924         spin_lock_init(&set->lock);
925         strlcpy(set->name, name, IPSET_MAXNAMELEN);
926         set->family = family;
927         set->revision = revision;
928
929         /* Next, check that we know the type, and take
930          * a reference on the type, to make sure it stays available
931          * while constructing our new set.
932          *
933          * After referencing the type, we try to create the type
934          * specific part of the set without holding any locks.
935          */
936         ret = find_set_type_get(typename, family, revision, &set->type);
937         if (ret)
938                 goto out;
939
940         /* Without holding any locks, create private part. */
941         if (attr[IPSET_ATTR_DATA] &&
942             nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
943                              set->type->create_policy, NULL)) {
944                 ret = -IPSET_ERR_PROTOCOL;
945                 goto put_out;
946         }
947
948         ret = set->type->create(net, set, tb, flags);
949         if (ret != 0)
950                 goto put_out;
951
952         /* BTW, ret==0 here. */
953
954         /* Here, we have a valid, constructed set and we are protected
955          * by the nfnl mutex. Find the first free index in ip_set_list
956          * and check clashing.
957          */
958         ret = find_free_id(inst, set->name, &index, &clash);
959         if (ret == -EEXIST) {
960                 /* If this is the same set and requested, ignore error */
961                 if ((flags & IPSET_FLAG_EXIST) &&
962                     STRNCMP(set->type->name, clash->type->name) &&
963                     set->type->family == clash->type->family &&
964                     set->type->revision_min == clash->type->revision_min &&
965                     set->type->revision_max == clash->type->revision_max &&
966                     set->variant->same_set(set, clash))
967                         ret = 0;
968                 goto cleanup;
969         } else if (ret == -IPSET_ERR_MAX_SETS) {
970                 struct ip_set **list, **tmp;
971                 ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
972
973                 if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
974                         /* Wraparound */
975                         goto cleanup;
976
977                 list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
978                 if (!list)
979                         goto cleanup;
980                 /* nfnl mutex is held, both lists are valid */
981                 tmp = ip_set_dereference(inst->ip_set_list);
982                 memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
983                 rcu_assign_pointer(inst->ip_set_list, list);
984                 /* Make sure all current packets have passed through */
985                 synchronize_net();
986                 /* Use new list */
987                 index = inst->ip_set_max;
988                 inst->ip_set_max = i;
989                 kvfree(tmp);
990                 ret = 0;
991         } else if (ret) {
992                 goto cleanup;
993         }
994
995         /* Finally! Add our shiny new set to the list, and be done. */
996         pr_debug("create: '%s' created with index %u!\n", set->name, index);
997         ip_set(inst, index) = set;
998
999         return ret;
1000
1001 cleanup:
1002         set->variant->destroy(set);
1003 put_out:
1004         module_put(set->type->me);
1005 out:
1006         kfree(set);
1007         return ret;
1008 }
1009
1010 /* Destroy sets */
1011
1012 static const struct nla_policy
1013 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1014         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1015         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1016                                     .len = IPSET_MAXNAMELEN - 1 },
1017 };
1018
1019 static void
1020 ip_set_destroy_set(struct ip_set *set)
1021 {
1022         pr_debug("set: %s\n",  set->name);
1023
1024         /* Must call it without holding any lock */
1025         set->variant->destroy(set);
1026         module_put(set->type->me);
1027         kfree(set);
1028 }
1029
1030 static int ip_set_destroy(struct net *net, struct sock *ctnl,
1031                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1032                           const struct nlattr * const attr[],
1033                           struct netlink_ext_ack *extack)
1034 {
1035         struct ip_set_net *inst = ip_set_pernet(net);
1036         struct ip_set *s;
1037         ip_set_id_t i;
1038         int ret = 0;
1039
1040         if (unlikely(protocol_min_failed(attr)))
1041                 return -IPSET_ERR_PROTOCOL;
1042
1043         /* Must wait for flush to be really finished in list:set */
1044         rcu_barrier();
1045
1046         /* Commands are serialized and references are
1047          * protected by the ip_set_ref_lock.
1048          * External systems (i.e. xt_set) must call
1049          * ip_set_put|get_nfnl_* functions, that way we
1050          * can safely check references here.
1051          *
1052          * list:set timer can only decrement the reference
1053          * counter, so if it's already zero, we can proceed
1054          * without holding the lock.
1055          */
1056         read_lock_bh(&ip_set_ref_lock);
1057         if (!attr[IPSET_ATTR_SETNAME]) {
1058                 for (i = 0; i < inst->ip_set_max; i++) {
1059                         s = ip_set(inst, i);
1060                         if (s && (s->ref || s->ref_netlink)) {
1061                                 ret = -IPSET_ERR_BUSY;
1062                                 goto out;
1063                         }
1064                 }
1065                 inst->is_destroyed = true;
1066                 read_unlock_bh(&ip_set_ref_lock);
1067                 for (i = 0; i < inst->ip_set_max; i++) {
1068                         s = ip_set(inst, i);
1069                         if (s) {
1070                                 ip_set(inst, i) = NULL;
1071                                 ip_set_destroy_set(s);
1072                         }
1073                 }
1074                 /* Modified by ip_set_destroy() only, which is serialized */
1075                 inst->is_destroyed = false;
1076         } else {
1077                 s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1078                                     &i);
1079                 if (!s) {
1080                         ret = -ENOENT;
1081                         goto out;
1082                 } else if (s->ref || s->ref_netlink) {
1083                         ret = -IPSET_ERR_BUSY;
1084                         goto out;
1085                 }
1086                 ip_set(inst, i) = NULL;
1087                 read_unlock_bh(&ip_set_ref_lock);
1088
1089                 ip_set_destroy_set(s);
1090         }
1091         return 0;
1092 out:
1093         read_unlock_bh(&ip_set_ref_lock);
1094         return ret;
1095 }
1096
1097 /* Flush sets */
1098
1099 static void
1100 ip_set_flush_set(struct ip_set *set)
1101 {
1102         pr_debug("set: %s\n",  set->name);
1103
1104         spin_lock_bh(&set->lock);
1105         set->variant->flush(set);
1106         spin_unlock_bh(&set->lock);
1107 }
1108
1109 static int ip_set_flush(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1110                         const struct nlmsghdr *nlh,
1111                         const struct nlattr * const attr[],
1112                         struct netlink_ext_ack *extack)
1113 {
1114         struct ip_set_net *inst = ip_set_pernet(net);
1115         struct ip_set *s;
1116         ip_set_id_t i;
1117
1118         if (unlikely(protocol_min_failed(attr)))
1119                 return -IPSET_ERR_PROTOCOL;
1120
1121         if (!attr[IPSET_ATTR_SETNAME]) {
1122                 for (i = 0; i < inst->ip_set_max; i++) {
1123                         s = ip_set(inst, i);
1124                         if (s)
1125                                 ip_set_flush_set(s);
1126                 }
1127         } else {
1128                 s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1129                 if (!s)
1130                         return -ENOENT;
1131
1132                 ip_set_flush_set(s);
1133         }
1134
1135         return 0;
1136 }
1137
1138 /* Rename a set */
1139
1140 static const struct nla_policy
1141 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1142         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1143         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1144                                     .len = IPSET_MAXNAMELEN - 1 },
1145         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
1146                                     .len = IPSET_MAXNAMELEN - 1 },
1147 };
1148
1149 static int ip_set_rename(struct net *net, struct sock *ctnl,
1150                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1151                          const struct nlattr * const attr[],
1152                          struct netlink_ext_ack *extack)
1153 {
1154         struct ip_set_net *inst = ip_set_pernet(net);
1155         struct ip_set *set, *s;
1156         const char *name2;
1157         ip_set_id_t i;
1158         int ret = 0;
1159
1160         if (unlikely(protocol_min_failed(attr) ||
1161                      !attr[IPSET_ATTR_SETNAME] ||
1162                      !attr[IPSET_ATTR_SETNAME2]))
1163                 return -IPSET_ERR_PROTOCOL;
1164
1165         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1166         if (!set)
1167                 return -ENOENT;
1168
1169         write_lock_bh(&ip_set_ref_lock);
1170         if (set->ref != 0) {
1171                 ret = -IPSET_ERR_REFERENCED;
1172                 goto out;
1173         }
1174
1175         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1176         for (i = 0; i < inst->ip_set_max; i++) {
1177                 s = ip_set(inst, i);
1178                 if (s && STRNCMP(s->name, name2)) {
1179                         ret = -IPSET_ERR_EXIST_SETNAME2;
1180                         goto out;
1181                 }
1182         }
1183         strncpy(set->name, name2, IPSET_MAXNAMELEN);
1184
1185 out:
1186         write_unlock_bh(&ip_set_ref_lock);
1187         return ret;
1188 }
1189
1190 /* Swap two sets so that name/index points to the other.
1191  * References and set names are also swapped.
1192  *
1193  * The commands are serialized by the nfnl mutex and references are
1194  * protected by the ip_set_ref_lock. The kernel interfaces
1195  * do not hold the mutex but the pointer settings are atomic
1196  * so the ip_set_list always contains valid pointers to the sets.
1197  */
1198
1199 static int ip_set_swap(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1200                        const struct nlmsghdr *nlh,
1201                        const struct nlattr * const attr[],
1202                        struct netlink_ext_ack *extack)
1203 {
1204         struct ip_set_net *inst = ip_set_pernet(net);
1205         struct ip_set *from, *to;
1206         ip_set_id_t from_id, to_id;
1207         char from_name[IPSET_MAXNAMELEN];
1208
1209         if (unlikely(protocol_min_failed(attr) ||
1210                      !attr[IPSET_ATTR_SETNAME] ||
1211                      !attr[IPSET_ATTR_SETNAME2]))
1212                 return -IPSET_ERR_PROTOCOL;
1213
1214         from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1215                                &from_id);
1216         if (!from)
1217                 return -ENOENT;
1218
1219         to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1220                              &to_id);
1221         if (!to)
1222                 return -IPSET_ERR_EXIST_SETNAME2;
1223
1224         /* Features must not change.
1225          * Not an artifical restriction anymore, as we must prevent
1226          * possible loops created by swapping in setlist type of sets.
1227          */
1228         if (!(from->type->features == to->type->features &&
1229               from->family == to->family))
1230                 return -IPSET_ERR_TYPE_MISMATCH;
1231
1232         write_lock_bh(&ip_set_ref_lock);
1233
1234         if (from->ref_netlink || to->ref_netlink) {
1235                 write_unlock_bh(&ip_set_ref_lock);
1236                 return -EBUSY;
1237         }
1238
1239         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1240         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1241         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1242
1243         swap(from->ref, to->ref);
1244         ip_set(inst, from_id) = to;
1245         ip_set(inst, to_id) = from;
1246         write_unlock_bh(&ip_set_ref_lock);
1247
1248         return 0;
1249 }
1250
1251 /* List/save set data */
1252
1253 #define DUMP_INIT       0
1254 #define DUMP_ALL        1
1255 #define DUMP_ONE        2
1256 #define DUMP_LAST       3
1257
1258 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1259 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1260
1261 static int
1262 ip_set_dump_done(struct netlink_callback *cb)
1263 {
1264         if (cb->args[IPSET_CB_ARG0]) {
1265                 struct ip_set_net *inst =
1266                         (struct ip_set_net *)cb->args[IPSET_CB_NET];
1267                 ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1268                 struct ip_set *set = ip_set_ref_netlink(inst, index);
1269
1270                 if (set->variant->uref)
1271                         set->variant->uref(set, cb, false);
1272                 pr_debug("release set %s\n", set->name);
1273                 __ip_set_put_netlink(set);
1274         }
1275         return 0;
1276 }
1277
1278 static inline void
1279 dump_attrs(struct nlmsghdr *nlh)
1280 {
1281         const struct nlattr *attr;
1282         int rem;
1283
1284         pr_debug("dump nlmsg\n");
1285         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1286                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1287         }
1288 }
1289
1290 static int
1291 dump_init(struct netlink_callback *cb, struct ip_set_net *inst)
1292 {
1293         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1294         int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1295         struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1296         struct nlattr *attr = (void *)nlh + min_len;
1297         u32 dump_type;
1298         ip_set_id_t index;
1299
1300         /* Second pass, so parser can't fail */
1301         nla_parse(cda, IPSET_ATTR_CMD_MAX, attr, nlh->nlmsg_len - min_len,
1302                   ip_set_setname_policy, NULL);
1303
1304         cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1305         if (cda[IPSET_ATTR_SETNAME]) {
1306                 struct ip_set *set;
1307
1308                 set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1309                                       &index);
1310                 if (!set)
1311                         return -ENOENT;
1312
1313                 dump_type = DUMP_ONE;
1314                 cb->args[IPSET_CB_INDEX] = index;
1315         } else {
1316                 dump_type = DUMP_ALL;
1317         }
1318
1319         if (cda[IPSET_ATTR_FLAGS]) {
1320                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1321
1322                 dump_type |= (f << 16);
1323         }
1324         cb->args[IPSET_CB_NET] = (unsigned long)inst;
1325         cb->args[IPSET_CB_DUMP] = dump_type;
1326
1327         return 0;
1328 }
1329
1330 static int
1331 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1332 {
1333         ip_set_id_t index = IPSET_INVALID_ID, max;
1334         struct ip_set *set = NULL;
1335         struct nlmsghdr *nlh = NULL;
1336         unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1337         struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1338         u32 dump_type, dump_flags;
1339         bool is_destroyed;
1340         int ret = 0;
1341
1342         if (!cb->args[IPSET_CB_DUMP]) {
1343                 ret = dump_init(cb, inst);
1344                 if (ret < 0) {
1345                         nlh = nlmsg_hdr(cb->skb);
1346                         /* We have to create and send the error message
1347                          * manually :-(
1348                          */
1349                         if (nlh->nlmsg_flags & NLM_F_ACK)
1350                                 netlink_ack(cb->skb, nlh, ret, NULL);
1351                         return ret;
1352                 }
1353         }
1354
1355         if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1356                 goto out;
1357
1358         dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1359         dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1360         max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1361                                     : inst->ip_set_max;
1362 dump_last:
1363         pr_debug("dump type, flag: %u %u index: %ld\n",
1364                  dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1365         for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1366                 index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1367                 write_lock_bh(&ip_set_ref_lock);
1368                 set = ip_set(inst, index);
1369                 is_destroyed = inst->is_destroyed;
1370                 if (!set || is_destroyed) {
1371                         write_unlock_bh(&ip_set_ref_lock);
1372                         if (dump_type == DUMP_ONE) {
1373                                 ret = -ENOENT;
1374                                 goto out;
1375                         }
1376                         if (is_destroyed) {
1377                                 /* All sets are just being destroyed */
1378                                 ret = 0;
1379                                 goto out;
1380                         }
1381                         continue;
1382                 }
1383                 /* When dumping all sets, we must dump "sorted"
1384                  * so that lists (unions of sets) are dumped last.
1385                  */
1386                 if (dump_type != DUMP_ONE &&
1387                     ((dump_type == DUMP_ALL) ==
1388                      !!(set->type->features & IPSET_DUMP_LAST))) {
1389                         write_unlock_bh(&ip_set_ref_lock);
1390                         continue;
1391                 }
1392                 pr_debug("List set: %s\n", set->name);
1393                 if (!cb->args[IPSET_CB_ARG0]) {
1394                         /* Start listing: make sure set won't be destroyed */
1395                         pr_debug("reference set\n");
1396                         set->ref_netlink++;
1397                 }
1398                 write_unlock_bh(&ip_set_ref_lock);
1399                 nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1400                                 cb->nlh->nlmsg_seq, flags,
1401                                 IPSET_CMD_LIST);
1402                 if (!nlh) {
1403                         ret = -EMSGSIZE;
1404                         goto release_refcount;
1405                 }
1406                 if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1407                                cb->args[IPSET_CB_PROTO]) ||
1408                     nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1409                         goto nla_put_failure;
1410                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1411                         goto next_set;
1412                 switch (cb->args[IPSET_CB_ARG0]) {
1413                 case 0:
1414                         /* Core header data */
1415                         if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1416                                            set->type->name) ||
1417                             nla_put_u8(skb, IPSET_ATTR_FAMILY,
1418                                        set->family) ||
1419                             nla_put_u8(skb, IPSET_ATTR_REVISION,
1420                                        set->revision))
1421                                 goto nla_put_failure;
1422                         if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1423                             nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1424                                 goto nla_put_failure;
1425                         ret = set->variant->head(set, skb);
1426                         if (ret < 0)
1427                                 goto release_refcount;
1428                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1429                                 goto next_set;
1430                         if (set->variant->uref)
1431                                 set->variant->uref(set, cb, true);
1432                         /* fall through */
1433                 default:
1434                         ret = set->variant->list(set, skb, cb);
1435                         if (!cb->args[IPSET_CB_ARG0])
1436                                 /* Set is done, proceed with next one */
1437                                 goto next_set;
1438                         goto release_refcount;
1439                 }
1440         }
1441         /* If we dump all sets, continue with dumping last ones */
1442         if (dump_type == DUMP_ALL) {
1443                 dump_type = DUMP_LAST;
1444                 cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1445                 cb->args[IPSET_CB_INDEX] = 0;
1446                 if (set && set->variant->uref)
1447                         set->variant->uref(set, cb, false);
1448                 goto dump_last;
1449         }
1450         goto out;
1451
1452 nla_put_failure:
1453         ret = -EFAULT;
1454 next_set:
1455         if (dump_type == DUMP_ONE)
1456                 cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1457         else
1458                 cb->args[IPSET_CB_INDEX]++;
1459 release_refcount:
1460         /* If there was an error or set is done, release set */
1461         if (ret || !cb->args[IPSET_CB_ARG0]) {
1462                 set = ip_set_ref_netlink(inst, index);
1463                 if (set->variant->uref)
1464                         set->variant->uref(set, cb, false);
1465                 pr_debug("release set %s\n", set->name);
1466                 __ip_set_put_netlink(set);
1467                 cb->args[IPSET_CB_ARG0] = 0;
1468         }
1469 out:
1470         if (nlh) {
1471                 nlmsg_end(skb, nlh);
1472                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1473                 dump_attrs(nlh);
1474         }
1475
1476         return ret < 0 ? ret : skb->len;
1477 }
1478
1479 static int ip_set_dump(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1480                        const struct nlmsghdr *nlh,
1481                        const struct nlattr * const attr[],
1482                        struct netlink_ext_ack *extack)
1483 {
1484         if (unlikely(protocol_min_failed(attr)))
1485                 return -IPSET_ERR_PROTOCOL;
1486
1487         {
1488                 struct netlink_dump_control c = {
1489                         .dump = ip_set_dump_start,
1490                         .done = ip_set_dump_done,
1491                 };
1492                 return netlink_dump_start(ctnl, skb, nlh, &c);
1493         }
1494 }
1495
1496 /* Add, del and test */
1497
1498 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1499         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1500         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1501                                     .len = IPSET_MAXNAMELEN - 1 },
1502         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1503         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1504         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1505 };
1506
1507 static int
1508 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1509         struct nlattr *tb[], enum ipset_adt adt,
1510         u32 flags, bool use_lineno)
1511 {
1512         int ret;
1513         u32 lineno = 0;
1514         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1515
1516         do {
1517                 spin_lock_bh(&set->lock);
1518                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1519                 spin_unlock_bh(&set->lock);
1520                 retried = true;
1521         } while (ret == -EAGAIN &&
1522                  set->variant->resize &&
1523                  (ret = set->variant->resize(set, retried)) == 0);
1524
1525         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1526                 return 0;
1527         if (lineno && use_lineno) {
1528                 /* Error in restore/batch mode: send back lineno */
1529                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1530                 struct sk_buff *skb2;
1531                 struct nlmsgerr *errmsg;
1532                 size_t payload = min(SIZE_MAX,
1533                                      sizeof(*errmsg) + nlmsg_len(nlh));
1534                 int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1535                 struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1536                 struct nlattr *cmdattr;
1537                 u32 *errline;
1538
1539                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1540                 if (!skb2)
1541                         return -ENOMEM;
1542                 rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1543                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1544                 errmsg = nlmsg_data(rep);
1545                 errmsg->error = ret;
1546                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1547                 cmdattr = (void *)&errmsg->msg + min_len;
1548
1549                 nla_parse(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1550                           nlh->nlmsg_len - min_len, ip_set_adt_policy, NULL);
1551
1552                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1553
1554                 *errline = lineno;
1555
1556                 netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid,
1557                                 MSG_DONTWAIT);
1558                 /* Signal netlink not to send its ACK/errmsg.  */
1559                 return -EINTR;
1560         }
1561
1562         return ret;
1563 }
1564
1565 static int ip_set_uadd(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1566                        const struct nlmsghdr *nlh,
1567                        const struct nlattr * const attr[],
1568                        struct netlink_ext_ack *extack)
1569 {
1570         struct ip_set_net *inst = ip_set_pernet(net);
1571         struct ip_set *set;
1572         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1573         const struct nlattr *nla;
1574         u32 flags = flag_exist(nlh);
1575         bool use_lineno;
1576         int ret = 0;
1577
1578         if (unlikely(protocol_min_failed(attr) ||
1579                      !attr[IPSET_ATTR_SETNAME] ||
1580                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1581                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1582                      (attr[IPSET_ATTR_DATA] &&
1583                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1584                      (attr[IPSET_ATTR_ADT] &&
1585                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1586                        !attr[IPSET_ATTR_LINENO]))))
1587                 return -IPSET_ERR_PROTOCOL;
1588
1589         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1590         if (!set)
1591                 return -ENOENT;
1592
1593         use_lineno = !!attr[IPSET_ATTR_LINENO];
1594         if (attr[IPSET_ATTR_DATA]) {
1595                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1596                                      attr[IPSET_ATTR_DATA],
1597                                      set->type->adt_policy, NULL))
1598                         return -IPSET_ERR_PROTOCOL;
1599                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1600                               use_lineno);
1601         } else {
1602                 int nla_rem;
1603
1604                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1605                         memset(tb, 0, sizeof(tb));
1606                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1607                             !flag_nested(nla) ||
1608                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1609                                              set->type->adt_policy, NULL))
1610                                 return -IPSET_ERR_PROTOCOL;
1611                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1612                                       flags, use_lineno);
1613                         if (ret < 0)
1614                                 return ret;
1615                 }
1616         }
1617         return ret;
1618 }
1619
1620 static int ip_set_udel(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1621                        const struct nlmsghdr *nlh,
1622                        const struct nlattr * const attr[],
1623                        struct netlink_ext_ack *extack)
1624 {
1625         struct ip_set_net *inst = ip_set_pernet(net);
1626         struct ip_set *set;
1627         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1628         const struct nlattr *nla;
1629         u32 flags = flag_exist(nlh);
1630         bool use_lineno;
1631         int ret = 0;
1632
1633         if (unlikely(protocol_min_failed(attr) ||
1634                      !attr[IPSET_ATTR_SETNAME] ||
1635                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1636                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1637                      (attr[IPSET_ATTR_DATA] &&
1638                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1639                      (attr[IPSET_ATTR_ADT] &&
1640                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1641                        !attr[IPSET_ATTR_LINENO]))))
1642                 return -IPSET_ERR_PROTOCOL;
1643
1644         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1645         if (!set)
1646                 return -ENOENT;
1647
1648         use_lineno = !!attr[IPSET_ATTR_LINENO];
1649         if (attr[IPSET_ATTR_DATA]) {
1650                 if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX,
1651                                      attr[IPSET_ATTR_DATA],
1652                                      set->type->adt_policy, NULL))
1653                         return -IPSET_ERR_PROTOCOL;
1654                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1655                               use_lineno);
1656         } else {
1657                 int nla_rem;
1658
1659                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1660                         memset(tb, 0, sizeof(*tb));
1661                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1662                             !flag_nested(nla) ||
1663                             nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla,
1664                                              set->type->adt_policy, NULL))
1665                                 return -IPSET_ERR_PROTOCOL;
1666                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1667                                       flags, use_lineno);
1668                         if (ret < 0)
1669                                 return ret;
1670                 }
1671         }
1672         return ret;
1673 }
1674
1675 static int ip_set_utest(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1676                         const struct nlmsghdr *nlh,
1677                         const struct nlattr * const attr[],
1678                         struct netlink_ext_ack *extack)
1679 {
1680         struct ip_set_net *inst = ip_set_pernet(net);
1681         struct ip_set *set;
1682         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1683         int ret = 0;
1684
1685         if (unlikely(protocol_min_failed(attr) ||
1686                      !attr[IPSET_ATTR_SETNAME] ||
1687                      !attr[IPSET_ATTR_DATA] ||
1688                      !flag_nested(attr[IPSET_ATTR_DATA])))
1689                 return -IPSET_ERR_PROTOCOL;
1690
1691         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1692         if (!set)
1693                 return -ENOENT;
1694
1695         if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1696                              set->type->adt_policy, NULL))
1697                 return -IPSET_ERR_PROTOCOL;
1698
1699         rcu_read_lock_bh();
1700         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1701         rcu_read_unlock_bh();
1702         /* Userspace can't trigger element to be re-added */
1703         if (ret == -EAGAIN)
1704                 ret = 1;
1705
1706         return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1707 }
1708
1709 /* Get headed data of a set */
1710
1711 static int ip_set_header(struct net *net, struct sock *ctnl,
1712                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1713                          const struct nlattr * const attr[],
1714                          struct netlink_ext_ack *extack)
1715 {
1716         struct ip_set_net *inst = ip_set_pernet(net);
1717         const struct ip_set *set;
1718         struct sk_buff *skb2;
1719         struct nlmsghdr *nlh2;
1720         int ret = 0;
1721
1722         if (unlikely(protocol_min_failed(attr) ||
1723                      !attr[IPSET_ATTR_SETNAME]))
1724                 return -IPSET_ERR_PROTOCOL;
1725
1726         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1727         if (!set)
1728                 return -ENOENT;
1729
1730         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1731         if (!skb2)
1732                 return -ENOMEM;
1733
1734         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1735                          IPSET_CMD_HEADER);
1736         if (!nlh2)
1737                 goto nlmsg_failure;
1738         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1739             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1740             nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1741             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1742             nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1743                 goto nla_put_failure;
1744         nlmsg_end(skb2, nlh2);
1745
1746         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1747         if (ret < 0)
1748                 return ret;
1749
1750         return 0;
1751
1752 nla_put_failure:
1753         nlmsg_cancel(skb2, nlh2);
1754 nlmsg_failure:
1755         kfree_skb(skb2);
1756         return -EMSGSIZE;
1757 }
1758
1759 /* Get type data */
1760
1761 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1762         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1763         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1764                                     .len = IPSET_MAXNAMELEN - 1 },
1765         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1766 };
1767
1768 static int ip_set_type(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1769                        const struct nlmsghdr *nlh,
1770                        const struct nlattr * const attr[],
1771                        struct netlink_ext_ack *extack)
1772 {
1773         struct sk_buff *skb2;
1774         struct nlmsghdr *nlh2;
1775         u8 family, min, max;
1776         const char *typename;
1777         int ret = 0;
1778
1779         if (unlikely(protocol_min_failed(attr) ||
1780                      !attr[IPSET_ATTR_TYPENAME] ||
1781                      !attr[IPSET_ATTR_FAMILY]))
1782                 return -IPSET_ERR_PROTOCOL;
1783
1784         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1785         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1786         ret = find_set_type_minmax(typename, family, &min, &max);
1787         if (ret)
1788                 return ret;
1789
1790         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1791         if (!skb2)
1792                 return -ENOMEM;
1793
1794         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1795                          IPSET_CMD_TYPE);
1796         if (!nlh2)
1797                 goto nlmsg_failure;
1798         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1799             nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1800             nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1801             nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1802             nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1803                 goto nla_put_failure;
1804         nlmsg_end(skb2, nlh2);
1805
1806         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1807         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1808         if (ret < 0)
1809                 return ret;
1810
1811         return 0;
1812
1813 nla_put_failure:
1814         nlmsg_cancel(skb2, nlh2);
1815 nlmsg_failure:
1816         kfree_skb(skb2);
1817         return -EMSGSIZE;
1818 }
1819
1820 /* Get protocol version */
1821
1822 static const struct nla_policy
1823 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1824         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1825 };
1826
1827 static int ip_set_protocol(struct net *net, struct sock *ctnl,
1828                            struct sk_buff *skb, const struct nlmsghdr *nlh,
1829                            const struct nlattr * const attr[],
1830                            struct netlink_ext_ack *extack)
1831 {
1832         struct sk_buff *skb2;
1833         struct nlmsghdr *nlh2;
1834         int ret = 0;
1835
1836         if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
1837                 return -IPSET_ERR_PROTOCOL;
1838
1839         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1840         if (!skb2)
1841                 return -ENOMEM;
1842
1843         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1844                          IPSET_CMD_PROTOCOL);
1845         if (!nlh2)
1846                 goto nlmsg_failure;
1847         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1848                 goto nla_put_failure;
1849         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
1850                 goto nla_put_failure;
1851         nlmsg_end(skb2, nlh2);
1852
1853         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1854         if (ret < 0)
1855                 return ret;
1856
1857         return 0;
1858
1859 nla_put_failure:
1860         nlmsg_cancel(skb2, nlh2);
1861 nlmsg_failure:
1862         kfree_skb(skb2);
1863         return -EMSGSIZE;
1864 }
1865
1866 /* Get set by name or index, from userspace */
1867
1868 static int ip_set_byname(struct net *net, struct sock *ctnl,
1869                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1870                          const struct nlattr * const attr[],
1871                          struct netlink_ext_ack *extack)
1872 {
1873         struct ip_set_net *inst = ip_set_pernet(net);
1874         struct sk_buff *skb2;
1875         struct nlmsghdr *nlh2;
1876         ip_set_id_t id = IPSET_INVALID_ID;
1877         const struct ip_set *set;
1878         int ret = 0;
1879
1880         if (unlikely(protocol_failed(attr) ||
1881                      !attr[IPSET_ATTR_SETNAME]))
1882                 return -IPSET_ERR_PROTOCOL;
1883
1884         set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
1885         if (id == IPSET_INVALID_ID)
1886                 return -ENOENT;
1887
1888         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1889         if (!skb2)
1890                 return -ENOMEM;
1891
1892         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1893                          IPSET_CMD_GET_BYNAME);
1894         if (!nlh2)
1895                 goto nlmsg_failure;
1896         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1897             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1898             nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
1899                 goto nla_put_failure;
1900         nlmsg_end(skb2, nlh2);
1901
1902         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1903         if (ret < 0)
1904                 return ret;
1905
1906         return 0;
1907
1908 nla_put_failure:
1909         nlmsg_cancel(skb2, nlh2);
1910 nlmsg_failure:
1911         kfree_skb(skb2);
1912         return -EMSGSIZE;
1913 }
1914
1915 static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
1916         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1917         [IPSET_ATTR_INDEX]      = { .type = NLA_U16 },
1918 };
1919
1920 static int ip_set_byindex(struct net *net, struct sock *ctnl,
1921                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1922                           const struct nlattr * const attr[],
1923                           struct netlink_ext_ack *extack)
1924 {
1925         struct ip_set_net *inst = ip_set_pernet(net);
1926         struct sk_buff *skb2;
1927         struct nlmsghdr *nlh2;
1928         ip_set_id_t id = IPSET_INVALID_ID;
1929         const struct ip_set *set;
1930         int ret = 0;
1931
1932         if (unlikely(protocol_failed(attr) ||
1933                      !attr[IPSET_ATTR_INDEX]))
1934                 return -IPSET_ERR_PROTOCOL;
1935
1936         id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
1937         if (id >= inst->ip_set_max)
1938                 return -ENOENT;
1939         set = ip_set(inst, id);
1940         if (set == NULL)
1941                 return -ENOENT;
1942
1943         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1944         if (!skb2)
1945                 return -ENOMEM;
1946
1947         nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1948                          IPSET_CMD_GET_BYINDEX);
1949         if (!nlh2)
1950                 goto nlmsg_failure;
1951         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1952             nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1953                 goto nla_put_failure;
1954         nlmsg_end(skb2, nlh2);
1955
1956         ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1957         if (ret < 0)
1958                 return ret;
1959
1960         return 0;
1961
1962 nla_put_failure:
1963         nlmsg_cancel(skb2, nlh2);
1964 nlmsg_failure:
1965         kfree_skb(skb2);
1966         return -EMSGSIZE;
1967 }
1968
1969 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1970         [IPSET_CMD_NONE]        = {
1971                 .call           = ip_set_none,
1972                 .attr_count     = IPSET_ATTR_CMD_MAX,
1973         },
1974         [IPSET_CMD_CREATE]      = {
1975                 .call           = ip_set_create,
1976                 .attr_count     = IPSET_ATTR_CMD_MAX,
1977                 .policy         = ip_set_create_policy,
1978         },
1979         [IPSET_CMD_DESTROY]     = {
1980                 .call           = ip_set_destroy,
1981                 .attr_count     = IPSET_ATTR_CMD_MAX,
1982                 .policy         = ip_set_setname_policy,
1983         },
1984         [IPSET_CMD_FLUSH]       = {
1985                 .call           = ip_set_flush,
1986                 .attr_count     = IPSET_ATTR_CMD_MAX,
1987                 .policy         = ip_set_setname_policy,
1988         },
1989         [IPSET_CMD_RENAME]      = {
1990                 .call           = ip_set_rename,
1991                 .attr_count     = IPSET_ATTR_CMD_MAX,
1992                 .policy         = ip_set_setname2_policy,
1993         },
1994         [IPSET_CMD_SWAP]        = {
1995                 .call           = ip_set_swap,
1996                 .attr_count     = IPSET_ATTR_CMD_MAX,
1997                 .policy         = ip_set_setname2_policy,
1998         },
1999         [IPSET_CMD_LIST]        = {
2000                 .call           = ip_set_dump,
2001                 .attr_count     = IPSET_ATTR_CMD_MAX,
2002                 .policy         = ip_set_setname_policy,
2003         },
2004         [IPSET_CMD_SAVE]        = {
2005                 .call           = ip_set_dump,
2006                 .attr_count     = IPSET_ATTR_CMD_MAX,
2007                 .policy         = ip_set_setname_policy,
2008         },
2009         [IPSET_CMD_ADD] = {
2010                 .call           = ip_set_uadd,
2011                 .attr_count     = IPSET_ATTR_CMD_MAX,
2012                 .policy         = ip_set_adt_policy,
2013         },
2014         [IPSET_CMD_DEL] = {
2015                 .call           = ip_set_udel,
2016                 .attr_count     = IPSET_ATTR_CMD_MAX,
2017                 .policy         = ip_set_adt_policy,
2018         },
2019         [IPSET_CMD_TEST]        = {
2020                 .call           = ip_set_utest,
2021                 .attr_count     = IPSET_ATTR_CMD_MAX,
2022                 .policy         = ip_set_adt_policy,
2023         },
2024         [IPSET_CMD_HEADER]      = {
2025                 .call           = ip_set_header,
2026                 .attr_count     = IPSET_ATTR_CMD_MAX,
2027                 .policy         = ip_set_setname_policy,
2028         },
2029         [IPSET_CMD_TYPE]        = {
2030                 .call           = ip_set_type,
2031                 .attr_count     = IPSET_ATTR_CMD_MAX,
2032                 .policy         = ip_set_type_policy,
2033         },
2034         [IPSET_CMD_PROTOCOL]    = {
2035                 .call           = ip_set_protocol,
2036                 .attr_count     = IPSET_ATTR_CMD_MAX,
2037                 .policy         = ip_set_protocol_policy,
2038         },
2039         [IPSET_CMD_GET_BYNAME]  = {
2040                 .call           = ip_set_byname,
2041                 .attr_count     = IPSET_ATTR_CMD_MAX,
2042                 .policy         = ip_set_setname_policy,
2043         },
2044         [IPSET_CMD_GET_BYINDEX] = {
2045                 .call           = ip_set_byindex,
2046                 .attr_count     = IPSET_ATTR_CMD_MAX,
2047                 .policy         = ip_set_index_policy,
2048         },
2049 };
2050
2051 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2052         .name           = "ip_set",
2053         .subsys_id      = NFNL_SUBSYS_IPSET,
2054         .cb_count       = IPSET_MSG_MAX,
2055         .cb             = ip_set_netlink_subsys_cb,
2056 };
2057
2058 /* Interface to iptables/ip6tables */
2059
2060 static int
2061 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2062 {
2063         unsigned int *op;
2064         void *data;
2065         int copylen = *len, ret = 0;
2066         struct net *net = sock_net(sk);
2067         struct ip_set_net *inst = ip_set_pernet(net);
2068
2069         if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2070                 return -EPERM;
2071         if (optval != SO_IP_SET)
2072                 return -EBADF;
2073         if (*len < sizeof(unsigned int))
2074                 return -EINVAL;
2075
2076         data = vmalloc(*len);
2077         if (!data)
2078                 return -ENOMEM;
2079         if (copy_from_user(data, user, *len) != 0) {
2080                 ret = -EFAULT;
2081                 goto done;
2082         }
2083         op = data;
2084
2085         if (*op < IP_SET_OP_VERSION) {
2086                 /* Check the version at the beginning of operations */
2087                 struct ip_set_req_version *req_version = data;
2088
2089                 if (*len < sizeof(struct ip_set_req_version)) {
2090                         ret = -EINVAL;
2091                         goto done;
2092                 }
2093
2094                 if (req_version->version < IPSET_PROTOCOL_MIN) {
2095                         ret = -EPROTO;
2096                         goto done;
2097                 }
2098         }
2099
2100         switch (*op) {
2101         case IP_SET_OP_VERSION: {
2102                 struct ip_set_req_version *req_version = data;
2103
2104                 if (*len != sizeof(struct ip_set_req_version)) {
2105                         ret = -EINVAL;
2106                         goto done;
2107                 }
2108
2109                 req_version->version = IPSET_PROTOCOL;
2110                 ret = copy_to_user(user, req_version,
2111                                    sizeof(struct ip_set_req_version));
2112                 goto done;
2113         }
2114         case IP_SET_OP_GET_BYNAME: {
2115                 struct ip_set_req_get_set *req_get = data;
2116                 ip_set_id_t id;
2117
2118                 if (*len != sizeof(struct ip_set_req_get_set)) {
2119                         ret = -EINVAL;
2120                         goto done;
2121                 }
2122                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2123                 nfnl_lock(NFNL_SUBSYS_IPSET);
2124                 find_set_and_id(inst, req_get->set.name, &id);
2125                 req_get->set.index = id;
2126                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2127                 goto copy;
2128         }
2129         case IP_SET_OP_GET_FNAME: {
2130                 struct ip_set_req_get_set_family *req_get = data;
2131                 ip_set_id_t id;
2132
2133                 if (*len != sizeof(struct ip_set_req_get_set_family)) {
2134                         ret = -EINVAL;
2135                         goto done;
2136                 }
2137                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2138                 nfnl_lock(NFNL_SUBSYS_IPSET);
2139                 find_set_and_id(inst, req_get->set.name, &id);
2140                 req_get->set.index = id;
2141                 if (id != IPSET_INVALID_ID)
2142                         req_get->family = ip_set(inst, id)->family;
2143                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2144                 goto copy;
2145         }
2146         case IP_SET_OP_GET_BYINDEX: {
2147                 struct ip_set_req_get_set *req_get = data;
2148                 struct ip_set *set;
2149
2150                 if (*len != sizeof(struct ip_set_req_get_set) ||
2151                     req_get->set.index >= inst->ip_set_max) {
2152                         ret = -EINVAL;
2153                         goto done;
2154                 }
2155                 nfnl_lock(NFNL_SUBSYS_IPSET);
2156                 set = ip_set(inst, req_get->set.index);
2157                 strncpy(req_get->set.name, set ? set->name : "",
2158                         IPSET_MAXNAMELEN);
2159                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2160                 goto copy;
2161         }
2162         default:
2163                 ret = -EBADMSG;
2164                 goto done;
2165         }       /* end of switch(op) */
2166
2167 copy:
2168         ret = copy_to_user(user, data, copylen);
2169
2170 done:
2171         vfree(data);
2172         if (ret > 0)
2173                 ret = 0;
2174         return ret;
2175 }
2176
2177 static struct nf_sockopt_ops so_set __read_mostly = {
2178         .pf             = PF_INET,
2179         .get_optmin     = SO_IP_SET,
2180         .get_optmax     = SO_IP_SET + 1,
2181         .get            = ip_set_sockfn_get,
2182         .owner          = THIS_MODULE,
2183 };
2184
2185 static int __net_init
2186 ip_set_net_init(struct net *net)
2187 {
2188         struct ip_set_net *inst = ip_set_pernet(net);
2189         struct ip_set **list;
2190
2191         inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2192         if (inst->ip_set_max >= IPSET_INVALID_ID)
2193                 inst->ip_set_max = IPSET_INVALID_ID - 1;
2194
2195         list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2196         if (!list)
2197                 return -ENOMEM;
2198         inst->is_deleted = false;
2199         inst->is_destroyed = false;
2200         rcu_assign_pointer(inst->ip_set_list, list);
2201         return 0;
2202 }
2203
2204 static void __net_exit
2205 ip_set_net_exit(struct net *net)
2206 {
2207         struct ip_set_net *inst = ip_set_pernet(net);
2208
2209         struct ip_set *set = NULL;
2210         ip_set_id_t i;
2211
2212         inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2213
2214         nfnl_lock(NFNL_SUBSYS_IPSET);
2215         for (i = 0; i < inst->ip_set_max; i++) {
2216                 set = ip_set(inst, i);
2217                 if (set) {
2218                         ip_set(inst, i) = NULL;
2219                         ip_set_destroy_set(set);
2220                 }
2221         }
2222         nfnl_unlock(NFNL_SUBSYS_IPSET);
2223         kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2224 }
2225
2226 static struct pernet_operations ip_set_net_ops = {
2227         .init   = ip_set_net_init,
2228         .exit   = ip_set_net_exit,
2229         .id     = &ip_set_net_id,
2230         .size   = sizeof(struct ip_set_net),
2231 };
2232
2233 static int __init
2234 ip_set_init(void)
2235 {
2236         int ret = register_pernet_subsys(&ip_set_net_ops);
2237
2238         if (ret) {
2239                 pr_err("ip_set: cannot register pernet_subsys.\n");
2240                 return ret;
2241         }
2242
2243         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2244         if (ret != 0) {
2245                 pr_err("ip_set: cannot register with nfnetlink.\n");
2246                 unregister_pernet_subsys(&ip_set_net_ops);
2247                 return ret;
2248         }
2249
2250         ret = nf_register_sockopt(&so_set);
2251         if (ret != 0) {
2252                 pr_err("SO_SET registry failed: %d\n", ret);
2253                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2254                 unregister_pernet_subsys(&ip_set_net_ops);
2255                 return ret;
2256         }
2257
2258         return 0;
2259 }
2260
2261 static void __exit
2262 ip_set_fini(void)
2263 {
2264         nf_unregister_sockopt(&so_set);
2265         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2266
2267         unregister_pernet_subsys(&ip_set_net_ops);
2268         pr_debug("these are the famous last words\n");
2269 }
2270
2271 module_init(ip_set_init);
2272 module_exit(ip_set_fini);
2273
2274 MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));