Merge tag 'compiler-attributes-for-linus-4.20-rc1' of https://github.com/ojeda/linux
[linux-2.6-block.git] / net / netfilter / nf_conntrack_proto.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #include <linux/types.h>
4 #include <linux/netfilter.h>
5 #include <linux/module.h>
6 #include <linux/slab.h>
7 #include <linux/mutex.h>
8 #include <linux/vmalloc.h>
9 #include <linux/stddef.h>
10 #include <linux/err.h>
11 #include <linux/percpu.h>
12 #include <linux/notifier.h>
13 #include <linux/kernel.h>
14 #include <linux/netdevice.h>
15
16 #include <net/netfilter/nf_conntrack.h>
17 #include <net/netfilter/nf_conntrack_l4proto.h>
18 #include <net/netfilter/nf_conntrack_core.h>
19 #include <net/netfilter/nf_log.h>
20
21 #include <linux/ip.h>
22 #include <linux/icmp.h>
23 #include <linux/sysctl.h>
24 #include <net/route.h>
25 #include <net/ip.h>
26
27 #include <linux/netfilter_ipv4.h>
28 #include <linux/netfilter_ipv6.h>
29 #include <linux/netfilter_ipv6/ip6_tables.h>
30 #include <net/netfilter/nf_conntrack_helper.h>
31 #include <net/netfilter/nf_conntrack_zones.h>
32 #include <net/netfilter/nf_conntrack_seqadj.h>
33 #include <net/netfilter/ipv4/nf_conntrack_ipv4.h>
34 #include <net/netfilter/ipv6/nf_conntrack_ipv6.h>
35 #include <net/netfilter/nf_nat_helper.h>
36 #include <net/netfilter/ipv4/nf_defrag_ipv4.h>
37 #include <net/netfilter/ipv6/nf_defrag_ipv6.h>
38
39 #include <linux/ipv6.h>
40 #include <linux/in6.h>
41 #include <net/ipv6.h>
42 #include <net/inet_frag.h>
43
44 extern unsigned int nf_conntrack_net_id;
45
46 static struct nf_conntrack_l4proto __rcu *nf_ct_protos[MAX_NF_CT_PROTO + 1] __read_mostly;
47
48 static DEFINE_MUTEX(nf_ct_proto_mutex);
49
50 #ifdef CONFIG_SYSCTL
51 static int
52 nf_ct_register_sysctl(struct net *net,
53                       struct ctl_table_header **header,
54                       const char *path,
55                       struct ctl_table *table)
56 {
57         if (*header == NULL) {
58                 *header = register_net_sysctl(net, path, table);
59                 if (*header == NULL)
60                         return -ENOMEM;
61         }
62
63         return 0;
64 }
65
66 static void
67 nf_ct_unregister_sysctl(struct ctl_table_header **header,
68                         struct ctl_table **table,
69                         unsigned int users)
70 {
71         if (users > 0)
72                 return;
73
74         unregister_net_sysctl_table(*header);
75         kfree(*table);
76         *header = NULL;
77         *table = NULL;
78 }
79
80 __printf(5, 6)
81 void nf_l4proto_log_invalid(const struct sk_buff *skb,
82                             struct net *net,
83                             u16 pf, u8 protonum,
84                             const char *fmt, ...)
85 {
86         struct va_format vaf;
87         va_list args;
88
89         if (net->ct.sysctl_log_invalid != protonum ||
90             net->ct.sysctl_log_invalid != IPPROTO_RAW)
91                 return;
92
93         va_start(args, fmt);
94         vaf.fmt = fmt;
95         vaf.va = &args;
96
97         nf_log_packet(net, pf, 0, skb, NULL, NULL, NULL,
98                       "nf_ct_proto_%d: %pV ", protonum, &vaf);
99         va_end(args);
100 }
101 EXPORT_SYMBOL_GPL(nf_l4proto_log_invalid);
102
103 __printf(3, 4)
104 void nf_ct_l4proto_log_invalid(const struct sk_buff *skb,
105                                const struct nf_conn *ct,
106                                const char *fmt, ...)
107 {
108         struct va_format vaf;
109         struct net *net;
110         va_list args;
111
112         net = nf_ct_net(ct);
113         if (likely(net->ct.sysctl_log_invalid == 0))
114                 return;
115
116         va_start(args, fmt);
117         vaf.fmt = fmt;
118         vaf.va = &args;
119
120         nf_l4proto_log_invalid(skb, net, nf_ct_l3num(ct),
121                                nf_ct_protonum(ct), "%pV", &vaf);
122         va_end(args);
123 }
124 EXPORT_SYMBOL_GPL(nf_ct_l4proto_log_invalid);
125 #endif
126
127 const struct nf_conntrack_l4proto *__nf_ct_l4proto_find(u8 l4proto)
128 {
129         if (unlikely(l4proto >= ARRAY_SIZE(nf_ct_protos)))
130                 return &nf_conntrack_l4proto_generic;
131
132         return rcu_dereference(nf_ct_protos[l4proto]);
133 }
134 EXPORT_SYMBOL_GPL(__nf_ct_l4proto_find);
135
136 const struct nf_conntrack_l4proto *nf_ct_l4proto_find_get(u8 l4num)
137 {
138         const struct nf_conntrack_l4proto *p;
139
140         rcu_read_lock();
141         p = __nf_ct_l4proto_find(l4num);
142         if (!try_module_get(p->me))
143                 p = &nf_conntrack_l4proto_generic;
144         rcu_read_unlock();
145
146         return p;
147 }
148 EXPORT_SYMBOL_GPL(nf_ct_l4proto_find_get);
149
150 void nf_ct_l4proto_put(const struct nf_conntrack_l4proto *p)
151 {
152         module_put(p->me);
153 }
154 EXPORT_SYMBOL_GPL(nf_ct_l4proto_put);
155
156 static int kill_l4proto(struct nf_conn *i, void *data)
157 {
158         const struct nf_conntrack_l4proto *l4proto;
159         l4proto = data;
160         return nf_ct_protonum(i) == l4proto->l4proto;
161 }
162
163 static struct nf_proto_net *nf_ct_l4proto_net(struct net *net,
164                                 const struct nf_conntrack_l4proto *l4proto)
165 {
166         if (l4proto->get_net_proto) {
167                 /* statically built-in protocols use static per-net */
168                 return l4proto->get_net_proto(net);
169         } else if (l4proto->net_id) {
170                 /* ... and loadable protocols use dynamic per-net */
171                 return net_generic(net, *l4proto->net_id);
172         }
173         return NULL;
174 }
175
176 static
177 int nf_ct_l4proto_register_sysctl(struct net *net,
178                                   struct nf_proto_net *pn,
179                                   const struct nf_conntrack_l4proto *l4proto)
180 {
181         int err = 0;
182
183 #ifdef CONFIG_SYSCTL
184         if (pn->ctl_table != NULL) {
185                 err = nf_ct_register_sysctl(net,
186                                             &pn->ctl_table_header,
187                                             "net/netfilter",
188                                             pn->ctl_table);
189                 if (err < 0) {
190                         if (!pn->users) {
191                                 kfree(pn->ctl_table);
192                                 pn->ctl_table = NULL;
193                         }
194                 }
195         }
196 #endif /* CONFIG_SYSCTL */
197         return err;
198 }
199
200 static
201 void nf_ct_l4proto_unregister_sysctl(struct net *net,
202                                 struct nf_proto_net *pn,
203                                 const struct nf_conntrack_l4proto *l4proto)
204 {
205 #ifdef CONFIG_SYSCTL
206         if (pn->ctl_table_header != NULL)
207                 nf_ct_unregister_sysctl(&pn->ctl_table_header,
208                                         &pn->ctl_table,
209                                         pn->users);
210 #endif /* CONFIG_SYSCTL */
211 }
212
213 /* FIXME: Allow NULL functions and sub in pointers to generic for
214    them. --RR */
215 int nf_ct_l4proto_register_one(const struct nf_conntrack_l4proto *l4proto)
216 {
217         int ret = 0;
218
219         if ((l4proto->to_nlattr && l4proto->nlattr_size == 0) ||
220             (l4proto->tuple_to_nlattr && !l4proto->nlattr_tuple_size))
221                 return -EINVAL;
222
223         mutex_lock(&nf_ct_proto_mutex);
224         if (rcu_dereference_protected(
225                         nf_ct_protos[l4proto->l4proto],
226                         lockdep_is_held(&nf_ct_proto_mutex)
227                         ) != &nf_conntrack_l4proto_generic) {
228                 ret = -EBUSY;
229                 goto out_unlock;
230         }
231
232         rcu_assign_pointer(nf_ct_protos[l4proto->l4proto], l4proto);
233 out_unlock:
234         mutex_unlock(&nf_ct_proto_mutex);
235         return ret;
236 }
237 EXPORT_SYMBOL_GPL(nf_ct_l4proto_register_one);
238
239 int nf_ct_l4proto_pernet_register_one(struct net *net,
240                                 const struct nf_conntrack_l4proto *l4proto)
241 {
242         int ret = 0;
243         struct nf_proto_net *pn = NULL;
244
245         if (l4proto->init_net) {
246                 ret = l4proto->init_net(net);
247                 if (ret < 0)
248                         goto out;
249         }
250
251         pn = nf_ct_l4proto_net(net, l4proto);
252         if (pn == NULL)
253                 goto out;
254
255         ret = nf_ct_l4proto_register_sysctl(net, pn, l4proto);
256         if (ret < 0)
257                 goto out;
258
259         pn->users++;
260 out:
261         return ret;
262 }
263 EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_register_one);
264
265 static void __nf_ct_l4proto_unregister_one(const struct nf_conntrack_l4proto *l4proto)
266
267 {
268         BUG_ON(l4proto->l4proto >= ARRAY_SIZE(nf_ct_protos));
269
270         BUG_ON(rcu_dereference_protected(
271                         nf_ct_protos[l4proto->l4proto],
272                         lockdep_is_held(&nf_ct_proto_mutex)
273                         ) != l4proto);
274         rcu_assign_pointer(nf_ct_protos[l4proto->l4proto],
275                            &nf_conntrack_l4proto_generic);
276 }
277
278 void nf_ct_l4proto_unregister_one(const struct nf_conntrack_l4proto *l4proto)
279 {
280         mutex_lock(&nf_ct_proto_mutex);
281         __nf_ct_l4proto_unregister_one(l4proto);
282         mutex_unlock(&nf_ct_proto_mutex);
283
284         synchronize_net();
285         /* Remove all contrack entries for this protocol */
286         nf_ct_iterate_destroy(kill_l4proto, (void *)l4proto);
287 }
288 EXPORT_SYMBOL_GPL(nf_ct_l4proto_unregister_one);
289
290 void nf_ct_l4proto_pernet_unregister_one(struct net *net,
291                                 const struct nf_conntrack_l4proto *l4proto)
292 {
293         struct nf_proto_net *pn = nf_ct_l4proto_net(net, l4proto);
294
295         if (pn == NULL)
296                 return;
297
298         pn->users--;
299         nf_ct_l4proto_unregister_sysctl(net, pn, l4proto);
300 }
301 EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_unregister_one);
302
303 static void
304 nf_ct_l4proto_unregister(const struct nf_conntrack_l4proto * const l4proto[],
305                          unsigned int num_proto)
306 {
307         int i;
308
309         mutex_lock(&nf_ct_proto_mutex);
310         for (i = 0; i < num_proto; i++)
311                 __nf_ct_l4proto_unregister_one(l4proto[i]);
312         mutex_unlock(&nf_ct_proto_mutex);
313
314         synchronize_net();
315
316         for (i = 0; i < num_proto; i++)
317                 nf_ct_iterate_destroy(kill_l4proto, (void *)l4proto[i]);
318 }
319
320 static int
321 nf_ct_l4proto_register(const struct nf_conntrack_l4proto * const l4proto[],
322                        unsigned int num_proto)
323 {
324         int ret = -EINVAL;
325         unsigned int i;
326
327         for (i = 0; i < num_proto; i++) {
328                 ret = nf_ct_l4proto_register_one(l4proto[i]);
329                 if (ret < 0)
330                         break;
331         }
332         if (i != num_proto) {
333                 pr_err("nf_conntrack: can't register l4 %d proto.\n",
334                        l4proto[i]->l4proto);
335                 nf_ct_l4proto_unregister(l4proto, i);
336         }
337         return ret;
338 }
339
340 int nf_ct_l4proto_pernet_register(struct net *net,
341                                   const struct nf_conntrack_l4proto *const l4proto[],
342                                   unsigned int num_proto)
343 {
344         int ret = -EINVAL;
345         unsigned int i;
346
347         for (i = 0; i < num_proto; i++) {
348                 ret = nf_ct_l4proto_pernet_register_one(net, l4proto[i]);
349                 if (ret < 0)
350                         break;
351         }
352         if (i != num_proto) {
353                 pr_err("nf_conntrack %d: pernet registration failed\n",
354                        l4proto[i]->l4proto);
355                 nf_ct_l4proto_pernet_unregister(net, l4proto, i);
356         }
357         return ret;
358 }
359 EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_register);
360
361 void nf_ct_l4proto_pernet_unregister(struct net *net,
362                                 const struct nf_conntrack_l4proto *const l4proto[],
363                                 unsigned int num_proto)
364 {
365         while (num_proto-- != 0)
366                 nf_ct_l4proto_pernet_unregister_one(net, l4proto[num_proto]);
367 }
368 EXPORT_SYMBOL_GPL(nf_ct_l4proto_pernet_unregister);
369
370 static unsigned int ipv4_helper(void *priv,
371                                 struct sk_buff *skb,
372                                 const struct nf_hook_state *state)
373 {
374         struct nf_conn *ct;
375         enum ip_conntrack_info ctinfo;
376         const struct nf_conn_help *help;
377         const struct nf_conntrack_helper *helper;
378
379         /* This is where we call the helper: as the packet goes out. */
380         ct = nf_ct_get(skb, &ctinfo);
381         if (!ct || ctinfo == IP_CT_RELATED_REPLY)
382                 return NF_ACCEPT;
383
384         help = nfct_help(ct);
385         if (!help)
386                 return NF_ACCEPT;
387
388         /* rcu_read_lock()ed by nf_hook_thresh */
389         helper = rcu_dereference(help->helper);
390         if (!helper)
391                 return NF_ACCEPT;
392
393         return helper->help(skb, skb_network_offset(skb) + ip_hdrlen(skb),
394                             ct, ctinfo);
395 }
396
397 static unsigned int ipv4_confirm(void *priv,
398                                  struct sk_buff *skb,
399                                  const struct nf_hook_state *state)
400 {
401         struct nf_conn *ct;
402         enum ip_conntrack_info ctinfo;
403
404         ct = nf_ct_get(skb, &ctinfo);
405         if (!ct || ctinfo == IP_CT_RELATED_REPLY)
406                 goto out;
407
408         /* adjust seqs for loopback traffic only in outgoing direction */
409         if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) &&
410             !nf_is_loopback_packet(skb)) {
411                 if (!nf_ct_seq_adjust(skb, ct, ctinfo, ip_hdrlen(skb))) {
412                         NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop);
413                         return NF_DROP;
414                 }
415         }
416 out:
417         /* We've seen it coming out the other side: confirm it */
418         return nf_conntrack_confirm(skb);
419 }
420
421 static unsigned int ipv4_conntrack_in(void *priv,
422                                       struct sk_buff *skb,
423                                       const struct nf_hook_state *state)
424 {
425         return nf_conntrack_in(skb, state);
426 }
427
428 static unsigned int ipv4_conntrack_local(void *priv,
429                                          struct sk_buff *skb,
430                                          const struct nf_hook_state *state)
431 {
432         if (ip_is_fragment(ip_hdr(skb))) { /* IP_NODEFRAG setsockopt set */
433                 enum ip_conntrack_info ctinfo;
434                 struct nf_conn *tmpl;
435
436                 tmpl = nf_ct_get(skb, &ctinfo);
437                 if (tmpl && nf_ct_is_template(tmpl)) {
438                         /* when skipping ct, clear templates to avoid fooling
439                          * later targets/matches
440                          */
441                         skb->_nfct = 0;
442                         nf_ct_put(tmpl);
443                 }
444                 return NF_ACCEPT;
445         }
446
447         return nf_conntrack_in(skb, state);
448 }
449
450 /* Connection tracking may drop packets, but never alters them, so
451  * make it the first hook.
452  */
453 static const struct nf_hook_ops ipv4_conntrack_ops[] = {
454         {
455                 .hook           = ipv4_conntrack_in,
456                 .pf             = NFPROTO_IPV4,
457                 .hooknum        = NF_INET_PRE_ROUTING,
458                 .priority       = NF_IP_PRI_CONNTRACK,
459         },
460         {
461                 .hook           = ipv4_conntrack_local,
462                 .pf             = NFPROTO_IPV4,
463                 .hooknum        = NF_INET_LOCAL_OUT,
464                 .priority       = NF_IP_PRI_CONNTRACK,
465         },
466         {
467                 .hook           = ipv4_helper,
468                 .pf             = NFPROTO_IPV4,
469                 .hooknum        = NF_INET_POST_ROUTING,
470                 .priority       = NF_IP_PRI_CONNTRACK_HELPER,
471         },
472         {
473                 .hook           = ipv4_confirm,
474                 .pf             = NFPROTO_IPV4,
475                 .hooknum        = NF_INET_POST_ROUTING,
476                 .priority       = NF_IP_PRI_CONNTRACK_CONFIRM,
477         },
478         {
479                 .hook           = ipv4_helper,
480                 .pf             = NFPROTO_IPV4,
481                 .hooknum        = NF_INET_LOCAL_IN,
482                 .priority       = NF_IP_PRI_CONNTRACK_HELPER,
483         },
484         {
485                 .hook           = ipv4_confirm,
486                 .pf             = NFPROTO_IPV4,
487                 .hooknum        = NF_INET_LOCAL_IN,
488                 .priority       = NF_IP_PRI_CONNTRACK_CONFIRM,
489         },
490 };
491
492 /* Fast function for those who don't want to parse /proc (and I don't
493  * blame them).
494  * Reversing the socket's dst/src point of view gives us the reply
495  * mapping.
496  */
497 static int
498 getorigdst(struct sock *sk, int optval, void __user *user, int *len)
499 {
500         const struct inet_sock *inet = inet_sk(sk);
501         const struct nf_conntrack_tuple_hash *h;
502         struct nf_conntrack_tuple tuple;
503
504         memset(&tuple, 0, sizeof(tuple));
505
506         lock_sock(sk);
507         tuple.src.u3.ip = inet->inet_rcv_saddr;
508         tuple.src.u.tcp.port = inet->inet_sport;
509         tuple.dst.u3.ip = inet->inet_daddr;
510         tuple.dst.u.tcp.port = inet->inet_dport;
511         tuple.src.l3num = PF_INET;
512         tuple.dst.protonum = sk->sk_protocol;
513         release_sock(sk);
514
515         /* We only do TCP and SCTP at the moment: is there a better way? */
516         if (tuple.dst.protonum != IPPROTO_TCP &&
517             tuple.dst.protonum != IPPROTO_SCTP) {
518                 pr_debug("SO_ORIGINAL_DST: Not a TCP/SCTP socket\n");
519                 return -ENOPROTOOPT;
520         }
521
522         if ((unsigned int)*len < sizeof(struct sockaddr_in)) {
523                 pr_debug("SO_ORIGINAL_DST: len %d not %zu\n",
524                          *len, sizeof(struct sockaddr_in));
525                 return -EINVAL;
526         }
527
528         h = nf_conntrack_find_get(sock_net(sk), &nf_ct_zone_dflt, &tuple);
529         if (h) {
530                 struct sockaddr_in sin;
531                 struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);
532
533                 sin.sin_family = AF_INET;
534                 sin.sin_port = ct->tuplehash[IP_CT_DIR_ORIGINAL]
535                         .tuple.dst.u.tcp.port;
536                 sin.sin_addr.s_addr = ct->tuplehash[IP_CT_DIR_ORIGINAL]
537                         .tuple.dst.u3.ip;
538                 memset(sin.sin_zero, 0, sizeof(sin.sin_zero));
539
540                 pr_debug("SO_ORIGINAL_DST: %pI4 %u\n",
541                          &sin.sin_addr.s_addr, ntohs(sin.sin_port));
542                 nf_ct_put(ct);
543                 if (copy_to_user(user, &sin, sizeof(sin)) != 0)
544                         return -EFAULT;
545                 else
546                         return 0;
547         }
548         pr_debug("SO_ORIGINAL_DST: Can't find %pI4/%u-%pI4/%u.\n",
549                  &tuple.src.u3.ip, ntohs(tuple.src.u.tcp.port),
550                  &tuple.dst.u3.ip, ntohs(tuple.dst.u.tcp.port));
551         return -ENOENT;
552 }
553
554 static struct nf_sockopt_ops so_getorigdst = {
555         .pf             = PF_INET,
556         .get_optmin     = SO_ORIGINAL_DST,
557         .get_optmax     = SO_ORIGINAL_DST + 1,
558         .get            = getorigdst,
559         .owner          = THIS_MODULE,
560 };
561
562 #if IS_ENABLED(CONFIG_IPV6)
563 static int
564 ipv6_getorigdst(struct sock *sk, int optval, void __user *user, int *len)
565 {
566         struct nf_conntrack_tuple tuple = { .src.l3num = NFPROTO_IPV6 };
567         const struct ipv6_pinfo *inet6 = inet6_sk(sk);
568         const struct inet_sock *inet = inet_sk(sk);
569         const struct nf_conntrack_tuple_hash *h;
570         struct sockaddr_in6 sin6;
571         struct nf_conn *ct;
572         __be32 flow_label;
573         int bound_dev_if;
574
575         lock_sock(sk);
576         tuple.src.u3.in6 = sk->sk_v6_rcv_saddr;
577         tuple.src.u.tcp.port = inet->inet_sport;
578         tuple.dst.u3.in6 = sk->sk_v6_daddr;
579         tuple.dst.u.tcp.port = inet->inet_dport;
580         tuple.dst.protonum = sk->sk_protocol;
581         bound_dev_if = sk->sk_bound_dev_if;
582         flow_label = inet6->flow_label;
583         release_sock(sk);
584
585         if (tuple.dst.protonum != IPPROTO_TCP &&
586             tuple.dst.protonum != IPPROTO_SCTP)
587                 return -ENOPROTOOPT;
588
589         if (*len < 0 || (unsigned int)*len < sizeof(sin6))
590                 return -EINVAL;
591
592         h = nf_conntrack_find_get(sock_net(sk), &nf_ct_zone_dflt, &tuple);
593         if (!h) {
594                 pr_debug("IP6T_SO_ORIGINAL_DST: Can't find %pI6c/%u-%pI6c/%u.\n",
595                          &tuple.src.u3.ip6, ntohs(tuple.src.u.tcp.port),
596                          &tuple.dst.u3.ip6, ntohs(tuple.dst.u.tcp.port));
597                 return -ENOENT;
598         }
599
600         ct = nf_ct_tuplehash_to_ctrack(h);
601
602         sin6.sin6_family = AF_INET6;
603         sin6.sin6_port = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u.tcp.port;
604         sin6.sin6_flowinfo = flow_label & IPV6_FLOWINFO_MASK;
605         memcpy(&sin6.sin6_addr,
606                &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3.in6,
607                sizeof(sin6.sin6_addr));
608
609         nf_ct_put(ct);
610         sin6.sin6_scope_id = ipv6_iface_scope_id(&sin6.sin6_addr, bound_dev_if);
611         return copy_to_user(user, &sin6, sizeof(sin6)) ? -EFAULT : 0;
612 }
613
614 static struct nf_sockopt_ops so_getorigdst6 = {
615         .pf             = NFPROTO_IPV6,
616         .get_optmin     = IP6T_SO_ORIGINAL_DST,
617         .get_optmax     = IP6T_SO_ORIGINAL_DST + 1,
618         .get            = ipv6_getorigdst,
619         .owner          = THIS_MODULE,
620 };
621
622 static unsigned int ipv6_confirm(void *priv,
623                                  struct sk_buff *skb,
624                                  const struct nf_hook_state *state)
625 {
626         struct nf_conn *ct;
627         enum ip_conntrack_info ctinfo;
628         unsigned char pnum = ipv6_hdr(skb)->nexthdr;
629         int protoff;
630         __be16 frag_off;
631
632         ct = nf_ct_get(skb, &ctinfo);
633         if (!ct || ctinfo == IP_CT_RELATED_REPLY)
634                 goto out;
635
636         protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum,
637                                    &frag_off);
638         if (protoff < 0 || (frag_off & htons(~0x7)) != 0) {
639                 pr_debug("proto header not found\n");
640                 goto out;
641         }
642
643         /* adjust seqs for loopback traffic only in outgoing direction */
644         if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) &&
645             !nf_is_loopback_packet(skb)) {
646                 if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) {
647                         NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop);
648                         return NF_DROP;
649                 }
650         }
651 out:
652         /* We've seen it coming out the other side: confirm it */
653         return nf_conntrack_confirm(skb);
654 }
655
656 static unsigned int ipv6_conntrack_in(void *priv,
657                                       struct sk_buff *skb,
658                                       const struct nf_hook_state *state)
659 {
660         return nf_conntrack_in(skb, state);
661 }
662
663 static unsigned int ipv6_conntrack_local(void *priv,
664                                          struct sk_buff *skb,
665                                          const struct nf_hook_state *state)
666 {
667         return nf_conntrack_in(skb, state);
668 }
669
670 static unsigned int ipv6_helper(void *priv,
671                                 struct sk_buff *skb,
672                                 const struct nf_hook_state *state)
673 {
674         struct nf_conn *ct;
675         const struct nf_conn_help *help;
676         const struct nf_conntrack_helper *helper;
677         enum ip_conntrack_info ctinfo;
678         __be16 frag_off;
679         int protoff;
680         u8 nexthdr;
681
682         /* This is where we call the helper: as the packet goes out. */
683         ct = nf_ct_get(skb, &ctinfo);
684         if (!ct || ctinfo == IP_CT_RELATED_REPLY)
685                 return NF_ACCEPT;
686
687         help = nfct_help(ct);
688         if (!help)
689                 return NF_ACCEPT;
690         /* rcu_read_lock()ed by nf_hook_thresh */
691         helper = rcu_dereference(help->helper);
692         if (!helper)
693                 return NF_ACCEPT;
694
695         nexthdr = ipv6_hdr(skb)->nexthdr;
696         protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr,
697                                    &frag_off);
698         if (protoff < 0 || (frag_off & htons(~0x7)) != 0) {
699                 pr_debug("proto header not found\n");
700                 return NF_ACCEPT;
701         }
702
703         return helper->help(skb, protoff, ct, ctinfo);
704 }
705
706 static const struct nf_hook_ops ipv6_conntrack_ops[] = {
707         {
708                 .hook           = ipv6_conntrack_in,
709                 .pf             = NFPROTO_IPV6,
710                 .hooknum        = NF_INET_PRE_ROUTING,
711                 .priority       = NF_IP6_PRI_CONNTRACK,
712         },
713         {
714                 .hook           = ipv6_conntrack_local,
715                 .pf             = NFPROTO_IPV6,
716                 .hooknum        = NF_INET_LOCAL_OUT,
717                 .priority       = NF_IP6_PRI_CONNTRACK,
718         },
719         {
720                 .hook           = ipv6_helper,
721                 .pf             = NFPROTO_IPV6,
722                 .hooknum        = NF_INET_POST_ROUTING,
723                 .priority       = NF_IP6_PRI_CONNTRACK_HELPER,
724         },
725         {
726                 .hook           = ipv6_confirm,
727                 .pf             = NFPROTO_IPV6,
728                 .hooknum        = NF_INET_POST_ROUTING,
729                 .priority       = NF_IP6_PRI_LAST,
730         },
731         {
732                 .hook           = ipv6_helper,
733                 .pf             = NFPROTO_IPV6,
734                 .hooknum        = NF_INET_LOCAL_IN,
735                 .priority       = NF_IP6_PRI_CONNTRACK_HELPER,
736         },
737         {
738                 .hook           = ipv6_confirm,
739                 .pf             = NFPROTO_IPV6,
740                 .hooknum        = NF_INET_LOCAL_IN,
741                 .priority       = NF_IP6_PRI_LAST - 1,
742         },
743 };
744 #endif
745
746 static int nf_ct_tcp_fixup(struct nf_conn *ct, void *_nfproto)
747 {
748         u8 nfproto = (unsigned long)_nfproto;
749
750         if (nf_ct_l3num(ct) != nfproto)
751                 return 0;
752
753         if (nf_ct_protonum(ct) == IPPROTO_TCP &&
754             ct->proto.tcp.state == TCP_CONNTRACK_ESTABLISHED) {
755                 ct->proto.tcp.seen[0].td_maxwin = 0;
756                 ct->proto.tcp.seen[1].td_maxwin = 0;
757         }
758
759         return 0;
760 }
761
762 static int nf_ct_netns_do_get(struct net *net, u8 nfproto)
763 {
764         struct nf_conntrack_net *cnet = net_generic(net, nf_conntrack_net_id);
765         bool fixup_needed = false;
766         int err = 0;
767
768         mutex_lock(&nf_ct_proto_mutex);
769
770         switch (nfproto) {
771         case NFPROTO_IPV4:
772                 cnet->users4++;
773                 if (cnet->users4 > 1)
774                         goto out_unlock;
775                 err = nf_defrag_ipv4_enable(net);
776                 if (err) {
777                         cnet->users4 = 0;
778                         goto out_unlock;
779                 }
780
781                 err = nf_register_net_hooks(net, ipv4_conntrack_ops,
782                                             ARRAY_SIZE(ipv4_conntrack_ops));
783                 if (err)
784                         cnet->users4 = 0;
785                 else
786                         fixup_needed = true;
787                 break;
788 #if IS_ENABLED(CONFIG_IPV6)
789         case NFPROTO_IPV6:
790                 cnet->users6++;
791                 if (cnet->users6 > 1)
792                         goto out_unlock;
793                 err = nf_defrag_ipv6_enable(net);
794                 if (err < 0) {
795                         cnet->users6 = 0;
796                         goto out_unlock;
797                 }
798
799                 err = nf_register_net_hooks(net, ipv6_conntrack_ops,
800                                             ARRAY_SIZE(ipv6_conntrack_ops));
801                 if (err)
802                         cnet->users6 = 0;
803                 else
804                         fixup_needed = true;
805                 break;
806 #endif
807         default:
808                 err = -EPROTO;
809                 break;
810         }
811  out_unlock:
812         mutex_unlock(&nf_ct_proto_mutex);
813
814         if (fixup_needed)
815                 nf_ct_iterate_cleanup_net(net, nf_ct_tcp_fixup,
816                                           (void *)(unsigned long)nfproto, 0, 0);
817
818         return err;
819 }
820
821 static void nf_ct_netns_do_put(struct net *net, u8 nfproto)
822 {
823         struct nf_conntrack_net *cnet = net_generic(net, nf_conntrack_net_id);
824
825         mutex_lock(&nf_ct_proto_mutex);
826         switch (nfproto) {
827         case NFPROTO_IPV4:
828                 if (cnet->users4 && (--cnet->users4 == 0))
829                         nf_unregister_net_hooks(net, ipv4_conntrack_ops,
830                                                 ARRAY_SIZE(ipv4_conntrack_ops));
831                 break;
832 #if IS_ENABLED(CONFIG_IPV6)
833         case NFPROTO_IPV6:
834                 if (cnet->users6 && (--cnet->users6 == 0))
835                         nf_unregister_net_hooks(net, ipv6_conntrack_ops,
836                                                 ARRAY_SIZE(ipv6_conntrack_ops));
837                 break;
838 #endif
839         }
840
841         mutex_unlock(&nf_ct_proto_mutex);
842 }
843
844 int nf_ct_netns_get(struct net *net, u8 nfproto)
845 {
846         int err;
847
848         if (nfproto == NFPROTO_INET) {
849                 err = nf_ct_netns_do_get(net, NFPROTO_IPV4);
850                 if (err < 0)
851                         goto err1;
852                 err = nf_ct_netns_do_get(net, NFPROTO_IPV6);
853                 if (err < 0)
854                         goto err2;
855         } else {
856                 err = nf_ct_netns_do_get(net, nfproto);
857                 if (err < 0)
858                         goto err1;
859         }
860         return 0;
861
862 err2:
863         nf_ct_netns_put(net, NFPROTO_IPV4);
864 err1:
865         return err;
866 }
867 EXPORT_SYMBOL_GPL(nf_ct_netns_get);
868
869 void nf_ct_netns_put(struct net *net, uint8_t nfproto)
870 {
871         if (nfproto == NFPROTO_INET) {
872                 nf_ct_netns_do_put(net, NFPROTO_IPV4);
873                 nf_ct_netns_do_put(net, NFPROTO_IPV6);
874         } else {
875                 nf_ct_netns_do_put(net, nfproto);
876         }
877 }
878 EXPORT_SYMBOL_GPL(nf_ct_netns_put);
879
880 static const struct nf_conntrack_l4proto * const builtin_l4proto[] = {
881         &nf_conntrack_l4proto_tcp,
882         &nf_conntrack_l4proto_udp,
883         &nf_conntrack_l4proto_icmp,
884 #ifdef CONFIG_NF_CT_PROTO_DCCP
885         &nf_conntrack_l4proto_dccp,
886 #endif
887 #ifdef CONFIG_NF_CT_PROTO_SCTP
888         &nf_conntrack_l4proto_sctp,
889 #endif
890 #ifdef CONFIG_NF_CT_PROTO_UDPLITE
891         &nf_conntrack_l4proto_udplite,
892 #endif
893 #if IS_ENABLED(CONFIG_IPV6)
894         &nf_conntrack_l4proto_icmpv6,
895 #endif /* CONFIG_IPV6 */
896 };
897
898 int nf_conntrack_proto_init(void)
899 {
900         int ret = 0, i;
901
902         ret = nf_register_sockopt(&so_getorigdst);
903         if (ret < 0)
904                 return ret;
905
906 #if IS_ENABLED(CONFIG_IPV6)
907         ret = nf_register_sockopt(&so_getorigdst6);
908         if (ret < 0)
909                 goto cleanup_sockopt;
910 #endif
911
912         for (i = 0; i < ARRAY_SIZE(nf_ct_protos); i++)
913                 RCU_INIT_POINTER(nf_ct_protos[i],
914                                  &nf_conntrack_l4proto_generic);
915
916         ret = nf_ct_l4proto_register(builtin_l4proto,
917                                      ARRAY_SIZE(builtin_l4proto));
918         if (ret < 0)
919                 goto cleanup_sockopt2;
920
921         return ret;
922 cleanup_sockopt2:
923         nf_unregister_sockopt(&so_getorigdst);
924 #if IS_ENABLED(CONFIG_IPV6)
925 cleanup_sockopt:
926         nf_unregister_sockopt(&so_getorigdst6);
927 #endif
928         return ret;
929 }
930
931 void nf_conntrack_proto_fini(void)
932 {
933         nf_unregister_sockopt(&so_getorigdst);
934 #if IS_ENABLED(CONFIG_IPV6)
935         nf_unregister_sockopt(&so_getorigdst6);
936 #endif
937 }
938
939 int nf_conntrack_proto_pernet_init(struct net *net)
940 {
941         int err;
942         struct nf_proto_net *pn = nf_ct_l4proto_net(net,
943                                         &nf_conntrack_l4proto_generic);
944
945         err = nf_conntrack_l4proto_generic.init_net(net);
946         if (err < 0)
947                 return err;
948         err = nf_ct_l4proto_register_sysctl(net,
949                                             pn,
950                                             &nf_conntrack_l4proto_generic);
951         if (err < 0)
952                 return err;
953
954         err = nf_ct_l4proto_pernet_register(net, builtin_l4proto,
955                                             ARRAY_SIZE(builtin_l4proto));
956         if (err < 0) {
957                 nf_ct_l4proto_unregister_sysctl(net, pn,
958                                                 &nf_conntrack_l4proto_generic);
959                 return err;
960         }
961
962         pn->users++;
963         return 0;
964 }
965
966 void nf_conntrack_proto_pernet_fini(struct net *net)
967 {
968         struct nf_proto_net *pn = nf_ct_l4proto_net(net,
969                                         &nf_conntrack_l4proto_generic);
970
971         nf_ct_l4proto_pernet_unregister(net, builtin_l4proto,
972                                         ARRAY_SIZE(builtin_l4proto));
973         pn->users--;
974         nf_ct_l4proto_unregister_sysctl(net,
975                                         pn,
976                                         &nf_conntrack_l4proto_generic);
977 }
978
979
980 module_param_call(hashsize, nf_conntrack_set_hashsize, param_get_uint,
981                   &nf_conntrack_htable_size, 0600);
982
983 MODULE_ALIAS("ip_conntrack");
984 MODULE_ALIAS("nf_conntrack-" __stringify(AF_INET));
985 MODULE_ALIAS("nf_conntrack-" __stringify(AF_INET6));
986 MODULE_LICENSE("GPL");