Merge branch 'for-linus-4.11' of git://git.kernel.org/pub/scm/linux/kernel/git/mason...
[linux-2.6-block.git] / net / netfilter / nft_ct.c
1 /*
2  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
3  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  *
9  * Development of this code funded by Astaro AG (http://www.astaro.com/)
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/init.h>
14 #include <linux/module.h>
15 #include <linux/netlink.h>
16 #include <linux/netfilter.h>
17 #include <linux/netfilter/nf_tables.h>
18 #include <net/netfilter/nf_tables.h>
19 #include <net/netfilter/nf_conntrack.h>
20 #include <net/netfilter/nf_conntrack_acct.h>
21 #include <net/netfilter/nf_conntrack_tuple.h>
22 #include <net/netfilter/nf_conntrack_helper.h>
23 #include <net/netfilter/nf_conntrack_ecache.h>
24 #include <net/netfilter/nf_conntrack_labels.h>
25
26 struct nft_ct {
27         enum nft_ct_keys        key:8;
28         enum ip_conntrack_dir   dir:8;
29         union {
30                 enum nft_registers      dreg:8;
31                 enum nft_registers      sreg:8;
32         };
33 };
34
35 #ifdef CONFIG_NF_CONNTRACK_ZONES
36 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
37 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
38 #endif
39
40 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
41                                    enum nft_ct_keys k,
42                                    enum ip_conntrack_dir d)
43 {
44         if (d < IP_CT_DIR_MAX)
45                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
46                                            atomic64_read(&c[d].packets);
47
48         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
49                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
50 }
51
52 static void nft_ct_get_eval(const struct nft_expr *expr,
53                             struct nft_regs *regs,
54                             const struct nft_pktinfo *pkt)
55 {
56         const struct nft_ct *priv = nft_expr_priv(expr);
57         u32 *dest = &regs->data[priv->dreg];
58         enum ip_conntrack_info ctinfo;
59         const struct nf_conn *ct;
60         const struct nf_conn_help *help;
61         const struct nf_conntrack_tuple *tuple;
62         const struct nf_conntrack_helper *helper;
63         unsigned int state;
64
65         ct = nf_ct_get(pkt->skb, &ctinfo);
66
67         switch (priv->key) {
68         case NFT_CT_STATE:
69                 if (ct == NULL)
70                         state = NF_CT_STATE_INVALID_BIT;
71                 else if (nf_ct_is_untracked(ct))
72                         state = NF_CT_STATE_UNTRACKED_BIT;
73                 else
74                         state = NF_CT_STATE_BIT(ctinfo);
75                 *dest = state;
76                 return;
77         default:
78                 break;
79         }
80
81         if (ct == NULL)
82                 goto err;
83
84         switch (priv->key) {
85         case NFT_CT_DIRECTION:
86                 *dest = CTINFO2DIR(ctinfo);
87                 return;
88         case NFT_CT_STATUS:
89                 *dest = ct->status;
90                 return;
91 #ifdef CONFIG_NF_CONNTRACK_MARK
92         case NFT_CT_MARK:
93                 *dest = ct->mark;
94                 return;
95 #endif
96 #ifdef CONFIG_NF_CONNTRACK_SECMARK
97         case NFT_CT_SECMARK:
98                 *dest = ct->secmark;
99                 return;
100 #endif
101         case NFT_CT_EXPIRATION:
102                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
103                 return;
104         case NFT_CT_HELPER:
105                 if (ct->master == NULL)
106                         goto err;
107                 help = nfct_help(ct->master);
108                 if (help == NULL)
109                         goto err;
110                 helper = rcu_dereference(help->helper);
111                 if (helper == NULL)
112                         goto err;
113                 strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
114                 return;
115 #ifdef CONFIG_NF_CONNTRACK_LABELS
116         case NFT_CT_LABELS: {
117                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
118
119                 if (labels)
120                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
121                 else
122                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
123                 return;
124         }
125 #endif
126         case NFT_CT_BYTES: /* fallthrough */
127         case NFT_CT_PKTS: {
128                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
129                 u64 count = 0;
130
131                 if (acct)
132                         count = nft_ct_get_eval_counter(acct->counter,
133                                                         priv->key, priv->dir);
134                 memcpy(dest, &count, sizeof(count));
135                 return;
136         }
137         case NFT_CT_AVGPKT: {
138                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
139                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
140
141                 if (acct) {
142                         pcnt = nft_ct_get_eval_counter(acct->counter,
143                                                        NFT_CT_PKTS, priv->dir);
144                         bcnt = nft_ct_get_eval_counter(acct->counter,
145                                                        NFT_CT_BYTES, priv->dir);
146                         if (pcnt != 0)
147                                 avgcnt = div64_u64(bcnt, pcnt);
148                 }
149
150                 memcpy(dest, &avgcnt, sizeof(avgcnt));
151                 return;
152         }
153         case NFT_CT_L3PROTOCOL:
154                 *dest = nf_ct_l3num(ct);
155                 return;
156         case NFT_CT_PROTOCOL:
157                 *dest = nf_ct_protonum(ct);
158                 return;
159 #ifdef CONFIG_NF_CONNTRACK_ZONES
160         case NFT_CT_ZONE: {
161                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
162
163                 if (priv->dir < IP_CT_DIR_MAX)
164                         *dest = nf_ct_zone_id(zone, priv->dir);
165                 else
166                         *dest = zone->id;
167
168                 return;
169         }
170 #endif
171         default:
172                 break;
173         }
174
175         tuple = &ct->tuplehash[priv->dir].tuple;
176         switch (priv->key) {
177         case NFT_CT_SRC:
178                 memcpy(dest, tuple->src.u3.all,
179                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
180                 return;
181         case NFT_CT_DST:
182                 memcpy(dest, tuple->dst.u3.all,
183                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
184                 return;
185         case NFT_CT_PROTO_SRC:
186                 *dest = (__force __u16)tuple->src.u.all;
187                 return;
188         case NFT_CT_PROTO_DST:
189                 *dest = (__force __u16)tuple->dst.u.all;
190                 return;
191         default:
192                 break;
193         }
194         return;
195 err:
196         regs->verdict.code = NFT_BREAK;
197 }
198
199 #ifdef CONFIG_NF_CONNTRACK_ZONES
200 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
201                                  struct nft_regs *regs,
202                                  const struct nft_pktinfo *pkt)
203 {
204         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
205         const struct nft_ct *priv = nft_expr_priv(expr);
206         struct sk_buff *skb = pkt->skb;
207         enum ip_conntrack_info ctinfo;
208         u16 value = regs->data[priv->sreg];
209         struct nf_conn *ct;
210
211         ct = nf_ct_get(skb, &ctinfo);
212         if (ct) /* already tracked */
213                 return;
214
215         zone.id = value;
216
217         switch (priv->dir) {
218         case IP_CT_DIR_ORIGINAL:
219                 zone.dir = NF_CT_ZONE_DIR_ORIG;
220                 break;
221         case IP_CT_DIR_REPLY:
222                 zone.dir = NF_CT_ZONE_DIR_REPL;
223                 break;
224         default:
225                 break;
226         }
227
228         ct = this_cpu_read(nft_ct_pcpu_template);
229
230         if (likely(atomic_read(&ct->ct_general.use) == 1)) {
231                 nf_ct_zone_add(ct, &zone);
232         } else {
233                 /* previous skb got queued to userspace */
234                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
235                 if (!ct) {
236                         regs->verdict.code = NF_DROP;
237                         return;
238                 }
239         }
240
241         atomic_inc(&ct->ct_general.use);
242         nf_ct_set(skb, ct, IP_CT_NEW);
243 }
244 #endif
245
246 static void nft_ct_set_eval(const struct nft_expr *expr,
247                             struct nft_regs *regs,
248                             const struct nft_pktinfo *pkt)
249 {
250         const struct nft_ct *priv = nft_expr_priv(expr);
251         struct sk_buff *skb = pkt->skb;
252 #ifdef CONFIG_NF_CONNTRACK_MARK
253         u32 value = regs->data[priv->sreg];
254 #endif
255         enum ip_conntrack_info ctinfo;
256         struct nf_conn *ct;
257
258         ct = nf_ct_get(skb, &ctinfo);
259         if (ct == NULL)
260                 return;
261
262         switch (priv->key) {
263 #ifdef CONFIG_NF_CONNTRACK_MARK
264         case NFT_CT_MARK:
265                 if (ct->mark != value) {
266                         ct->mark = value;
267                         nf_conntrack_event_cache(IPCT_MARK, ct);
268                 }
269                 break;
270 #endif
271 #ifdef CONFIG_NF_CONNTRACK_LABELS
272         case NFT_CT_LABELS:
273                 nf_connlabels_replace(ct,
274                                       &regs->data[priv->sreg],
275                                       &regs->data[priv->sreg],
276                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
277                 break;
278 #endif
279         default:
280                 break;
281         }
282 }
283
284 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
285         [NFTA_CT_DREG]          = { .type = NLA_U32 },
286         [NFTA_CT_KEY]           = { .type = NLA_U32 },
287         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
288         [NFTA_CT_SREG]          = { .type = NLA_U32 },
289 };
290
291 static int nft_ct_netns_get(struct net *net, uint8_t family)
292 {
293         int err;
294
295         if (family == NFPROTO_INET) {
296                 err = nf_ct_netns_get(net, NFPROTO_IPV4);
297                 if (err < 0)
298                         goto err1;
299                 err = nf_ct_netns_get(net, NFPROTO_IPV6);
300                 if (err < 0)
301                         goto err2;
302         } else {
303                 err = nf_ct_netns_get(net, family);
304                 if (err < 0)
305                         goto err1;
306         }
307         return 0;
308
309 err2:
310         nf_ct_netns_put(net, NFPROTO_IPV4);
311 err1:
312         return err;
313 }
314
315 static void nft_ct_netns_put(struct net *net, uint8_t family)
316 {
317         if (family == NFPROTO_INET) {
318                 nf_ct_netns_put(net, NFPROTO_IPV4);
319                 nf_ct_netns_put(net, NFPROTO_IPV6);
320         } else
321                 nf_ct_netns_put(net, family);
322 }
323
324 #ifdef CONFIG_NF_CONNTRACK_ZONES
325 static void nft_ct_tmpl_put_pcpu(void)
326 {
327         struct nf_conn *ct;
328         int cpu;
329
330         for_each_possible_cpu(cpu) {
331                 ct = per_cpu(nft_ct_pcpu_template, cpu);
332                 if (!ct)
333                         break;
334                 nf_ct_put(ct);
335                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
336         }
337 }
338
339 static bool nft_ct_tmpl_alloc_pcpu(void)
340 {
341         struct nf_conntrack_zone zone = { .id = 0 };
342         struct nf_conn *tmp;
343         int cpu;
344
345         if (nft_ct_pcpu_template_refcnt)
346                 return true;
347
348         for_each_possible_cpu(cpu) {
349                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
350                 if (!tmp) {
351                         nft_ct_tmpl_put_pcpu();
352                         return false;
353                 }
354
355                 atomic_set(&tmp->ct_general.use, 1);
356                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
357         }
358
359         return true;
360 }
361 #endif
362
363 static int nft_ct_get_init(const struct nft_ctx *ctx,
364                            const struct nft_expr *expr,
365                            const struct nlattr * const tb[])
366 {
367         struct nft_ct *priv = nft_expr_priv(expr);
368         unsigned int len;
369         int err;
370
371         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
372         priv->dir = IP_CT_DIR_MAX;
373         switch (priv->key) {
374         case NFT_CT_DIRECTION:
375                 if (tb[NFTA_CT_DIRECTION] != NULL)
376                         return -EINVAL;
377                 len = sizeof(u8);
378                 break;
379         case NFT_CT_STATE:
380         case NFT_CT_STATUS:
381 #ifdef CONFIG_NF_CONNTRACK_MARK
382         case NFT_CT_MARK:
383 #endif
384 #ifdef CONFIG_NF_CONNTRACK_SECMARK
385         case NFT_CT_SECMARK:
386 #endif
387         case NFT_CT_EXPIRATION:
388                 if (tb[NFTA_CT_DIRECTION] != NULL)
389                         return -EINVAL;
390                 len = sizeof(u32);
391                 break;
392 #ifdef CONFIG_NF_CONNTRACK_LABELS
393         case NFT_CT_LABELS:
394                 if (tb[NFTA_CT_DIRECTION] != NULL)
395                         return -EINVAL;
396                 len = NF_CT_LABELS_MAX_SIZE;
397                 break;
398 #endif
399         case NFT_CT_HELPER:
400                 if (tb[NFTA_CT_DIRECTION] != NULL)
401                         return -EINVAL;
402                 len = NF_CT_HELPER_NAME_LEN;
403                 break;
404
405         case NFT_CT_L3PROTOCOL:
406         case NFT_CT_PROTOCOL:
407                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
408                  * attribute is specified.
409                  */
410                 len = sizeof(u8);
411                 break;
412         case NFT_CT_SRC:
413         case NFT_CT_DST:
414                 if (tb[NFTA_CT_DIRECTION] == NULL)
415                         return -EINVAL;
416
417                 switch (ctx->afi->family) {
418                 case NFPROTO_IPV4:
419                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
420                                            src.u3.ip);
421                         break;
422                 case NFPROTO_IPV6:
423                 case NFPROTO_INET:
424                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
425                                            src.u3.ip6);
426                         break;
427                 default:
428                         return -EAFNOSUPPORT;
429                 }
430                 break;
431         case NFT_CT_PROTO_SRC:
432         case NFT_CT_PROTO_DST:
433                 if (tb[NFTA_CT_DIRECTION] == NULL)
434                         return -EINVAL;
435                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u.all);
436                 break;
437         case NFT_CT_BYTES:
438         case NFT_CT_PKTS:
439         case NFT_CT_AVGPKT:
440                 len = sizeof(u64);
441                 break;
442 #ifdef CONFIG_NF_CONNTRACK_ZONES
443         case NFT_CT_ZONE:
444                 len = sizeof(u16);
445                 break;
446 #endif
447         default:
448                 return -EOPNOTSUPP;
449         }
450
451         if (tb[NFTA_CT_DIRECTION] != NULL) {
452                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
453                 switch (priv->dir) {
454                 case IP_CT_DIR_ORIGINAL:
455                 case IP_CT_DIR_REPLY:
456                         break;
457                 default:
458                         return -EINVAL;
459                 }
460         }
461
462         priv->dreg = nft_parse_register(tb[NFTA_CT_DREG]);
463         err = nft_validate_register_store(ctx, priv->dreg, NULL,
464                                           NFT_DATA_VALUE, len);
465         if (err < 0)
466                 return err;
467
468         err = nft_ct_netns_get(ctx->net, ctx->afi->family);
469         if (err < 0)
470                 return err;
471
472         if (priv->key == NFT_CT_BYTES ||
473             priv->key == NFT_CT_PKTS  ||
474             priv->key == NFT_CT_AVGPKT)
475                 nf_ct_set_acct(ctx->net, true);
476
477         return 0;
478 }
479
480 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
481 {
482         switch (priv->key) {
483 #ifdef CONFIG_NF_CONNTRACK_LABELS
484         case NFT_CT_LABELS:
485                 nf_connlabels_put(ctx->net);
486                 break;
487 #endif
488 #ifdef CONFIG_NF_CONNTRACK_ZONES
489         case NFT_CT_ZONE:
490                 if (--nft_ct_pcpu_template_refcnt == 0)
491                         nft_ct_tmpl_put_pcpu();
492 #endif
493         default:
494                 break;
495         }
496 }
497
498 static int nft_ct_set_init(const struct nft_ctx *ctx,
499                            const struct nft_expr *expr,
500                            const struct nlattr * const tb[])
501 {
502         struct nft_ct *priv = nft_expr_priv(expr);
503         unsigned int len;
504         int err;
505
506         priv->dir = IP_CT_DIR_MAX;
507         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
508         switch (priv->key) {
509 #ifdef CONFIG_NF_CONNTRACK_MARK
510         case NFT_CT_MARK:
511                 if (tb[NFTA_CT_DIRECTION])
512                         return -EINVAL;
513                 len = FIELD_SIZEOF(struct nf_conn, mark);
514                 break;
515 #endif
516 #ifdef CONFIG_NF_CONNTRACK_LABELS
517         case NFT_CT_LABELS:
518                 if (tb[NFTA_CT_DIRECTION])
519                         return -EINVAL;
520                 len = NF_CT_LABELS_MAX_SIZE;
521                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
522                 if (err)
523                         return err;
524                 break;
525 #endif
526 #ifdef CONFIG_NF_CONNTRACK_ZONES
527         case NFT_CT_ZONE:
528                 if (!nft_ct_tmpl_alloc_pcpu())
529                         return -ENOMEM;
530                 nft_ct_pcpu_template_refcnt++;
531                 break;
532 #endif
533         default:
534                 return -EOPNOTSUPP;
535         }
536
537         if (tb[NFTA_CT_DIRECTION]) {
538                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
539                 switch (priv->dir) {
540                 case IP_CT_DIR_ORIGINAL:
541                 case IP_CT_DIR_REPLY:
542                         break;
543                 default:
544                         return -EINVAL;
545                 }
546         }
547
548         priv->sreg = nft_parse_register(tb[NFTA_CT_SREG]);
549         err = nft_validate_register_load(priv->sreg, len);
550         if (err < 0)
551                 goto err1;
552
553         err = nft_ct_netns_get(ctx->net, ctx->afi->family);
554         if (err < 0)
555                 goto err1;
556
557         return 0;
558
559 err1:
560         __nft_ct_set_destroy(ctx, priv);
561         return err;
562 }
563
564 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
565                                const struct nft_expr *expr)
566 {
567         nf_ct_netns_put(ctx->net, ctx->afi->family);
568 }
569
570 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
571                                const struct nft_expr *expr)
572 {
573         struct nft_ct *priv = nft_expr_priv(expr);
574
575         __nft_ct_set_destroy(ctx, priv);
576         nft_ct_netns_put(ctx->net, ctx->afi->family);
577 }
578
579 static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
580 {
581         const struct nft_ct *priv = nft_expr_priv(expr);
582
583         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
584                 goto nla_put_failure;
585         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
586                 goto nla_put_failure;
587
588         switch (priv->key) {
589         case NFT_CT_SRC:
590         case NFT_CT_DST:
591         case NFT_CT_PROTO_SRC:
592         case NFT_CT_PROTO_DST:
593                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
594                         goto nla_put_failure;
595                 break;
596         case NFT_CT_BYTES:
597         case NFT_CT_PKTS:
598         case NFT_CT_AVGPKT:
599         case NFT_CT_ZONE:
600                 if (priv->dir < IP_CT_DIR_MAX &&
601                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
602                         goto nla_put_failure;
603                 break;
604         default:
605                 break;
606         }
607
608         return 0;
609
610 nla_put_failure:
611         return -1;
612 }
613
614 static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr)
615 {
616         const struct nft_ct *priv = nft_expr_priv(expr);
617
618         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
619                 goto nla_put_failure;
620         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
621                 goto nla_put_failure;
622
623         switch (priv->key) {
624         case NFT_CT_ZONE:
625                 if (priv->dir < IP_CT_DIR_MAX &&
626                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
627                         goto nla_put_failure;
628                 break;
629         default:
630                 break;
631         }
632
633         return 0;
634
635 nla_put_failure:
636         return -1;
637 }
638
639 static struct nft_expr_type nft_ct_type;
640 static const struct nft_expr_ops nft_ct_get_ops = {
641         .type           = &nft_ct_type,
642         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
643         .eval           = nft_ct_get_eval,
644         .init           = nft_ct_get_init,
645         .destroy        = nft_ct_get_destroy,
646         .dump           = nft_ct_get_dump,
647 };
648
649 static const struct nft_expr_ops nft_ct_set_ops = {
650         .type           = &nft_ct_type,
651         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
652         .eval           = nft_ct_set_eval,
653         .init           = nft_ct_set_init,
654         .destroy        = nft_ct_set_destroy,
655         .dump           = nft_ct_set_dump,
656 };
657
658 #ifdef CONFIG_NF_CONNTRACK_ZONES
659 static const struct nft_expr_ops nft_ct_set_zone_ops = {
660         .type           = &nft_ct_type,
661         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
662         .eval           = nft_ct_set_zone_eval,
663         .init           = nft_ct_set_init,
664         .destroy        = nft_ct_set_destroy,
665         .dump           = nft_ct_set_dump,
666 };
667 #endif
668
669 static const struct nft_expr_ops *
670 nft_ct_select_ops(const struct nft_ctx *ctx,
671                     const struct nlattr * const tb[])
672 {
673         if (tb[NFTA_CT_KEY] == NULL)
674                 return ERR_PTR(-EINVAL);
675
676         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
677                 return ERR_PTR(-EINVAL);
678
679         if (tb[NFTA_CT_DREG])
680                 return &nft_ct_get_ops;
681
682         if (tb[NFTA_CT_SREG]) {
683 #ifdef CONFIG_NF_CONNTRACK_ZONES
684                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
685                         return &nft_ct_set_zone_ops;
686 #endif
687                 return &nft_ct_set_ops;
688         }
689
690         return ERR_PTR(-EINVAL);
691 }
692
693 static struct nft_expr_type nft_ct_type __read_mostly = {
694         .name           = "ct",
695         .select_ops     = &nft_ct_select_ops,
696         .policy         = nft_ct_policy,
697         .maxattr        = NFTA_CT_MAX,
698         .owner          = THIS_MODULE,
699 };
700
701 static void nft_notrack_eval(const struct nft_expr *expr,
702                              struct nft_regs *regs,
703                              const struct nft_pktinfo *pkt)
704 {
705         struct sk_buff *skb = pkt->skb;
706         enum ip_conntrack_info ctinfo;
707         struct nf_conn *ct;
708
709         ct = nf_ct_get(pkt->skb, &ctinfo);
710         /* Previously seen (loopback or untracked)?  Ignore. */
711         if (ct)
712                 return;
713
714         ct = nf_ct_untracked_get();
715         atomic_inc(&ct->ct_general.use);
716         nf_ct_set(skb, ct, IP_CT_NEW);
717 }
718
719 static struct nft_expr_type nft_notrack_type;
720 static const struct nft_expr_ops nft_notrack_ops = {
721         .type           = &nft_notrack_type,
722         .size           = NFT_EXPR_SIZE(0),
723         .eval           = nft_notrack_eval,
724 };
725
726 static struct nft_expr_type nft_notrack_type __read_mostly = {
727         .name           = "notrack",
728         .ops            = &nft_notrack_ops,
729         .owner          = THIS_MODULE,
730 };
731
732 static int __init nft_ct_module_init(void)
733 {
734         int err;
735
736         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
737
738         err = nft_register_expr(&nft_ct_type);
739         if (err < 0)
740                 return err;
741
742         err = nft_register_expr(&nft_notrack_type);
743         if (err < 0)
744                 goto err1;
745
746         return 0;
747 err1:
748         nft_unregister_expr(&nft_ct_type);
749         return err;
750 }
751
752 static void __exit nft_ct_module_exit(void)
753 {
754         nft_unregister_expr(&nft_notrack_type);
755         nft_unregister_expr(&nft_ct_type);
756 }
757
758 module_init(nft_ct_module_init);
759 module_exit(nft_ct_module_exit);
760
761 MODULE_LICENSE("GPL");
762 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
763 MODULE_ALIAS_NFT_EXPR("ct");
764 MODULE_ALIAS_NFT_EXPR("notrack");