amt: use workqueue for gateway side message handling
[linux-2.6-block.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "", /* Type 0 is not defined */
55         "AMT_MSG_DISCOVERY",
56         "AMT_MSG_ADVERTISEMENT",
57         "AMT_MSG_REQUEST",
58         "AMT_MSG_MEMBERSHIP_QUERY",
59         "AMT_MSG_MEMBERSHIP_UPDATE",
60         "AMT_MSG_MULTICAST_DATA",
61         "AMT_MSG_TEARDOWN",
62 };
63
64 static char *action_str[] = {
65         "AMT_ACT_GMI",
66         "AMT_ACT_GMI_ZERO",
67         "AMT_ACT_GT",
68         "AMT_ACT_STATUS_FWD_NEW",
69         "AMT_ACT_STATUS_D_FWD_NEW",
70         "AMT_ACT_STATUS_NONE_NEW",
71 };
72
73 static struct igmpv3_grec igmpv3_zero_grec;
74
75 #if IS_ENABLED(CONFIG_IPV6)
76 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
77 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
78 static struct mld2_grec mldv2_zero_grec;
79 #endif
80
81 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
82 {
83         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
84                      sizeof_field(struct sk_buff, cb));
85
86         return (struct amt_skb_cb *)((void *)skb->cb +
87                 sizeof(struct qdisc_skb_cb));
88 }
89
90 static void __amt_source_gc_work(void)
91 {
92         struct amt_source_node *snode;
93         struct hlist_head gc_list;
94         struct hlist_node *t;
95
96         spin_lock_bh(&source_gc_lock);
97         hlist_move_list(&source_gc_list, &gc_list);
98         spin_unlock_bh(&source_gc_lock);
99
100         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
101                 hlist_del_rcu(&snode->node);
102                 kfree_rcu(snode, rcu);
103         }
104 }
105
106 static void amt_source_gc_work(struct work_struct *work)
107 {
108         __amt_source_gc_work();
109
110         spin_lock_bh(&source_gc_lock);
111         mod_delayed_work(amt_wq, &source_gc_wq,
112                          msecs_to_jiffies(AMT_GC_INTERVAL));
113         spin_unlock_bh(&source_gc_lock);
114 }
115
116 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
117 {
118         return !memcmp(a, b, sizeof(union amt_addr));
119 }
120
121 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
122 {
123         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
124
125         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
126 }
127
128 static bool amt_status_filter(struct amt_source_node *snode,
129                               enum amt_filter filter)
130 {
131         bool rc = false;
132
133         switch (filter) {
134         case AMT_FILTER_FWD:
135                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
136                     snode->flags == AMT_SOURCE_OLD)
137                         rc = true;
138                 break;
139         case AMT_FILTER_D_FWD:
140                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
141                     snode->flags == AMT_SOURCE_OLD)
142                         rc = true;
143                 break;
144         case AMT_FILTER_FWD_NEW:
145                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
146                     snode->flags == AMT_SOURCE_NEW)
147                         rc = true;
148                 break;
149         case AMT_FILTER_D_FWD_NEW:
150                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
151                     snode->flags == AMT_SOURCE_NEW)
152                         rc = true;
153                 break;
154         case AMT_FILTER_ALL:
155                 rc = true;
156                 break;
157         case AMT_FILTER_NONE_NEW:
158                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
159                     snode->flags == AMT_SOURCE_NEW)
160                         rc = true;
161                 break;
162         case AMT_FILTER_BOTH:
163                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
164                      snode->status == AMT_SOURCE_STATUS_FWD) &&
165                     snode->flags == AMT_SOURCE_OLD)
166                         rc = true;
167                 break;
168         case AMT_FILTER_BOTH_NEW:
169                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
170                      snode->status == AMT_SOURCE_STATUS_FWD) &&
171                     snode->flags == AMT_SOURCE_NEW)
172                         rc = true;
173                 break;
174         default:
175                 WARN_ON_ONCE(1);
176                 break;
177         }
178
179         return rc;
180 }
181
182 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
183                                               struct amt_group_node *gnode,
184                                               enum amt_filter filter,
185                                               union amt_addr *src)
186 {
187         u32 hash = amt_source_hash(tunnel, src);
188         struct amt_source_node *snode;
189
190         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
191                 if (amt_status_filter(snode, filter) &&
192                     amt_addr_equal(&snode->source_addr, src))
193                         return snode;
194
195         return NULL;
196 }
197
198 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
199 {
200         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
201
202         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
203 }
204
205 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
206                                                union amt_addr *group,
207                                                union amt_addr *host,
208                                                bool v6)
209 {
210         u32 hash = amt_group_hash(tunnel, group);
211         struct amt_group_node *gnode;
212
213         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
214                 if (amt_addr_equal(&gnode->group_addr, group) &&
215                     amt_addr_equal(&gnode->host_addr, host) &&
216                     gnode->v6 == v6)
217                         return gnode;
218         }
219
220         return NULL;
221 }
222
223 static void amt_destroy_source(struct amt_source_node *snode)
224 {
225         struct amt_group_node *gnode = snode->gnode;
226         struct amt_tunnel_list *tunnel;
227
228         tunnel = gnode->tunnel_list;
229
230         if (!gnode->v6) {
231                 netdev_dbg(snode->gnode->amt->dev,
232                            "Delete source %pI4 from %pI4\n",
233                            &snode->source_addr.ip4,
234                            &gnode->group_addr.ip4);
235 #if IS_ENABLED(CONFIG_IPV6)
236         } else {
237                 netdev_dbg(snode->gnode->amt->dev,
238                            "Delete source %pI6 from %pI6\n",
239                            &snode->source_addr.ip6,
240                            &gnode->group_addr.ip6);
241 #endif
242         }
243
244         cancel_delayed_work(&snode->source_timer);
245         hlist_del_init_rcu(&snode->node);
246         tunnel->nr_sources--;
247         gnode->nr_sources--;
248         spin_lock_bh(&source_gc_lock);
249         hlist_add_head_rcu(&snode->node, &source_gc_list);
250         spin_unlock_bh(&source_gc_lock);
251 }
252
253 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
254 {
255         struct amt_source_node *snode;
256         struct hlist_node *t;
257         int i;
258
259         if (cancel_delayed_work(&gnode->group_timer))
260                 dev_put(amt->dev);
261         hlist_del_rcu(&gnode->node);
262         gnode->tunnel_list->nr_groups--;
263
264         if (!gnode->v6)
265                 netdev_dbg(amt->dev, "Leave group %pI4\n",
266                            &gnode->group_addr.ip4);
267 #if IS_ENABLED(CONFIG_IPV6)
268         else
269                 netdev_dbg(amt->dev, "Leave group %pI6\n",
270                            &gnode->group_addr.ip6);
271 #endif
272         for (i = 0; i < amt->hash_buckets; i++)
273                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
274                         amt_destroy_source(snode);
275
276         /* tunnel->lock was acquired outside of amt_del_group()
277          * But rcu_read_lock() was acquired too so It's safe.
278          */
279         kfree_rcu(gnode, rcu);
280 }
281
282 /* If a source timer expires with a router filter-mode for the group of
283  * INCLUDE, the router concludes that traffic from this particular
284  * source is no longer desired on the attached network, and deletes the
285  * associated source record.
286  */
287 static void amt_source_work(struct work_struct *work)
288 {
289         struct amt_source_node *snode = container_of(to_delayed_work(work),
290                                                      struct amt_source_node,
291                                                      source_timer);
292         struct amt_group_node *gnode = snode->gnode;
293         struct amt_dev *amt = gnode->amt;
294         struct amt_tunnel_list *tunnel;
295
296         tunnel = gnode->tunnel_list;
297         spin_lock_bh(&tunnel->lock);
298         rcu_read_lock();
299         if (gnode->filter_mode == MCAST_INCLUDE) {
300                 amt_destroy_source(snode);
301                 if (!gnode->nr_sources)
302                         amt_del_group(amt, gnode);
303         } else {
304                 /* When a router filter-mode for a group is EXCLUDE,
305                  * source records are only deleted when the group timer expires
306                  */
307                 snode->status = AMT_SOURCE_STATUS_D_FWD;
308         }
309         rcu_read_unlock();
310         spin_unlock_bh(&tunnel->lock);
311 }
312
313 static void amt_act_src(struct amt_tunnel_list *tunnel,
314                         struct amt_group_node *gnode,
315                         struct amt_source_node *snode,
316                         enum amt_act act)
317 {
318         struct amt_dev *amt = tunnel->amt;
319
320         switch (act) {
321         case AMT_ACT_GMI:
322                 mod_delayed_work(amt_wq, &snode->source_timer,
323                                  msecs_to_jiffies(amt_gmi(amt)));
324                 break;
325         case AMT_ACT_GMI_ZERO:
326                 cancel_delayed_work(&snode->source_timer);
327                 break;
328         case AMT_ACT_GT:
329                 mod_delayed_work(amt_wq, &snode->source_timer,
330                                  gnode->group_timer.timer.expires);
331                 break;
332         case AMT_ACT_STATUS_FWD_NEW:
333                 snode->status = AMT_SOURCE_STATUS_FWD;
334                 snode->flags = AMT_SOURCE_NEW;
335                 break;
336         case AMT_ACT_STATUS_D_FWD_NEW:
337                 snode->status = AMT_SOURCE_STATUS_D_FWD;
338                 snode->flags = AMT_SOURCE_NEW;
339                 break;
340         case AMT_ACT_STATUS_NONE_NEW:
341                 cancel_delayed_work(&snode->source_timer);
342                 snode->status = AMT_SOURCE_STATUS_NONE;
343                 snode->flags = AMT_SOURCE_NEW;
344                 break;
345         default:
346                 WARN_ON_ONCE(1);
347                 return;
348         }
349
350         if (!gnode->v6)
351                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
352                            &snode->source_addr.ip4,
353                            &gnode->group_addr.ip4,
354                            action_str[act]);
355 #if IS_ENABLED(CONFIG_IPV6)
356         else
357                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
358                            &snode->source_addr.ip6,
359                            &gnode->group_addr.ip6,
360                            action_str[act]);
361 #endif
362 }
363
364 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
365                                                union amt_addr *src)
366 {
367         struct amt_source_node *snode;
368
369         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
370         if (!snode)
371                 return NULL;
372
373         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
374         snode->gnode = gnode;
375         snode->status = AMT_SOURCE_STATUS_NONE;
376         snode->flags = AMT_SOURCE_NEW;
377         INIT_HLIST_NODE(&snode->node);
378         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
379
380         return snode;
381 }
382
383 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
384  *
385  *  Router Mode          Filter Timer         Actions/Comments
386  *  -----------       -----------------       ----------------
387  *
388  *    INCLUDE             Not Used            All listeners in
389  *                                            INCLUDE mode.
390  *
391  *    EXCLUDE             Timer > 0           At least one listener
392  *                                            in EXCLUDE mode.
393  *
394  *    EXCLUDE             Timer == 0          No more listeners in
395  *                                            EXCLUDE mode for the
396  *                                            multicast address.
397  *                                            If the Requested List
398  *                                            is empty, delete
399  *                                            Multicast Address
400  *                                            Record.  If not, switch
401  *                                            to INCLUDE filter mode;
402  *                                            the sources in the
403  *                                            Requested List are
404  *                                            moved to the Include
405  *                                            List, and the Exclude
406  *                                            List is deleted.
407  */
408 static void amt_group_work(struct work_struct *work)
409 {
410         struct amt_group_node *gnode = container_of(to_delayed_work(work),
411                                                     struct amt_group_node,
412                                                     group_timer);
413         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
414         struct amt_dev *amt = gnode->amt;
415         struct amt_source_node *snode;
416         bool delete_group = true;
417         struct hlist_node *t;
418         int i, buckets;
419
420         buckets = amt->hash_buckets;
421
422         spin_lock_bh(&tunnel->lock);
423         if (gnode->filter_mode == MCAST_INCLUDE) {
424                 /* Not Used */
425                 spin_unlock_bh(&tunnel->lock);
426                 goto out;
427         }
428
429         rcu_read_lock();
430         for (i = 0; i < buckets; i++) {
431                 hlist_for_each_entry_safe(snode, t,
432                                           &gnode->sources[i], node) {
433                         if (!delayed_work_pending(&snode->source_timer) ||
434                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
435                                 amt_destroy_source(snode);
436                         } else {
437                                 delete_group = false;
438                                 snode->status = AMT_SOURCE_STATUS_FWD;
439                         }
440                 }
441         }
442         if (delete_group)
443                 amt_del_group(amt, gnode);
444         else
445                 gnode->filter_mode = MCAST_INCLUDE;
446         rcu_read_unlock();
447         spin_unlock_bh(&tunnel->lock);
448 out:
449         dev_put(amt->dev);
450 }
451
452 /* Non-existant group is created as INCLUDE {empty}:
453  *
454  * RFC 3376 - 5.1. Action on Change of Interface State
455  *
456  * If no interface state existed for that multicast address before
457  * the change (i.e., the change consisted of creating a new
458  * per-interface record), or if no state exists after the change
459  * (i.e., the change consisted of deleting a per-interface record),
460  * then the "non-existent" state is considered to have a filter mode
461  * of INCLUDE and an empty source list.
462  */
463 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
464                                             struct amt_tunnel_list *tunnel,
465                                             union amt_addr *group,
466                                             union amt_addr *host,
467                                             bool v6)
468 {
469         struct amt_group_node *gnode;
470         u32 hash;
471         int i;
472
473         if (tunnel->nr_groups >= amt->max_groups)
474                 return ERR_PTR(-ENOSPC);
475
476         gnode = kzalloc(sizeof(*gnode) +
477                         (sizeof(struct hlist_head) * amt->hash_buckets),
478                         GFP_ATOMIC);
479         if (unlikely(!gnode))
480                 return ERR_PTR(-ENOMEM);
481
482         gnode->amt = amt;
483         gnode->group_addr = *group;
484         gnode->host_addr = *host;
485         gnode->v6 = v6;
486         gnode->tunnel_list = tunnel;
487         gnode->filter_mode = MCAST_INCLUDE;
488         INIT_HLIST_NODE(&gnode->node);
489         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
490         for (i = 0; i < amt->hash_buckets; i++)
491                 INIT_HLIST_HEAD(&gnode->sources[i]);
492
493         hash = amt_group_hash(tunnel, group);
494         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
495         tunnel->nr_groups++;
496
497         if (!gnode->v6)
498                 netdev_dbg(amt->dev, "Join group %pI4\n",
499                            &gnode->group_addr.ip4);
500 #if IS_ENABLED(CONFIG_IPV6)
501         else
502                 netdev_dbg(amt->dev, "Join group %pI6\n",
503                            &gnode->group_addr.ip6);
504 #endif
505
506         return gnode;
507 }
508
509 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
510 {
511         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
512         int hlen = LL_RESERVED_SPACE(amt->dev);
513         int tlen = amt->dev->needed_tailroom;
514         struct igmpv3_query *ihv3;
515         void *csum_start = NULL;
516         __sum16 *csum = NULL;
517         struct sk_buff *skb;
518         struct ethhdr *eth;
519         struct iphdr *iph;
520         unsigned int len;
521         int offset;
522
523         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
524         skb = netdev_alloc_skb_ip_align(amt->dev, len);
525         if (!skb)
526                 return NULL;
527
528         skb_reserve(skb, hlen);
529         skb_push(skb, sizeof(*eth));
530         skb->protocol = htons(ETH_P_IP);
531         skb_reset_mac_header(skb);
532         skb->priority = TC_PRIO_CONTROL;
533         skb_put(skb, sizeof(*iph));
534         skb_put_data(skb, ra, sizeof(ra));
535         skb_put(skb, sizeof(*ihv3));
536         skb_pull(skb, sizeof(*eth));
537         skb_reset_network_header(skb);
538
539         iph             = ip_hdr(skb);
540         iph->version    = 4;
541         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
542         iph->tos        = AMT_TOS;
543         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
544         iph->frag_off   = htons(IP_DF);
545         iph->ttl        = 1;
546         iph->id         = 0;
547         iph->protocol   = IPPROTO_IGMP;
548         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
549         iph->saddr      = htonl(INADDR_ANY);
550         ip_send_check(iph);
551
552         eth = eth_hdr(skb);
553         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
554         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
555         eth->h_proto = htons(ETH_P_IP);
556
557         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
558         skb_reset_transport_header(skb);
559         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
560         ihv3->code      = 1;
561         ihv3->group     = 0;
562         ihv3->qqic      = amt->qi;
563         ihv3->nsrcs     = 0;
564         ihv3->resv      = 0;
565         ihv3->suppress  = false;
566         ihv3->qrv       = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
567         ihv3->csum      = 0;
568         csum            = &ihv3->csum;
569         csum_start      = (void *)ihv3;
570         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
571         offset          = skb_transport_offset(skb);
572         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
573         skb->ip_summed  = CHECKSUM_NONE;
574
575         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
576
577         return skb;
578 }
579
580 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
581                                    bool validate)
582 {
583         if (validate && amt->status >= status)
584                 return;
585         netdev_dbg(amt->dev, "Update GW status %s -> %s",
586                    status_str[amt->status], status_str[status]);
587         amt->status = status;
588 }
589
590 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
591                                       enum amt_status status,
592                                       bool validate)
593 {
594         if (validate && tunnel->status >= status)
595                 return;
596         netdev_dbg(tunnel->amt->dev,
597                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
598                    &tunnel->ip4, ntohs(tunnel->source_port),
599                    status_str[tunnel->status], status_str[status]);
600         tunnel->status = status;
601 }
602
603 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
604                                  bool validate)
605 {
606         spin_lock_bh(&amt->lock);
607         __amt_update_gw_status(amt, status, validate);
608         spin_unlock_bh(&amt->lock);
609 }
610
611 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
612                                     enum amt_status status, bool validate)
613 {
614         spin_lock_bh(&tunnel->lock);
615         __amt_update_relay_status(tunnel, status, validate);
616         spin_unlock_bh(&tunnel->lock);
617 }
618
619 static void amt_send_discovery(struct amt_dev *amt)
620 {
621         struct amt_header_discovery *amtd;
622         int hlen, tlen, offset;
623         struct socket *sock;
624         struct udphdr *udph;
625         struct sk_buff *skb;
626         struct iphdr *iph;
627         struct rtable *rt;
628         struct flowi4 fl4;
629         u32 len;
630         int err;
631
632         rcu_read_lock();
633         sock = rcu_dereference(amt->sock);
634         if (!sock)
635                 goto out;
636
637         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
638                 goto out;
639
640         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
641                                    amt->discovery_ip, amt->local_ip,
642                                    amt->gw_port, amt->relay_port,
643                                    IPPROTO_UDP, 0,
644                                    amt->stream_dev->ifindex);
645         if (IS_ERR(rt)) {
646                 amt->dev->stats.tx_errors++;
647                 goto out;
648         }
649
650         hlen = LL_RESERVED_SPACE(amt->dev);
651         tlen = amt->dev->needed_tailroom;
652         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
653         skb = netdev_alloc_skb_ip_align(amt->dev, len);
654         if (!skb) {
655                 ip_rt_put(rt);
656                 amt->dev->stats.tx_errors++;
657                 goto out;
658         }
659
660         skb->priority = TC_PRIO_CONTROL;
661         skb_dst_set(skb, &rt->dst);
662
663         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
664         skb_reset_network_header(skb);
665         skb_put(skb, len);
666         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
667         amtd->version   = 0;
668         amtd->type      = AMT_MSG_DISCOVERY;
669         amtd->reserved  = 0;
670         amtd->nonce     = amt->nonce;
671         skb_push(skb, sizeof(*udph));
672         skb_reset_transport_header(skb);
673         udph            = udp_hdr(skb);
674         udph->source    = amt->gw_port;
675         udph->dest      = amt->relay_port;
676         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
677         udph->check     = 0;
678         offset = skb_transport_offset(skb);
679         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
680         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
681                                         sizeof(*udph) + sizeof(*amtd),
682                                         IPPROTO_UDP, skb->csum);
683
684         skb_push(skb, sizeof(*iph));
685         iph             = ip_hdr(skb);
686         iph->version    = 4;
687         iph->ihl        = (sizeof(struct iphdr)) >> 2;
688         iph->tos        = AMT_TOS;
689         iph->frag_off   = 0;
690         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
691         iph->daddr      = amt->discovery_ip;
692         iph->saddr      = amt->local_ip;
693         iph->protocol   = IPPROTO_UDP;
694         iph->tot_len    = htons(len);
695
696         skb->ip_summed = CHECKSUM_NONE;
697         ip_select_ident(amt->net, skb, NULL);
698         ip_send_check(iph);
699         err = ip_local_out(amt->net, sock->sk, skb);
700         if (unlikely(net_xmit_eval(err)))
701                 amt->dev->stats.tx_errors++;
702
703         spin_lock_bh(&amt->lock);
704         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
705         spin_unlock_bh(&amt->lock);
706 out:
707         rcu_read_unlock();
708 }
709
710 static void amt_send_request(struct amt_dev *amt, bool v6)
711 {
712         struct amt_header_request *amtrh;
713         int hlen, tlen, offset;
714         struct socket *sock;
715         struct udphdr *udph;
716         struct sk_buff *skb;
717         struct iphdr *iph;
718         struct rtable *rt;
719         struct flowi4 fl4;
720         u32 len;
721         int err;
722
723         rcu_read_lock();
724         sock = rcu_dereference(amt->sock);
725         if (!sock)
726                 goto out;
727
728         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
729                 goto out;
730
731         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
732                                    amt->remote_ip, amt->local_ip,
733                                    amt->gw_port, amt->relay_port,
734                                    IPPROTO_UDP, 0,
735                                    amt->stream_dev->ifindex);
736         if (IS_ERR(rt)) {
737                 amt->dev->stats.tx_errors++;
738                 goto out;
739         }
740
741         hlen = LL_RESERVED_SPACE(amt->dev);
742         tlen = amt->dev->needed_tailroom;
743         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
744         skb = netdev_alloc_skb_ip_align(amt->dev, len);
745         if (!skb) {
746                 ip_rt_put(rt);
747                 amt->dev->stats.tx_errors++;
748                 goto out;
749         }
750
751         skb->priority = TC_PRIO_CONTROL;
752         skb_dst_set(skb, &rt->dst);
753
754         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
755         skb_reset_network_header(skb);
756         skb_put(skb, len);
757         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
758         amtrh->version   = 0;
759         amtrh->type      = AMT_MSG_REQUEST;
760         amtrh->reserved1 = 0;
761         amtrh->p         = v6;
762         amtrh->reserved2 = 0;
763         amtrh->nonce     = amt->nonce;
764         skb_push(skb, sizeof(*udph));
765         skb_reset_transport_header(skb);
766         udph            = udp_hdr(skb);
767         udph->source    = amt->gw_port;
768         udph->dest      = amt->relay_port;
769         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
770         udph->check     = 0;
771         offset = skb_transport_offset(skb);
772         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
773         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
774                                         sizeof(*udph) + sizeof(*amtrh),
775                                         IPPROTO_UDP, skb->csum);
776
777         skb_push(skb, sizeof(*iph));
778         iph             = ip_hdr(skb);
779         iph->version    = 4;
780         iph->ihl        = (sizeof(struct iphdr)) >> 2;
781         iph->tos        = AMT_TOS;
782         iph->frag_off   = 0;
783         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
784         iph->daddr      = amt->remote_ip;
785         iph->saddr      = amt->local_ip;
786         iph->protocol   = IPPROTO_UDP;
787         iph->tot_len    = htons(len);
788
789         skb->ip_summed = CHECKSUM_NONE;
790         ip_select_ident(amt->net, skb, NULL);
791         ip_send_check(iph);
792         err = ip_local_out(amt->net, sock->sk, skb);
793         if (unlikely(net_xmit_eval(err)))
794                 amt->dev->stats.tx_errors++;
795
796 out:
797         rcu_read_unlock();
798 }
799
800 static void amt_send_igmp_gq(struct amt_dev *amt,
801                              struct amt_tunnel_list *tunnel)
802 {
803         struct sk_buff *skb;
804
805         skb = amt_build_igmp_gq(amt);
806         if (!skb)
807                 return;
808
809         amt_skb_cb(skb)->tunnel = tunnel;
810         dev_queue_xmit(skb);
811 }
812
813 #if IS_ENABLED(CONFIG_IPV6)
814 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
815 {
816         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
817                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
818         int hlen = LL_RESERVED_SPACE(amt->dev);
819         int tlen = amt->dev->needed_tailroom;
820         struct mld2_query *mld2q;
821         void *csum_start = NULL;
822         struct ipv6hdr *ip6h;
823         struct sk_buff *skb;
824         struct ethhdr *eth;
825         u32 len;
826
827         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
828         skb = netdev_alloc_skb_ip_align(amt->dev, len);
829         if (!skb)
830                 return NULL;
831
832         skb_reserve(skb, hlen);
833         skb_push(skb, sizeof(*eth));
834         skb_reset_mac_header(skb);
835         eth = eth_hdr(skb);
836         skb->priority = TC_PRIO_CONTROL;
837         skb->protocol = htons(ETH_P_IPV6);
838         skb_put_zero(skb, sizeof(*ip6h));
839         skb_put_data(skb, ra, sizeof(ra));
840         skb_put_zero(skb, sizeof(*mld2q));
841         skb_pull(skb, sizeof(*eth));
842         skb_reset_network_header(skb);
843         ip6h                    = ipv6_hdr(skb);
844         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
845         ip6h->nexthdr           = NEXTHDR_HOP;
846         ip6h->hop_limit         = 1;
847         ip6h->daddr             = mld2_all_node;
848         ip6_flow_hdr(ip6h, 0, 0);
849
850         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
851                                &ip6h->saddr)) {
852                 amt->dev->stats.tx_errors++;
853                 kfree_skb(skb);
854                 return NULL;
855         }
856
857         eth->h_proto = htons(ETH_P_IPV6);
858         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
859         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
860
861         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
862         skb_reset_transport_header(skb);
863         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
864         mld2q->mld2q_mrc        = htons(1);
865         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
866         mld2q->mld2q_code       = 0;
867         mld2q->mld2q_cksum      = 0;
868         mld2q->mld2q_resv1      = 0;
869         mld2q->mld2q_resv2      = 0;
870         mld2q->mld2q_suppress   = 0;
871         mld2q->mld2q_qrv        = amt->qrv;
872         mld2q->mld2q_nsrcs      = 0;
873         mld2q->mld2q_qqic       = amt->qi;
874         csum_start              = (void *)mld2q;
875         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
876                                              sizeof(*mld2q),
877                                              IPPROTO_ICMPV6,
878                                              csum_partial(csum_start,
879                                                           sizeof(*mld2q), 0));
880
881         skb->ip_summed = CHECKSUM_NONE;
882         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
883         return skb;
884 }
885
886 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
887 {
888         struct sk_buff *skb;
889
890         skb = amt_build_mld_gq(amt);
891         if (!skb)
892                 return;
893
894         amt_skb_cb(skb)->tunnel = tunnel;
895         dev_queue_xmit(skb);
896 }
897 #else
898 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
899 {
900 }
901 #endif
902
903 static bool amt_queue_event(struct amt_dev *amt, enum amt_event event,
904                             struct sk_buff *skb)
905 {
906         int index;
907
908         spin_lock_bh(&amt->lock);
909         if (amt->nr_events >= AMT_MAX_EVENTS) {
910                 spin_unlock_bh(&amt->lock);
911                 return 1;
912         }
913
914         index = (amt->event_idx + amt->nr_events) % AMT_MAX_EVENTS;
915         amt->events[index].event = event;
916         amt->events[index].skb = skb;
917         amt->nr_events++;
918         amt->event_idx %= AMT_MAX_EVENTS;
919         queue_work(amt_wq, &amt->event_wq);
920         spin_unlock_bh(&amt->lock);
921
922         return 0;
923 }
924
925 static void amt_secret_work(struct work_struct *work)
926 {
927         struct amt_dev *amt = container_of(to_delayed_work(work),
928                                            struct amt_dev,
929                                            secret_wq);
930
931         spin_lock_bh(&amt->lock);
932         get_random_bytes(&amt->key, sizeof(siphash_key_t));
933         spin_unlock_bh(&amt->lock);
934         mod_delayed_work(amt_wq, &amt->secret_wq,
935                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
936 }
937
938 static void amt_event_send_discovery(struct amt_dev *amt)
939 {
940         spin_lock_bh(&amt->lock);
941         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
942                 goto out;
943         get_random_bytes(&amt->nonce, sizeof(__be32));
944         spin_unlock_bh(&amt->lock);
945
946         amt_send_discovery(amt);
947         spin_lock_bh(&amt->lock);
948 out:
949         mod_delayed_work(amt_wq, &amt->discovery_wq,
950                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
951         spin_unlock_bh(&amt->lock);
952 }
953
954 static void amt_discovery_work(struct work_struct *work)
955 {
956         struct amt_dev *amt = container_of(to_delayed_work(work),
957                                            struct amt_dev,
958                                            discovery_wq);
959
960         if (amt_queue_event(amt, AMT_EVENT_SEND_DISCOVERY, NULL))
961                 mod_delayed_work(amt_wq, &amt->discovery_wq,
962                                  msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
963 }
964
965 static void amt_event_send_request(struct amt_dev *amt)
966 {
967         u32 exp;
968
969         spin_lock_bh(&amt->lock);
970         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
971                 goto out;
972
973         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
974                 netdev_dbg(amt->dev, "Gateway is not ready");
975                 amt->qi = AMT_INIT_REQ_TIMEOUT;
976                 amt->ready4 = false;
977                 amt->ready6 = false;
978                 amt->remote_ip = 0;
979                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
980                 amt->req_cnt = 0;
981                 goto out;
982         }
983         spin_unlock_bh(&amt->lock);
984
985         amt_send_request(amt, false);
986         amt_send_request(amt, true);
987         spin_lock_bh(&amt->lock);
988         __amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
989         amt->req_cnt++;
990 out:
991         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
992         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
993         spin_unlock_bh(&amt->lock);
994 }
995
996 static void amt_req_work(struct work_struct *work)
997 {
998         struct amt_dev *amt = container_of(to_delayed_work(work),
999                                            struct amt_dev,
1000                                            req_wq);
1001
1002         if (amt_queue_event(amt, AMT_EVENT_SEND_REQUEST, NULL))
1003                 mod_delayed_work(amt_wq, &amt->req_wq,
1004                                  msecs_to_jiffies(100));
1005 }
1006
1007 static bool amt_send_membership_update(struct amt_dev *amt,
1008                                        struct sk_buff *skb,
1009                                        bool v6)
1010 {
1011         struct amt_header_membership_update *amtmu;
1012         struct socket *sock;
1013         struct iphdr *iph;
1014         struct flowi4 fl4;
1015         struct rtable *rt;
1016         int err;
1017
1018         sock = rcu_dereference_bh(amt->sock);
1019         if (!sock)
1020                 return true;
1021
1022         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
1023                            sizeof(*iph) + sizeof(struct udphdr));
1024         if (err)
1025                 return true;
1026
1027         skb_reset_inner_headers(skb);
1028         memset(&fl4, 0, sizeof(struct flowi4));
1029         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1030         fl4.daddr              = amt->remote_ip;
1031         fl4.saddr              = amt->local_ip;
1032         fl4.flowi4_tos         = AMT_TOS;
1033         fl4.flowi4_proto       = IPPROTO_UDP;
1034         rt = ip_route_output_key(amt->net, &fl4);
1035         if (IS_ERR(rt)) {
1036                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
1037                 return true;
1038         }
1039
1040         amtmu                   = skb_push(skb, sizeof(*amtmu));
1041         amtmu->version          = 0;
1042         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1043         amtmu->reserved         = 0;
1044         amtmu->nonce            = amt->nonce;
1045         amtmu->response_mac     = amt->mac;
1046
1047         if (!v6)
1048                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1049         else
1050                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1051         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1052                             fl4.saddr,
1053                             fl4.daddr,
1054                             AMT_TOS,
1055                             ip4_dst_hoplimit(&rt->dst),
1056                             0,
1057                             amt->gw_port,
1058                             amt->relay_port,
1059                             false,
1060                             false);
1061         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1062         return false;
1063 }
1064
1065 static void amt_send_multicast_data(struct amt_dev *amt,
1066                                     const struct sk_buff *oskb,
1067                                     struct amt_tunnel_list *tunnel,
1068                                     bool v6)
1069 {
1070         struct amt_header_mcast_data *amtmd;
1071         struct socket *sock;
1072         struct sk_buff *skb;
1073         struct iphdr *iph;
1074         struct flowi4 fl4;
1075         struct rtable *rt;
1076
1077         sock = rcu_dereference_bh(amt->sock);
1078         if (!sock)
1079                 return;
1080
1081         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1082                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1083         if (!skb)
1084                 return;
1085
1086         skb_reset_inner_headers(skb);
1087         memset(&fl4, 0, sizeof(struct flowi4));
1088         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1089         fl4.daddr              = tunnel->ip4;
1090         fl4.saddr              = amt->local_ip;
1091         fl4.flowi4_proto       = IPPROTO_UDP;
1092         rt = ip_route_output_key(amt->net, &fl4);
1093         if (IS_ERR(rt)) {
1094                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1095                 kfree_skb(skb);
1096                 return;
1097         }
1098
1099         amtmd = skb_push(skb, sizeof(*amtmd));
1100         amtmd->version = 0;
1101         amtmd->reserved = 0;
1102         amtmd->type = AMT_MSG_MULTICAST_DATA;
1103
1104         if (!v6)
1105                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1106         else
1107                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1108         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1109                             fl4.saddr,
1110                             fl4.daddr,
1111                             AMT_TOS,
1112                             ip4_dst_hoplimit(&rt->dst),
1113                             0,
1114                             amt->relay_port,
1115                             tunnel->source_port,
1116                             false,
1117                             false);
1118 }
1119
1120 static bool amt_send_membership_query(struct amt_dev *amt,
1121                                       struct sk_buff *skb,
1122                                       struct amt_tunnel_list *tunnel,
1123                                       bool v6)
1124 {
1125         struct amt_header_membership_query *amtmq;
1126         struct socket *sock;
1127         struct rtable *rt;
1128         struct flowi4 fl4;
1129         int err;
1130
1131         sock = rcu_dereference_bh(amt->sock);
1132         if (!sock)
1133                 return true;
1134
1135         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1136                            sizeof(struct iphdr) + sizeof(struct udphdr));
1137         if (err)
1138                 return true;
1139
1140         skb_reset_inner_headers(skb);
1141         memset(&fl4, 0, sizeof(struct flowi4));
1142         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1143         fl4.daddr              = tunnel->ip4;
1144         fl4.saddr              = amt->local_ip;
1145         fl4.flowi4_tos         = AMT_TOS;
1146         fl4.flowi4_proto       = IPPROTO_UDP;
1147         rt = ip_route_output_key(amt->net, &fl4);
1148         if (IS_ERR(rt)) {
1149                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1150                 return true;
1151         }
1152
1153         amtmq           = skb_push(skb, sizeof(*amtmq));
1154         amtmq->version  = 0;
1155         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1156         amtmq->reserved = 0;
1157         amtmq->l        = 0;
1158         amtmq->g        = 0;
1159         amtmq->nonce    = tunnel->nonce;
1160         amtmq->response_mac = tunnel->mac;
1161
1162         if (!v6)
1163                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1164         else
1165                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1166         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1167                             fl4.saddr,
1168                             fl4.daddr,
1169                             AMT_TOS,
1170                             ip4_dst_hoplimit(&rt->dst),
1171                             0,
1172                             amt->relay_port,
1173                             tunnel->source_port,
1174                             false,
1175                             false);
1176         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1177         return false;
1178 }
1179
1180 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1181 {
1182         struct amt_dev *amt = netdev_priv(dev);
1183         struct amt_tunnel_list *tunnel;
1184         struct amt_group_node *gnode;
1185         union amt_addr group = {0,};
1186 #if IS_ENABLED(CONFIG_IPV6)
1187         struct ipv6hdr *ip6h;
1188         struct mld_msg *mld;
1189 #endif
1190         bool report = false;
1191         struct igmphdr *ih;
1192         bool query = false;
1193         struct iphdr *iph;
1194         bool data = false;
1195         bool v6 = false;
1196         u32 hash;
1197
1198         iph = ip_hdr(skb);
1199         if (iph->version == 4) {
1200                 if (!ipv4_is_multicast(iph->daddr))
1201                         goto free;
1202
1203                 if (!ip_mc_check_igmp(skb)) {
1204                         ih = igmp_hdr(skb);
1205                         switch (ih->type) {
1206                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1207                         case IGMP_HOST_MEMBERSHIP_REPORT:
1208                                 report = true;
1209                                 break;
1210                         case IGMP_HOST_MEMBERSHIP_QUERY:
1211                                 query = true;
1212                                 break;
1213                         default:
1214                                 goto free;
1215                         }
1216                 } else {
1217                         data = true;
1218                 }
1219                 v6 = false;
1220                 group.ip4 = iph->daddr;
1221 #if IS_ENABLED(CONFIG_IPV6)
1222         } else if (iph->version == 6) {
1223                 ip6h = ipv6_hdr(skb);
1224                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1225                         goto free;
1226
1227                 if (!ipv6_mc_check_mld(skb)) {
1228                         mld = (struct mld_msg *)skb_transport_header(skb);
1229                         switch (mld->mld_type) {
1230                         case ICMPV6_MGM_REPORT:
1231                         case ICMPV6_MLD2_REPORT:
1232                                 report = true;
1233                                 break;
1234                         case ICMPV6_MGM_QUERY:
1235                                 query = true;
1236                                 break;
1237                         default:
1238                                 goto free;
1239                         }
1240                 } else {
1241                         data = true;
1242                 }
1243                 v6 = true;
1244                 group.ip6 = ip6h->daddr;
1245 #endif
1246         } else {
1247                 dev->stats.tx_errors++;
1248                 goto free;
1249         }
1250
1251         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1252                 goto free;
1253
1254         skb_pull(skb, sizeof(struct ethhdr));
1255
1256         if (amt->mode == AMT_MODE_GATEWAY) {
1257                 /* Gateway only passes IGMP/MLD packets */
1258                 if (!report)
1259                         goto free;
1260                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1261                         goto free;
1262                 if (amt_send_membership_update(amt, skb,  v6))
1263                         goto free;
1264                 goto unlock;
1265         } else if (amt->mode == AMT_MODE_RELAY) {
1266                 if (query) {
1267                         tunnel = amt_skb_cb(skb)->tunnel;
1268                         if (!tunnel) {
1269                                 WARN_ON(1);
1270                                 goto free;
1271                         }
1272
1273                         /* Do not forward unexpected query */
1274                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1275                                 goto free;
1276                         goto unlock;
1277                 }
1278
1279                 if (!data)
1280                         goto free;
1281                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1282                         hash = amt_group_hash(tunnel, &group);
1283                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1284                                                  node) {
1285                                 if (!v6) {
1286                                         if (gnode->group_addr.ip4 == iph->daddr)
1287                                                 goto found;
1288 #if IS_ENABLED(CONFIG_IPV6)
1289                                 } else {
1290                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1291                                                             &ip6h->daddr))
1292                                                 goto found;
1293 #endif
1294                                 }
1295                         }
1296                         continue;
1297 found:
1298                         amt_send_multicast_data(amt, skb, tunnel, v6);
1299                 }
1300         }
1301
1302         dev_kfree_skb(skb);
1303         return NETDEV_TX_OK;
1304 free:
1305         dev_kfree_skb(skb);
1306 unlock:
1307         dev->stats.tx_dropped++;
1308         return NETDEV_TX_OK;
1309 }
1310
1311 static int amt_parse_type(struct sk_buff *skb)
1312 {
1313         struct amt_header *amth;
1314
1315         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1316                            sizeof(struct amt_header)))
1317                 return -1;
1318
1319         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1320
1321         if (amth->version != 0)
1322                 return -1;
1323
1324         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1325                 return -1;
1326         return amth->type;
1327 }
1328
1329 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1330 {
1331         struct amt_dev *amt = tunnel->amt;
1332         struct amt_group_node *gnode;
1333         struct hlist_node *t;
1334         int i;
1335
1336         spin_lock_bh(&tunnel->lock);
1337         rcu_read_lock();
1338         for (i = 0; i < amt->hash_buckets; i++)
1339                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1340                         amt_del_group(amt, gnode);
1341         rcu_read_unlock();
1342         spin_unlock_bh(&tunnel->lock);
1343 }
1344
1345 static void amt_tunnel_expire(struct work_struct *work)
1346 {
1347         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1348                                                       struct amt_tunnel_list,
1349                                                       gc_wq);
1350         struct amt_dev *amt = tunnel->amt;
1351
1352         spin_lock_bh(&amt->lock);
1353         rcu_read_lock();
1354         list_del_rcu(&tunnel->list);
1355         amt->nr_tunnels--;
1356         amt_clear_groups(tunnel);
1357         rcu_read_unlock();
1358         spin_unlock_bh(&amt->lock);
1359         kfree_rcu(tunnel, rcu);
1360 }
1361
1362 static void amt_cleanup_srcs(struct amt_dev *amt,
1363                              struct amt_tunnel_list *tunnel,
1364                              struct amt_group_node *gnode)
1365 {
1366         struct amt_source_node *snode;
1367         struct hlist_node *t;
1368         int i;
1369
1370         /* Delete old sources */
1371         for (i = 0; i < amt->hash_buckets; i++) {
1372                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1373                         if (snode->flags == AMT_SOURCE_OLD)
1374                                 amt_destroy_source(snode);
1375                 }
1376         }
1377
1378         /* switch from new to old */
1379         for (i = 0; i < amt->hash_buckets; i++)  {
1380                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1381                         snode->flags = AMT_SOURCE_OLD;
1382                         if (!gnode->v6)
1383                                 netdev_dbg(snode->gnode->amt->dev,
1384                                            "Add source as OLD %pI4 from %pI4\n",
1385                                            &snode->source_addr.ip4,
1386                                            &gnode->group_addr.ip4);
1387 #if IS_ENABLED(CONFIG_IPV6)
1388                         else
1389                                 netdev_dbg(snode->gnode->amt->dev,
1390                                            "Add source as OLD %pI6 from %pI6\n",
1391                                            &snode->source_addr.ip6,
1392                                            &gnode->group_addr.ip6);
1393 #endif
1394                 }
1395         }
1396 }
1397
1398 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1399                          struct amt_group_node *gnode, void *grec,
1400                          bool v6)
1401 {
1402         struct igmpv3_grec *igmp_grec;
1403         struct amt_source_node *snode;
1404 #if IS_ENABLED(CONFIG_IPV6)
1405         struct mld2_grec *mld_grec;
1406 #endif
1407         union amt_addr src = {0,};
1408         u16 nsrcs;
1409         u32 hash;
1410         int i;
1411
1412         if (!v6) {
1413                 igmp_grec = (struct igmpv3_grec *)grec;
1414                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1415         } else {
1416 #if IS_ENABLED(CONFIG_IPV6)
1417                 mld_grec = (struct mld2_grec *)grec;
1418                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1419 #else
1420         return;
1421 #endif
1422         }
1423         for (i = 0; i < nsrcs; i++) {
1424                 if (tunnel->nr_sources >= amt->max_sources)
1425                         return;
1426                 if (!v6)
1427                         src.ip4 = igmp_grec->grec_src[i];
1428 #if IS_ENABLED(CONFIG_IPV6)
1429                 else
1430                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1431                                sizeof(struct in6_addr));
1432 #endif
1433                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1434                         continue;
1435
1436                 snode = amt_alloc_snode(gnode, &src);
1437                 if (snode) {
1438                         hash = amt_source_hash(tunnel, &snode->source_addr);
1439                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1440                         tunnel->nr_sources++;
1441                         gnode->nr_sources++;
1442
1443                         if (!gnode->v6)
1444                                 netdev_dbg(snode->gnode->amt->dev,
1445                                            "Add source as NEW %pI4 from %pI4\n",
1446                                            &snode->source_addr.ip4,
1447                                            &gnode->group_addr.ip4);
1448 #if IS_ENABLED(CONFIG_IPV6)
1449                         else
1450                                 netdev_dbg(snode->gnode->amt->dev,
1451                                            "Add source as NEW %pI6 from %pI6\n",
1452                                            &snode->source_addr.ip6,
1453                                            &gnode->group_addr.ip6);
1454 #endif
1455                 }
1456         }
1457 }
1458
1459 /* Router State   Report Rec'd New Router State
1460  * ------------   ------------ ----------------
1461  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1462  *
1463  * -----------+-----------+-----------+
1464  *            |    OLD    |    NEW    |
1465  * -----------+-----------+-----------+
1466  *    FWD     |     X     |    X+A    |
1467  * -----------+-----------+-----------+
1468  *    D_FWD   |     Y     |    Y-A    |
1469  * -----------+-----------+-----------+
1470  *    NONE    |           |     A     |
1471  * -----------+-----------+-----------+
1472  *
1473  * a) Received sources are NONE/NEW
1474  * b) All NONE will be deleted by amt_cleanup_srcs().
1475  * c) All OLD will be deleted by amt_cleanup_srcs().
1476  * d) After delete, NEW source will be switched to OLD.
1477  */
1478 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1479                                 struct amt_group_node *gnode,
1480                                 void *grec,
1481                                 enum amt_ops ops,
1482                                 enum amt_filter filter,
1483                                 enum amt_act act,
1484                                 bool v6)
1485 {
1486         struct amt_dev *amt = tunnel->amt;
1487         struct amt_source_node *snode;
1488         struct igmpv3_grec *igmp_grec;
1489 #if IS_ENABLED(CONFIG_IPV6)
1490         struct mld2_grec *mld_grec;
1491 #endif
1492         union amt_addr src = {0,};
1493         struct hlist_node *t;
1494         u16 nsrcs;
1495         int i, j;
1496
1497         if (!v6) {
1498                 igmp_grec = (struct igmpv3_grec *)grec;
1499                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1500         } else {
1501 #if IS_ENABLED(CONFIG_IPV6)
1502                 mld_grec = (struct mld2_grec *)grec;
1503                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1504 #else
1505         return;
1506 #endif
1507         }
1508
1509         memset(&src, 0, sizeof(union amt_addr));
1510         switch (ops) {
1511         case AMT_OPS_INT:
1512                 /* A*B */
1513                 for (i = 0; i < nsrcs; i++) {
1514                         if (!v6)
1515                                 src.ip4 = igmp_grec->grec_src[i];
1516 #if IS_ENABLED(CONFIG_IPV6)
1517                         else
1518                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1519                                        sizeof(struct in6_addr));
1520 #endif
1521                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1522                         if (!snode)
1523                                 continue;
1524                         amt_act_src(tunnel, gnode, snode, act);
1525                 }
1526                 break;
1527         case AMT_OPS_UNI:
1528                 /* A+B */
1529                 for (i = 0; i < amt->hash_buckets; i++) {
1530                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1531                                                   node) {
1532                                 if (amt_status_filter(snode, filter))
1533                                         amt_act_src(tunnel, gnode, snode, act);
1534                         }
1535                 }
1536                 for (i = 0; i < nsrcs; i++) {
1537                         if (!v6)
1538                                 src.ip4 = igmp_grec->grec_src[i];
1539 #if IS_ENABLED(CONFIG_IPV6)
1540                         else
1541                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1542                                        sizeof(struct in6_addr));
1543 #endif
1544                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1545                         if (!snode)
1546                                 continue;
1547                         amt_act_src(tunnel, gnode, snode, act);
1548                 }
1549                 break;
1550         case AMT_OPS_SUB:
1551                 /* A-B */
1552                 for (i = 0; i < amt->hash_buckets; i++) {
1553                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1554                                                   node) {
1555                                 if (!amt_status_filter(snode, filter))
1556                                         continue;
1557                                 for (j = 0; j < nsrcs; j++) {
1558                                         if (!v6)
1559                                                 src.ip4 = igmp_grec->grec_src[j];
1560 #if IS_ENABLED(CONFIG_IPV6)
1561                                         else
1562                                                 memcpy(&src.ip6,
1563                                                        &mld_grec->grec_src[j],
1564                                                        sizeof(struct in6_addr));
1565 #endif
1566                                         if (amt_addr_equal(&snode->source_addr,
1567                                                            &src))
1568                                                 goto out_sub;
1569                                 }
1570                                 amt_act_src(tunnel, gnode, snode, act);
1571                                 continue;
1572 out_sub:;
1573                         }
1574                 }
1575                 break;
1576         case AMT_OPS_SUB_REV:
1577                 /* B-A */
1578                 for (i = 0; i < nsrcs; i++) {
1579                         if (!v6)
1580                                 src.ip4 = igmp_grec->grec_src[i];
1581 #if IS_ENABLED(CONFIG_IPV6)
1582                         else
1583                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1584                                        sizeof(struct in6_addr));
1585 #endif
1586                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1587                                                &src);
1588                         if (!snode) {
1589                                 snode = amt_lookup_src(tunnel, gnode,
1590                                                        filter, &src);
1591                                 if (snode)
1592                                         amt_act_src(tunnel, gnode, snode, act);
1593                         }
1594                 }
1595                 break;
1596         default:
1597                 netdev_dbg(amt->dev, "Invalid type\n");
1598                 return;
1599         }
1600 }
1601
1602 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1603                                     struct amt_tunnel_list *tunnel,
1604                                     struct amt_group_node *gnode,
1605                                     void *grec, void *zero_grec, bool v6)
1606 {
1607         if (gnode->filter_mode == MCAST_INCLUDE) {
1608 /* Router State   Report Rec'd New Router State        Actions
1609  * ------------   ------------ ----------------        -------
1610  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1611  */
1612                 /* Update IS_IN (B) as FWD/NEW */
1613                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1614                                     AMT_FILTER_NONE_NEW,
1615                                     AMT_ACT_STATUS_FWD_NEW,
1616                                     v6);
1617                 /* Update INCLUDE (A) as NEW */
1618                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1619                                     AMT_FILTER_FWD,
1620                                     AMT_ACT_STATUS_FWD_NEW,
1621                                     v6);
1622                 /* (B)=GMI */
1623                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1624                                     AMT_FILTER_FWD_NEW,
1625                                     AMT_ACT_GMI,
1626                                     v6);
1627         } else {
1628 /* State        Actions
1629  * ------------   ------------ ----------------        -------
1630  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1631  */
1632                 /* Update (A) in (X, Y) as NONE/NEW */
1633                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1634                                     AMT_FILTER_BOTH,
1635                                     AMT_ACT_STATUS_NONE_NEW,
1636                                     v6);
1637                 /* Update FWD/OLD as FWD/NEW */
1638                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1639                                     AMT_FILTER_FWD,
1640                                     AMT_ACT_STATUS_FWD_NEW,
1641                                     v6);
1642                 /* Update IS_IN (A) as FWD/NEW */
1643                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1644                                     AMT_FILTER_NONE_NEW,
1645                                     AMT_ACT_STATUS_FWD_NEW,
1646                                     v6);
1647                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1648                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1649                                     AMT_FILTER_D_FWD,
1650                                     AMT_ACT_STATUS_D_FWD_NEW,
1651                                     v6);
1652         }
1653 }
1654
1655 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1656                                     struct amt_tunnel_list *tunnel,
1657                                     struct amt_group_node *gnode,
1658                                     void *grec, void *zero_grec, bool v6)
1659 {
1660         if (gnode->filter_mode == MCAST_INCLUDE) {
1661 /* Router State   Report Rec'd  New Router State         Actions
1662  * ------------   ------------  ----------------         -------
1663  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1664  *                                                       Delete (A-B)
1665  *                                                       Group Timer=GMI
1666  */
1667                 /* EXCLUDE(A*B, ) */
1668                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1669                                     AMT_FILTER_FWD,
1670                                     AMT_ACT_STATUS_FWD_NEW,
1671                                     v6);
1672                 /* EXCLUDE(, B-A) */
1673                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1674                                     AMT_FILTER_FWD,
1675                                     AMT_ACT_STATUS_D_FWD_NEW,
1676                                     v6);
1677                 /* (B-A)=0 */
1678                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1679                                     AMT_FILTER_D_FWD_NEW,
1680                                     AMT_ACT_GMI_ZERO,
1681                                     v6);
1682                 /* Group Timer=GMI */
1683                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1684                                       msecs_to_jiffies(amt_gmi(amt))))
1685                         dev_hold(amt->dev);
1686                 gnode->filter_mode = MCAST_EXCLUDE;
1687                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1688         } else {
1689 /* Router State   Report Rec'd  New Router State        Actions
1690  * ------------   ------------  ----------------        -------
1691  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1692  *                                                      Delete (X-A)
1693  *                                                      Delete (Y-A)
1694  *                                                      Group Timer=GMI
1695  */
1696                 /* EXCLUDE (A-Y, ) */
1697                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1698                                     AMT_FILTER_D_FWD,
1699                                     AMT_ACT_STATUS_FWD_NEW,
1700                                     v6);
1701                 /* EXCLUDE (, Y*A ) */
1702                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1703                                     AMT_FILTER_D_FWD,
1704                                     AMT_ACT_STATUS_D_FWD_NEW,
1705                                     v6);
1706                 /* (A-X-Y)=GMI */
1707                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1708                                     AMT_FILTER_BOTH_NEW,
1709                                     AMT_ACT_GMI,
1710                                     v6);
1711                 /* Group Timer=GMI */
1712                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1713                                       msecs_to_jiffies(amt_gmi(amt))))
1714                         dev_hold(amt->dev);
1715                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1716         }
1717 }
1718
1719 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1720                                     struct amt_tunnel_list *tunnel,
1721                                     struct amt_group_node *gnode,
1722                                     void *grec, void *zero_grec, bool v6)
1723 {
1724         if (gnode->filter_mode == MCAST_INCLUDE) {
1725 /* Router State   Report Rec'd New Router State        Actions
1726  * ------------   ------------ ----------------        -------
1727  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1728  *                                                     Send Q(G,A-B)
1729  */
1730                 /* Update TO_IN (B) sources as FWD/NEW */
1731                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1732                                     AMT_FILTER_NONE_NEW,
1733                                     AMT_ACT_STATUS_FWD_NEW,
1734                                     v6);
1735                 /* Update INCLUDE (A) sources as NEW */
1736                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1737                                     AMT_FILTER_FWD,
1738                                     AMT_ACT_STATUS_FWD_NEW,
1739                                     v6);
1740                 /* (B)=GMI */
1741                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1742                                     AMT_FILTER_FWD_NEW,
1743                                     AMT_ACT_GMI,
1744                                     v6);
1745         } else {
1746 /* Router State   Report Rec'd New Router State        Actions
1747  * ------------   ------------ ----------------        -------
1748  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1749  *                                                     Send Q(G,X-A)
1750  *                                                     Send Q(G)
1751  */
1752                 /* Update TO_IN (A) sources as FWD/NEW */
1753                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1754                                     AMT_FILTER_NONE_NEW,
1755                                     AMT_ACT_STATUS_FWD_NEW,
1756                                     v6);
1757                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1758                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1759                                     AMT_FILTER_FWD,
1760                                     AMT_ACT_STATUS_FWD_NEW,
1761                                     v6);
1762                 /* EXCLUDE (, Y-A)
1763                  * (A) are already switched to FWD_NEW.
1764                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1765                  */
1766                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1767                                     AMT_FILTER_D_FWD,
1768                                     AMT_ACT_STATUS_D_FWD_NEW,
1769                                     v6);
1770                 /* (A)=GMI
1771                  * Only FWD_NEW will have (A) sources.
1772                  */
1773                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1774                                     AMT_FILTER_FWD_NEW,
1775                                     AMT_ACT_GMI,
1776                                     v6);
1777         }
1778 }
1779
1780 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1781                                     struct amt_tunnel_list *tunnel,
1782                                     struct amt_group_node *gnode,
1783                                     void *grec, void *zero_grec, bool v6)
1784 {
1785         if (gnode->filter_mode == MCAST_INCLUDE) {
1786 /* Router State   Report Rec'd New Router State        Actions
1787  * ------------   ------------ ----------------        -------
1788  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1789  *                                                     Delete (A-B)
1790  *                                                     Send Q(G,A*B)
1791  *                                                     Group Timer=GMI
1792  */
1793                 /* EXCLUDE (A*B, ) */
1794                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1795                                     AMT_FILTER_FWD,
1796                                     AMT_ACT_STATUS_FWD_NEW,
1797                                     v6);
1798                 /* EXCLUDE (, B-A) */
1799                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1800                                     AMT_FILTER_FWD,
1801                                     AMT_ACT_STATUS_D_FWD_NEW,
1802                                     v6);
1803                 /* (B-A)=0 */
1804                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1805                                     AMT_FILTER_D_FWD_NEW,
1806                                     AMT_ACT_GMI_ZERO,
1807                                     v6);
1808                 /* Group Timer=GMI */
1809                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1810                                       msecs_to_jiffies(amt_gmi(amt))))
1811                         dev_hold(amt->dev);
1812                 gnode->filter_mode = MCAST_EXCLUDE;
1813                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1814         } else {
1815 /* Router State   Report Rec'd New Router State        Actions
1816  * ------------   ------------ ----------------        -------
1817  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1818  *                                                     Delete (X-A)
1819  *                                                     Delete (Y-A)
1820  *                                                     Send Q(G,A-Y)
1821  *                                                     Group Timer=GMI
1822  */
1823                 /* Update (A-X-Y) as NONE/OLD */
1824                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1825                                     AMT_FILTER_BOTH,
1826                                     AMT_ACT_GT,
1827                                     v6);
1828                 /* EXCLUDE (A-Y, ) */
1829                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1830                                     AMT_FILTER_D_FWD,
1831                                     AMT_ACT_STATUS_FWD_NEW,
1832                                     v6);
1833                 /* EXCLUDE (, Y*A) */
1834                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1835                                     AMT_FILTER_D_FWD,
1836                                     AMT_ACT_STATUS_D_FWD_NEW,
1837                                     v6);
1838                 /* Group Timer=GMI */
1839                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1840                                       msecs_to_jiffies(amt_gmi(amt))))
1841                         dev_hold(amt->dev);
1842                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1843         }
1844 }
1845
1846 static void amt_mcast_allow_handler(struct amt_dev *amt,
1847                                     struct amt_tunnel_list *tunnel,
1848                                     struct amt_group_node *gnode,
1849                                     void *grec, void *zero_grec, bool v6)
1850 {
1851         if (gnode->filter_mode == MCAST_INCLUDE) {
1852 /* Router State   Report Rec'd New Router State        Actions
1853  * ------------   ------------ ----------------        -------
1854  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1855  */
1856                 /* INCLUDE (A+B) */
1857                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1858                                     AMT_FILTER_FWD,
1859                                     AMT_ACT_STATUS_FWD_NEW,
1860                                     v6);
1861                 /* (B)=GMI */
1862                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1863                                     AMT_FILTER_FWD_NEW,
1864                                     AMT_ACT_GMI,
1865                                     v6);
1866         } else {
1867 /* Router State   Report Rec'd New Router State        Actions
1868  * ------------   ------------ ----------------        -------
1869  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1870  */
1871                 /* EXCLUDE (X+A, ) */
1872                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1873                                     AMT_FILTER_FWD,
1874                                     AMT_ACT_STATUS_FWD_NEW,
1875                                     v6);
1876                 /* EXCLUDE (, Y-A) */
1877                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1878                                     AMT_FILTER_D_FWD,
1879                                     AMT_ACT_STATUS_D_FWD_NEW,
1880                                     v6);
1881                 /* (A)=GMI
1882                  * All (A) source are now FWD/NEW status.
1883                  */
1884                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1885                                     AMT_FILTER_FWD_NEW,
1886                                     AMT_ACT_GMI,
1887                                     v6);
1888         }
1889 }
1890
1891 static void amt_mcast_block_handler(struct amt_dev *amt,
1892                                     struct amt_tunnel_list *tunnel,
1893                                     struct amt_group_node *gnode,
1894                                     void *grec, void *zero_grec, bool v6)
1895 {
1896         if (gnode->filter_mode == MCAST_INCLUDE) {
1897 /* Router State   Report Rec'd New Router State        Actions
1898  * ------------   ------------ ----------------        -------
1899  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1900  */
1901                 /* INCLUDE (A) */
1902                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1903                                     AMT_FILTER_FWD,
1904                                     AMT_ACT_STATUS_FWD_NEW,
1905                                     v6);
1906         } else {
1907 /* Router State   Report Rec'd New Router State        Actions
1908  * ------------   ------------ ----------------        -------
1909  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1910  *                                                     Send Q(G,A-Y)
1911  */
1912                 /* (A-X-Y)=Group Timer */
1913                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1914                                     AMT_FILTER_BOTH,
1915                                     AMT_ACT_GT,
1916                                     v6);
1917                 /* EXCLUDE (X, ) */
1918                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1919                                     AMT_FILTER_FWD,
1920                                     AMT_ACT_STATUS_FWD_NEW,
1921                                     v6);
1922                 /* EXCLUDE (X+(A-Y) */
1923                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1924                                     AMT_FILTER_D_FWD,
1925                                     AMT_ACT_STATUS_FWD_NEW,
1926                                     v6);
1927                 /* EXCLUDE (, Y) */
1928                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1929                                     AMT_FILTER_D_FWD,
1930                                     AMT_ACT_STATUS_D_FWD_NEW,
1931                                     v6);
1932         }
1933 }
1934
1935 /* RFC 3376
1936  * 7.3.2. In the Presence of Older Version Group Members
1937  *
1938  * When Group Compatibility Mode is IGMPv2, a router internally
1939  * translates the following IGMPv2 messages for that group to their
1940  * IGMPv3 equivalents:
1941  *
1942  * IGMPv2 Message                IGMPv3 Equivalent
1943  * --------------                -----------------
1944  * Report                        IS_EX( {} )
1945  * Leave                         TO_IN( {} )
1946  */
1947 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1948                                       struct amt_tunnel_list *tunnel)
1949 {
1950         struct igmphdr *ih = igmp_hdr(skb);
1951         struct iphdr *iph = ip_hdr(skb);
1952         struct amt_group_node *gnode;
1953         union amt_addr group, host;
1954
1955         memset(&group, 0, sizeof(union amt_addr));
1956         group.ip4 = ih->group;
1957         memset(&host, 0, sizeof(union amt_addr));
1958         host.ip4 = iph->saddr;
1959
1960         gnode = amt_lookup_group(tunnel, &group, &host, false);
1961         if (!gnode) {
1962                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1963                 if (!IS_ERR(gnode)) {
1964                         gnode->filter_mode = MCAST_EXCLUDE;
1965                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1966                                               msecs_to_jiffies(amt_gmi(amt))))
1967                                 dev_hold(amt->dev);
1968                 }
1969         }
1970 }
1971
1972 /* RFC 3376
1973  * 7.3.2. In the Presence of Older Version Group Members
1974  *
1975  * When Group Compatibility Mode is IGMPv2, a router internally
1976  * translates the following IGMPv2 messages for that group to their
1977  * IGMPv3 equivalents:
1978  *
1979  * IGMPv2 Message                IGMPv3 Equivalent
1980  * --------------                -----------------
1981  * Report                        IS_EX( {} )
1982  * Leave                         TO_IN( {} )
1983  */
1984 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1985                                      struct amt_tunnel_list *tunnel)
1986 {
1987         struct igmphdr *ih = igmp_hdr(skb);
1988         struct iphdr *iph = ip_hdr(skb);
1989         struct amt_group_node *gnode;
1990         union amt_addr group, host;
1991
1992         memset(&group, 0, sizeof(union amt_addr));
1993         group.ip4 = ih->group;
1994         memset(&host, 0, sizeof(union amt_addr));
1995         host.ip4 = iph->saddr;
1996
1997         gnode = amt_lookup_group(tunnel, &group, &host, false);
1998         if (gnode)
1999                 amt_del_group(amt, gnode);
2000 }
2001
2002 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2003                                       struct amt_tunnel_list *tunnel)
2004 {
2005         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
2006         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
2007         void *zero_grec = (void *)&igmpv3_zero_grec;
2008         struct iphdr *iph = ip_hdr(skb);
2009         struct amt_group_node *gnode;
2010         union amt_addr group, host;
2011         struct igmpv3_grec *grec;
2012         u16 nsrcs;
2013         int i;
2014
2015         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
2016                 len += sizeof(*grec);
2017                 if (!ip_mc_may_pull(skb, len))
2018                         break;
2019
2020                 grec = (void *)(skb->data + len - sizeof(*grec));
2021                 nsrcs = ntohs(grec->grec_nsrcs);
2022
2023                 len += nsrcs * sizeof(__be32);
2024                 if (!ip_mc_may_pull(skb, len))
2025                         break;
2026
2027                 memset(&group, 0, sizeof(union amt_addr));
2028                 group.ip4 = grec->grec_mca;
2029                 memset(&host, 0, sizeof(union amt_addr));
2030                 host.ip4 = iph->saddr;
2031                 gnode = amt_lookup_group(tunnel, &group, &host, false);
2032                 if (!gnode) {
2033                         gnode = amt_add_group(amt, tunnel, &group, &host,
2034                                               false);
2035                         if (IS_ERR(gnode))
2036                                 continue;
2037                 }
2038
2039                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2040                 switch (grec->grec_type) {
2041                 case IGMPV3_MODE_IS_INCLUDE:
2042                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2043                                                 zero_grec, false);
2044                         break;
2045                 case IGMPV3_MODE_IS_EXCLUDE:
2046                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2047                                                 zero_grec, false);
2048                         break;
2049                 case IGMPV3_CHANGE_TO_INCLUDE:
2050                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2051                                                 zero_grec, false);
2052                         break;
2053                 case IGMPV3_CHANGE_TO_EXCLUDE:
2054                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2055                                                 zero_grec, false);
2056                         break;
2057                 case IGMPV3_ALLOW_NEW_SOURCES:
2058                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2059                                                 zero_grec, false);
2060                         break;
2061                 case IGMPV3_BLOCK_OLD_SOURCES:
2062                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2063                                                 zero_grec, false);
2064                         break;
2065                 default:
2066                         break;
2067                 }
2068                 amt_cleanup_srcs(amt, tunnel, gnode);
2069         }
2070 }
2071
2072 /* caller held tunnel->lock */
2073 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2074                                     struct amt_tunnel_list *tunnel)
2075 {
2076         struct igmphdr *ih = igmp_hdr(skb);
2077
2078         switch (ih->type) {
2079         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2080                 amt_igmpv3_report_handler(amt, skb, tunnel);
2081                 break;
2082         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2083                 amt_igmpv2_report_handler(amt, skb, tunnel);
2084                 break;
2085         case IGMP_HOST_LEAVE_MESSAGE:
2086                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2087                 break;
2088         default:
2089                 break;
2090         }
2091 }
2092
2093 #if IS_ENABLED(CONFIG_IPV6)
2094 /* RFC 3810
2095  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2096  *
2097  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2098  * using the MLDv2 protocol for that multicast address.  When Multicast
2099  * Address Compatibility Mode is MLDv1, a router internally translates
2100  * the following MLDv1 messages for that multicast address to their
2101  * MLDv2 equivalents:
2102  *
2103  * MLDv1 Message                 MLDv2 Equivalent
2104  * --------------                -----------------
2105  * Report                        IS_EX( {} )
2106  * Done                          TO_IN( {} )
2107  */
2108 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2109                                      struct amt_tunnel_list *tunnel)
2110 {
2111         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2112         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2113         struct amt_group_node *gnode;
2114         union amt_addr group, host;
2115
2116         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2117         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2118
2119         gnode = amt_lookup_group(tunnel, &group, &host, true);
2120         if (!gnode) {
2121                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2122                 if (!IS_ERR(gnode)) {
2123                         gnode->filter_mode = MCAST_EXCLUDE;
2124                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2125                                               msecs_to_jiffies(amt_gmi(amt))))
2126                                 dev_hold(amt->dev);
2127                 }
2128         }
2129 }
2130
2131 /* RFC 3810
2132  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2133  *
2134  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2135  * using the MLDv2 protocol for that multicast address.  When Multicast
2136  * Address Compatibility Mode is MLDv1, a router internally translates
2137  * the following MLDv1 messages for that multicast address to their
2138  * MLDv2 equivalents:
2139  *
2140  * MLDv1 Message                 MLDv2 Equivalent
2141  * --------------                -----------------
2142  * Report                        IS_EX( {} )
2143  * Done                          TO_IN( {} )
2144  */
2145 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2146                                     struct amt_tunnel_list *tunnel)
2147 {
2148         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2149         struct iphdr *iph = ip_hdr(skb);
2150         struct amt_group_node *gnode;
2151         union amt_addr group, host;
2152
2153         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2154         memset(&host, 0, sizeof(union amt_addr));
2155         host.ip4 = iph->saddr;
2156
2157         gnode = amt_lookup_group(tunnel, &group, &host, true);
2158         if (gnode) {
2159                 amt_del_group(amt, gnode);
2160                 return;
2161         }
2162 }
2163
2164 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2165                                      struct amt_tunnel_list *tunnel)
2166 {
2167         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2168         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2169         void *zero_grec = (void *)&mldv2_zero_grec;
2170         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2171         struct amt_group_node *gnode;
2172         union amt_addr group, host;
2173         struct mld2_grec *grec;
2174         u16 nsrcs;
2175         int i;
2176
2177         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2178                 len += sizeof(*grec);
2179                 if (!ipv6_mc_may_pull(skb, len))
2180                         break;
2181
2182                 grec = (void *)(skb->data + len - sizeof(*grec));
2183                 nsrcs = ntohs(grec->grec_nsrcs);
2184
2185                 len += nsrcs * sizeof(struct in6_addr);
2186                 if (!ipv6_mc_may_pull(skb, len))
2187                         break;
2188
2189                 memset(&group, 0, sizeof(union amt_addr));
2190                 group.ip6 = grec->grec_mca;
2191                 memset(&host, 0, sizeof(union amt_addr));
2192                 host.ip6 = ip6h->saddr;
2193                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2194                 if (!gnode) {
2195                         gnode = amt_add_group(amt, tunnel, &group, &host,
2196                                               ETH_P_IPV6);
2197                         if (IS_ERR(gnode))
2198                                 continue;
2199                 }
2200
2201                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2202                 switch (grec->grec_type) {
2203                 case MLD2_MODE_IS_INCLUDE:
2204                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2205                                                 zero_grec, true);
2206                         break;
2207                 case MLD2_MODE_IS_EXCLUDE:
2208                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2209                                                 zero_grec, true);
2210                         break;
2211                 case MLD2_CHANGE_TO_INCLUDE:
2212                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2213                                                 zero_grec, true);
2214                         break;
2215                 case MLD2_CHANGE_TO_EXCLUDE:
2216                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2217                                                 zero_grec, true);
2218                         break;
2219                 case MLD2_ALLOW_NEW_SOURCES:
2220                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2221                                                 zero_grec, true);
2222                         break;
2223                 case MLD2_BLOCK_OLD_SOURCES:
2224                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2225                                                 zero_grec, true);
2226                         break;
2227                 default:
2228                         break;
2229                 }
2230                 amt_cleanup_srcs(amt, tunnel, gnode);
2231         }
2232 }
2233
2234 /* caller held tunnel->lock */
2235 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2236                                    struct amt_tunnel_list *tunnel)
2237 {
2238         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2239
2240         switch (mld->mld_type) {
2241         case ICMPV6_MGM_REPORT:
2242                 amt_mldv1_report_handler(amt, skb, tunnel);
2243                 break;
2244         case ICMPV6_MLD2_REPORT:
2245                 amt_mldv2_report_handler(amt, skb, tunnel);
2246                 break;
2247         case ICMPV6_MGM_REDUCTION:
2248                 amt_mldv1_leave_handler(amt, skb, tunnel);
2249                 break;
2250         default:
2251                 break;
2252         }
2253 }
2254 #endif
2255
2256 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2257 {
2258         struct amt_header_advertisement *amta;
2259         int hdr_size;
2260
2261         hdr_size = sizeof(*amta) + sizeof(struct udphdr);
2262         if (!pskb_may_pull(skb, hdr_size))
2263                 return true;
2264
2265         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2266         if (!amta->ip4)
2267                 return true;
2268
2269         if (amta->reserved || amta->version)
2270                 return true;
2271
2272         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2273             ipv4_is_zeronet(amta->ip4))
2274                 return true;
2275
2276         amt->remote_ip = amta->ip4;
2277         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2278         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2279
2280         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2281         return false;
2282 }
2283
2284 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2285 {
2286         struct amt_header_mcast_data *amtmd;
2287         int hdr_size, len, err;
2288         struct ethhdr *eth;
2289         struct iphdr *iph;
2290
2291         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2292         if (!pskb_may_pull(skb, hdr_size))
2293                 return true;
2294
2295         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2296         if (amtmd->reserved || amtmd->version)
2297                 return true;
2298
2299         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2300                 return true;
2301
2302         skb_reset_network_header(skb);
2303         skb_push(skb, sizeof(*eth));
2304         skb_reset_mac_header(skb);
2305         skb_pull(skb, sizeof(*eth));
2306         eth = eth_hdr(skb);
2307
2308         if (!pskb_may_pull(skb, sizeof(*iph)))
2309                 return true;
2310         iph = ip_hdr(skb);
2311
2312         if (iph->version == 4) {
2313                 if (!ipv4_is_multicast(iph->daddr))
2314                         return true;
2315                 skb->protocol = htons(ETH_P_IP);
2316                 eth->h_proto = htons(ETH_P_IP);
2317                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2318 #if IS_ENABLED(CONFIG_IPV6)
2319         } else if (iph->version == 6) {
2320                 struct ipv6hdr *ip6h;
2321
2322                 if (!pskb_may_pull(skb, sizeof(*ip6h)))
2323                         return true;
2324
2325                 ip6h = ipv6_hdr(skb);
2326                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2327                         return true;
2328                 skb->protocol = htons(ETH_P_IPV6);
2329                 eth->h_proto = htons(ETH_P_IPV6);
2330                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2331 #endif
2332         } else {
2333                 return true;
2334         }
2335
2336         skb->pkt_type = PACKET_MULTICAST;
2337         skb->ip_summed = CHECKSUM_NONE;
2338         len = skb->len;
2339         err = gro_cells_receive(&amt->gro_cells, skb);
2340         if (likely(err == NET_RX_SUCCESS))
2341                 dev_sw_netstats_rx_add(amt->dev, len);
2342         else
2343                 amt->dev->stats.rx_dropped++;
2344
2345         return false;
2346 }
2347
2348 static bool amt_membership_query_handler(struct amt_dev *amt,
2349                                          struct sk_buff *skb)
2350 {
2351         struct amt_header_membership_query *amtmq;
2352         struct igmpv3_query *ihv3;
2353         struct ethhdr *eth, *oeth;
2354         struct iphdr *iph;
2355         int hdr_size, len;
2356
2357         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr);
2358         if (!pskb_may_pull(skb, hdr_size))
2359                 return true;
2360
2361         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2362         if (amtmq->reserved || amtmq->version)
2363                 return true;
2364
2365         hdr_size -= sizeof(*eth);
2366         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2367                 return true;
2368
2369         oeth = eth_hdr(skb);
2370         skb_reset_mac_header(skb);
2371         skb_pull(skb, sizeof(*eth));
2372         skb_reset_network_header(skb);
2373         eth = eth_hdr(skb);
2374         if (!pskb_may_pull(skb, sizeof(*iph)))
2375                 return true;
2376
2377         iph = ip_hdr(skb);
2378         if (iph->version == 4) {
2379                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2380                                    sizeof(*ihv3)))
2381                         return true;
2382
2383                 if (!ipv4_is_multicast(iph->daddr))
2384                         return true;
2385
2386                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2387                 skb_reset_transport_header(skb);
2388                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2389                 spin_lock_bh(&amt->lock);
2390                 amt->ready4 = true;
2391                 amt->mac = amtmq->response_mac;
2392                 amt->req_cnt = 0;
2393                 amt->qi = ihv3->qqic;
2394                 spin_unlock_bh(&amt->lock);
2395                 skb->protocol = htons(ETH_P_IP);
2396                 eth->h_proto = htons(ETH_P_IP);
2397                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2398 #if IS_ENABLED(CONFIG_IPV6)
2399         } else if (iph->version == 6) {
2400                 struct mld2_query *mld2q;
2401                 struct ipv6hdr *ip6h;
2402
2403                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2404                                    sizeof(*mld2q)))
2405                         return true;
2406
2407                 ip6h = ipv6_hdr(skb);
2408                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2409                         return true;
2410
2411                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2412                 skb_reset_transport_header(skb);
2413                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2414                 spin_lock_bh(&amt->lock);
2415                 amt->ready6 = true;
2416                 amt->mac = amtmq->response_mac;
2417                 amt->req_cnt = 0;
2418                 amt->qi = mld2q->mld2q_qqic;
2419                 spin_unlock_bh(&amt->lock);
2420                 skb->protocol = htons(ETH_P_IPV6);
2421                 eth->h_proto = htons(ETH_P_IPV6);
2422                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2423 #endif
2424         } else {
2425                 return true;
2426         }
2427
2428         ether_addr_copy(eth->h_source, oeth->h_source);
2429         skb->pkt_type = PACKET_MULTICAST;
2430         skb->ip_summed = CHECKSUM_NONE;
2431         len = skb->len;
2432         local_bh_disable();
2433         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2434                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2435                 dev_sw_netstats_rx_add(amt->dev, len);
2436         } else {
2437                 amt->dev->stats.rx_dropped++;
2438         }
2439         local_bh_enable();
2440
2441         return false;
2442 }
2443
2444 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2445 {
2446         struct amt_header_membership_update *amtmu;
2447         struct amt_tunnel_list *tunnel;
2448         struct ethhdr *eth;
2449         struct iphdr *iph;
2450         int len, hdr_size;
2451
2452         iph = ip_hdr(skb);
2453
2454         hdr_size = sizeof(*amtmu) + sizeof(struct udphdr);
2455         if (!pskb_may_pull(skb, hdr_size))
2456                 return true;
2457
2458         amtmu = (struct amt_header_membership_update *)(udp_hdr(skb) + 1);
2459         if (amtmu->reserved || amtmu->version)
2460                 return true;
2461
2462         if (iptunnel_pull_header(skb, hdr_size, skb->protocol, false))
2463                 return true;
2464
2465         skb_reset_network_header(skb);
2466
2467         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2468                 if (tunnel->ip4 == iph->saddr) {
2469                         if ((amtmu->nonce == tunnel->nonce &&
2470                              amtmu->response_mac == tunnel->mac)) {
2471                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2472                                                  msecs_to_jiffies(amt_gmi(amt))
2473                                                                   * 3);
2474                                 goto report;
2475                         } else {
2476                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2477                                 return true;
2478                         }
2479                 }
2480         }
2481
2482         return true;
2483
2484 report:
2485         if (!pskb_may_pull(skb, sizeof(*iph)))
2486                 return true;
2487
2488         iph = ip_hdr(skb);
2489         if (iph->version == 4) {
2490                 if (ip_mc_check_igmp(skb)) {
2491                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2492                         return true;
2493                 }
2494
2495                 spin_lock_bh(&tunnel->lock);
2496                 amt_igmp_report_handler(amt, skb, tunnel);
2497                 spin_unlock_bh(&tunnel->lock);
2498
2499                 skb_push(skb, sizeof(struct ethhdr));
2500                 skb_reset_mac_header(skb);
2501                 eth = eth_hdr(skb);
2502                 skb->protocol = htons(ETH_P_IP);
2503                 eth->h_proto = htons(ETH_P_IP);
2504                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2505 #if IS_ENABLED(CONFIG_IPV6)
2506         } else if (iph->version == 6) {
2507                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2508
2509                 if (ipv6_mc_check_mld(skb)) {
2510                         netdev_dbg(amt->dev, "Invalid MLD\n");
2511                         return true;
2512                 }
2513
2514                 spin_lock_bh(&tunnel->lock);
2515                 amt_mld_report_handler(amt, skb, tunnel);
2516                 spin_unlock_bh(&tunnel->lock);
2517
2518                 skb_push(skb, sizeof(struct ethhdr));
2519                 skb_reset_mac_header(skb);
2520                 eth = eth_hdr(skb);
2521                 skb->protocol = htons(ETH_P_IPV6);
2522                 eth->h_proto = htons(ETH_P_IPV6);
2523                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2524 #endif
2525         } else {
2526                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2527                 return true;
2528         }
2529
2530         skb_pull(skb, sizeof(struct ethhdr));
2531         skb->pkt_type = PACKET_MULTICAST;
2532         skb->ip_summed = CHECKSUM_NONE;
2533         len = skb->len;
2534         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2535                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2536                                         true);
2537                 dev_sw_netstats_rx_add(amt->dev, len);
2538         } else {
2539                 amt->dev->stats.rx_dropped++;
2540         }
2541
2542         return false;
2543 }
2544
2545 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2546                                    __be32 daddr, __be16 dport)
2547 {
2548         struct amt_header_advertisement *amta;
2549         int hlen, tlen, offset;
2550         struct socket *sock;
2551         struct udphdr *udph;
2552         struct sk_buff *skb;
2553         struct iphdr *iph;
2554         struct rtable *rt;
2555         struct flowi4 fl4;
2556         u32 len;
2557         int err;
2558
2559         rcu_read_lock();
2560         sock = rcu_dereference(amt->sock);
2561         if (!sock)
2562                 goto out;
2563
2564         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2565                 goto out;
2566
2567         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2568                                    daddr, amt->local_ip,
2569                                    dport, amt->relay_port,
2570                                    IPPROTO_UDP, 0,
2571                                    amt->stream_dev->ifindex);
2572         if (IS_ERR(rt)) {
2573                 amt->dev->stats.tx_errors++;
2574                 goto out;
2575         }
2576
2577         hlen = LL_RESERVED_SPACE(amt->dev);
2578         tlen = amt->dev->needed_tailroom;
2579         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2580         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2581         if (!skb) {
2582                 ip_rt_put(rt);
2583                 amt->dev->stats.tx_errors++;
2584                 goto out;
2585         }
2586
2587         skb->priority = TC_PRIO_CONTROL;
2588         skb_dst_set(skb, &rt->dst);
2589
2590         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2591         skb_reset_network_header(skb);
2592         skb_put(skb, len);
2593         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2594         amta->version   = 0;
2595         amta->type      = AMT_MSG_ADVERTISEMENT;
2596         amta->reserved  = 0;
2597         amta->nonce     = nonce;
2598         amta->ip4       = amt->local_ip;
2599         skb_push(skb, sizeof(*udph));
2600         skb_reset_transport_header(skb);
2601         udph            = udp_hdr(skb);
2602         udph->source    = amt->relay_port;
2603         udph->dest      = dport;
2604         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2605         udph->check     = 0;
2606         offset = skb_transport_offset(skb);
2607         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2608         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2609                                         sizeof(*udph) + sizeof(*amta),
2610                                         IPPROTO_UDP, skb->csum);
2611
2612         skb_push(skb, sizeof(*iph));
2613         iph             = ip_hdr(skb);
2614         iph->version    = 4;
2615         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2616         iph->tos        = AMT_TOS;
2617         iph->frag_off   = 0;
2618         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2619         iph->daddr      = daddr;
2620         iph->saddr      = amt->local_ip;
2621         iph->protocol   = IPPROTO_UDP;
2622         iph->tot_len    = htons(len);
2623
2624         skb->ip_summed = CHECKSUM_NONE;
2625         ip_select_ident(amt->net, skb, NULL);
2626         ip_send_check(iph);
2627         err = ip_local_out(amt->net, sock->sk, skb);
2628         if (unlikely(net_xmit_eval(err)))
2629                 amt->dev->stats.tx_errors++;
2630
2631 out:
2632         rcu_read_unlock();
2633 }
2634
2635 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2636 {
2637         struct amt_header_discovery *amtd;
2638         struct udphdr *udph;
2639         struct iphdr *iph;
2640
2641         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2642                 return true;
2643
2644         iph = ip_hdr(skb);
2645         udph = udp_hdr(skb);
2646         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2647
2648         if (amtd->reserved || amtd->version)
2649                 return true;
2650
2651         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2652
2653         return false;
2654 }
2655
2656 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2657 {
2658         struct amt_header_request *amtrh;
2659         struct amt_tunnel_list *tunnel;
2660         unsigned long long key;
2661         struct udphdr *udph;
2662         struct iphdr *iph;
2663         u64 mac;
2664         int i;
2665
2666         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2667                 return true;
2668
2669         iph = ip_hdr(skb);
2670         udph = udp_hdr(skb);
2671         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2672
2673         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2674                 return true;
2675
2676         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2677                 if (tunnel->ip4 == iph->saddr)
2678                         goto send;
2679
2680         if (amt->nr_tunnels >= amt->max_tunnels) {
2681                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2682                 return true;
2683         }
2684
2685         tunnel = kzalloc(sizeof(*tunnel) +
2686                          (sizeof(struct hlist_head) * amt->hash_buckets),
2687                          GFP_ATOMIC);
2688         if (!tunnel)
2689                 return true;
2690
2691         tunnel->source_port = udph->source;
2692         tunnel->ip4 = iph->saddr;
2693
2694         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2695         tunnel->amt = amt;
2696         spin_lock_init(&tunnel->lock);
2697         for (i = 0; i < amt->hash_buckets; i++)
2698                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2699
2700         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2701
2702         spin_lock_bh(&amt->lock);
2703         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2704         tunnel->key = amt->key;
2705         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2706         amt->nr_tunnels++;
2707         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2708                          msecs_to_jiffies(amt_gmi(amt)));
2709         spin_unlock_bh(&amt->lock);
2710
2711 send:
2712         tunnel->nonce = amtrh->nonce;
2713         mac = siphash_3u32((__force u32)tunnel->ip4,
2714                            (__force u32)tunnel->source_port,
2715                            (__force u32)tunnel->nonce,
2716                            &tunnel->key);
2717         tunnel->mac = mac >> 16;
2718
2719         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2720                 return true;
2721
2722         if (!amtrh->p)
2723                 amt_send_igmp_gq(amt, tunnel);
2724         else
2725                 amt_send_mld_gq(amt, tunnel);
2726
2727         return false;
2728 }
2729
2730 static void amt_gw_rcv(struct amt_dev *amt, struct sk_buff *skb)
2731 {
2732         int type = amt_parse_type(skb);
2733         int err = 1;
2734
2735         if (type == -1)
2736                 goto drop;
2737
2738         if (amt->mode == AMT_MODE_GATEWAY) {
2739                 switch (type) {
2740                 case AMT_MSG_ADVERTISEMENT:
2741                         err = amt_advertisement_handler(amt, skb);
2742                         break;
2743                 case AMT_MSG_MEMBERSHIP_QUERY:
2744                         err = amt_membership_query_handler(amt, skb);
2745                         if (!err)
2746                                 return;
2747                         break;
2748                 default:
2749                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2750                         break;
2751                 }
2752         }
2753 drop:
2754         if (err) {
2755                 amt->dev->stats.rx_dropped++;
2756                 kfree_skb(skb);
2757         } else {
2758                 consume_skb(skb);
2759         }
2760 }
2761
2762 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2763 {
2764         struct amt_dev *amt;
2765         struct iphdr *iph;
2766         int type;
2767         bool err;
2768
2769         rcu_read_lock_bh();
2770         amt = rcu_dereference_sk_user_data(sk);
2771         if (!amt) {
2772                 err = true;
2773                 kfree_skb(skb);
2774                 goto out;
2775         }
2776
2777         skb->dev = amt->dev;
2778         iph = ip_hdr(skb);
2779         type = amt_parse_type(skb);
2780         if (type == -1) {
2781                 err = true;
2782                 goto drop;
2783         }
2784
2785         if (amt->mode == AMT_MODE_GATEWAY) {
2786                 switch (type) {
2787                 case AMT_MSG_ADVERTISEMENT:
2788                         if (iph->saddr != amt->discovery_ip) {
2789                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2790                                 err = true;
2791                                 goto drop;
2792                         }
2793                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2794                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2795                                 err = true;
2796                                 goto drop;
2797                         }
2798                         goto out;
2799                 case AMT_MSG_MULTICAST_DATA:
2800                         if (iph->saddr != amt->remote_ip) {
2801                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2802                                 err = true;
2803                                 goto drop;
2804                         }
2805                         err = amt_multicast_data_handler(amt, skb);
2806                         if (err)
2807                                 goto drop;
2808                         else
2809                                 goto out;
2810                 case AMT_MSG_MEMBERSHIP_QUERY:
2811                         if (iph->saddr != amt->remote_ip) {
2812                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2813                                 err = true;
2814                                 goto drop;
2815                         }
2816                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2817                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2818                                 err = true;
2819                                 goto drop;
2820                         }
2821                         goto out;
2822                 default:
2823                         err = true;
2824                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2825                         break;
2826                 }
2827         } else {
2828                 switch (type) {
2829                 case AMT_MSG_DISCOVERY:
2830                         err = amt_discovery_handler(amt, skb);
2831                         break;
2832                 case AMT_MSG_REQUEST:
2833                         err = amt_request_handler(amt, skb);
2834                         break;
2835                 case AMT_MSG_MEMBERSHIP_UPDATE:
2836                         err = amt_update_handler(amt, skb);
2837                         if (err)
2838                                 goto drop;
2839                         else
2840                                 goto out;
2841                 default:
2842                         err = true;
2843                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2844                         break;
2845                 }
2846         }
2847 drop:
2848         if (err) {
2849                 amt->dev->stats.rx_dropped++;
2850                 kfree_skb(skb);
2851         } else {
2852                 consume_skb(skb);
2853         }
2854 out:
2855         rcu_read_unlock_bh();
2856         return 0;
2857 }
2858
2859 static void amt_event_work(struct work_struct *work)
2860 {
2861         struct amt_dev *amt = container_of(work, struct amt_dev, event_wq);
2862         struct sk_buff *skb;
2863         u8 event;
2864         int i;
2865
2866         for (i = 0; i < AMT_MAX_EVENTS; i++) {
2867                 spin_lock_bh(&amt->lock);
2868                 if (amt->nr_events == 0) {
2869                         spin_unlock_bh(&amt->lock);
2870                         return;
2871                 }
2872                 event = amt->events[amt->event_idx].event;
2873                 skb = amt->events[amt->event_idx].skb;
2874                 amt->events[amt->event_idx].event = AMT_EVENT_NONE;
2875                 amt->events[amt->event_idx].skb = NULL;
2876                 amt->nr_events--;
2877                 amt->event_idx++;
2878                 amt->event_idx %= AMT_MAX_EVENTS;
2879                 spin_unlock_bh(&amt->lock);
2880
2881                 switch (event) {
2882                 case AMT_EVENT_RECEIVE:
2883                         amt_gw_rcv(amt, skb);
2884                         break;
2885                 case AMT_EVENT_SEND_DISCOVERY:
2886                         amt_event_send_discovery(amt);
2887                         break;
2888                 case AMT_EVENT_SEND_REQUEST:
2889                         amt_event_send_request(amt);
2890                         break;
2891                 default:
2892                         if (skb)
2893                                 kfree_skb(skb);
2894                         break;
2895                 }
2896         }
2897 }
2898
2899 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2900 {
2901         struct amt_dev *amt;
2902         int type;
2903
2904         rcu_read_lock_bh();
2905         amt = rcu_dereference_sk_user_data(sk);
2906         if (!amt)
2907                 goto out;
2908
2909         if (amt->mode != AMT_MODE_GATEWAY)
2910                 goto drop;
2911
2912         type = amt_parse_type(skb);
2913         if (type == -1)
2914                 goto drop;
2915
2916         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2917                    type_str[type]);
2918         switch (type) {
2919         case AMT_MSG_DISCOVERY:
2920                 break;
2921         case AMT_MSG_REQUEST:
2922         case AMT_MSG_MEMBERSHIP_UPDATE:
2923                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2924                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2925                 break;
2926         default:
2927                 goto drop;
2928         }
2929 out:
2930         rcu_read_unlock_bh();
2931         return 0;
2932 drop:
2933         rcu_read_unlock_bh();
2934         amt->dev->stats.rx_dropped++;
2935         return 0;
2936 }
2937
2938 static struct socket *amt_create_sock(struct net *net, __be16 port)
2939 {
2940         struct udp_port_cfg udp_conf;
2941         struct socket *sock;
2942         int err;
2943
2944         memset(&udp_conf, 0, sizeof(udp_conf));
2945         udp_conf.family = AF_INET;
2946         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2947
2948         udp_conf.local_udp_port = port;
2949
2950         err = udp_sock_create(net, &udp_conf, &sock);
2951         if (err < 0)
2952                 return ERR_PTR(err);
2953
2954         return sock;
2955 }
2956
2957 static int amt_socket_create(struct amt_dev *amt)
2958 {
2959         struct udp_tunnel_sock_cfg tunnel_cfg;
2960         struct socket *sock;
2961
2962         sock = amt_create_sock(amt->net, amt->relay_port);
2963         if (IS_ERR(sock))
2964                 return PTR_ERR(sock);
2965
2966         /* Mark socket as an encapsulation socket */
2967         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2968         tunnel_cfg.sk_user_data = amt;
2969         tunnel_cfg.encap_type = 1;
2970         tunnel_cfg.encap_rcv = amt_rcv;
2971         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2972         tunnel_cfg.encap_destroy = NULL;
2973         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2974
2975         rcu_assign_pointer(amt->sock, sock);
2976         return 0;
2977 }
2978
2979 static int amt_dev_open(struct net_device *dev)
2980 {
2981         struct amt_dev *amt = netdev_priv(dev);
2982         int err;
2983
2984         amt->ready4 = false;
2985         amt->ready6 = false;
2986         amt->event_idx = 0;
2987         amt->nr_events = 0;
2988
2989         err = amt_socket_create(amt);
2990         if (err)
2991                 return err;
2992
2993         amt->req_cnt = 0;
2994         amt->remote_ip = 0;
2995         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2996
2997         amt->status = AMT_STATUS_INIT;
2998         if (amt->mode == AMT_MODE_GATEWAY) {
2999                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
3000                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
3001         } else if (amt->mode == AMT_MODE_RELAY) {
3002                 mod_delayed_work(amt_wq, &amt->secret_wq,
3003                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
3004         }
3005         return err;
3006 }
3007
3008 static int amt_dev_stop(struct net_device *dev)
3009 {
3010         struct amt_dev *amt = netdev_priv(dev);
3011         struct amt_tunnel_list *tunnel, *tmp;
3012         struct socket *sock;
3013         struct sk_buff *skb;
3014         int i;
3015
3016         cancel_delayed_work_sync(&amt->req_wq);
3017         cancel_delayed_work_sync(&amt->discovery_wq);
3018         cancel_delayed_work_sync(&amt->secret_wq);
3019
3020         /* shutdown */
3021         sock = rtnl_dereference(amt->sock);
3022         RCU_INIT_POINTER(amt->sock, NULL);
3023         synchronize_net();
3024         if (sock)
3025                 udp_tunnel_sock_release(sock);
3026
3027         cancel_work_sync(&amt->event_wq);
3028         for (i = 0; i < AMT_MAX_EVENTS; i++) {
3029                 skb = amt->events[i].skb;
3030                 if (skb)
3031                         kfree_skb(skb);
3032                 amt->events[i].event = AMT_EVENT_NONE;
3033                 amt->events[i].skb = NULL;
3034         }
3035
3036         amt->ready4 = false;
3037         amt->ready6 = false;
3038         amt->req_cnt = 0;
3039         amt->remote_ip = 0;
3040
3041         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
3042                 list_del_rcu(&tunnel->list);
3043                 amt->nr_tunnels--;
3044                 cancel_delayed_work_sync(&tunnel->gc_wq);
3045                 amt_clear_groups(tunnel);
3046                 kfree_rcu(tunnel, rcu);
3047         }
3048
3049         return 0;
3050 }
3051
3052 static const struct device_type amt_type = {
3053         .name = "amt",
3054 };
3055
3056 static int amt_dev_init(struct net_device *dev)
3057 {
3058         struct amt_dev *amt = netdev_priv(dev);
3059         int err;
3060
3061         amt->dev = dev;
3062         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
3063         if (!dev->tstats)
3064                 return -ENOMEM;
3065
3066         err = gro_cells_init(&amt->gro_cells, dev);
3067         if (err) {
3068                 free_percpu(dev->tstats);
3069                 return err;
3070         }
3071
3072         return 0;
3073 }
3074
3075 static void amt_dev_uninit(struct net_device *dev)
3076 {
3077         struct amt_dev *amt = netdev_priv(dev);
3078
3079         gro_cells_destroy(&amt->gro_cells);
3080         free_percpu(dev->tstats);
3081 }
3082
3083 static const struct net_device_ops amt_netdev_ops = {
3084         .ndo_init               = amt_dev_init,
3085         .ndo_uninit             = amt_dev_uninit,
3086         .ndo_open               = amt_dev_open,
3087         .ndo_stop               = amt_dev_stop,
3088         .ndo_start_xmit         = amt_dev_xmit,
3089         .ndo_get_stats64        = dev_get_tstats64,
3090 };
3091
3092 static void amt_link_setup(struct net_device *dev)
3093 {
3094         dev->netdev_ops         = &amt_netdev_ops;
3095         dev->needs_free_netdev  = true;
3096         SET_NETDEV_DEVTYPE(dev, &amt_type);
3097         dev->min_mtu            = ETH_MIN_MTU;
3098         dev->max_mtu            = ETH_MAX_MTU;
3099         dev->type               = ARPHRD_NONE;
3100         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
3101         dev->hard_header_len    = 0;
3102         dev->addr_len           = 0;
3103         dev->priv_flags         |= IFF_NO_QUEUE;
3104         dev->features           |= NETIF_F_LLTX;
3105         dev->features           |= NETIF_F_GSO_SOFTWARE;
3106         dev->features           |= NETIF_F_NETNS_LOCAL;
3107         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
3108         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
3109         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
3110         eth_hw_addr_random(dev);
3111         eth_zero_addr(dev->broadcast);
3112         ether_setup(dev);
3113 }
3114
3115 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
3116         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
3117         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
3118         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
3119         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
3120         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
3121         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
3122         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
3123         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
3124 };
3125
3126 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
3127                         struct netlink_ext_ack *extack)
3128 {
3129         if (!data)
3130                 return -EINVAL;
3131
3132         if (!data[IFLA_AMT_LINK]) {
3133                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
3134                                     "Link attribute is required");
3135                 return -EINVAL;
3136         }
3137
3138         if (!data[IFLA_AMT_MODE]) {
3139                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3140                                     "Mode attribute is required");
3141                 return -EINVAL;
3142         }
3143
3144         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
3145                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3146                                     "Mode attribute is not valid");
3147                 return -EINVAL;
3148         }
3149
3150         if (!data[IFLA_AMT_LOCAL_IP]) {
3151                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3152                                     "Local attribute is required");
3153                 return -EINVAL;
3154         }
3155
3156         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3157             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3158                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3159                                     "Discovery attribute is required");
3160                 return -EINVAL;
3161         }
3162
3163         return 0;
3164 }
3165
3166 static int amt_newlink(struct net *net, struct net_device *dev,
3167                        struct nlattr *tb[], struct nlattr *data[],
3168                        struct netlink_ext_ack *extack)
3169 {
3170         struct amt_dev *amt = netdev_priv(dev);
3171         int err = -EINVAL;
3172
3173         amt->net = net;
3174         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3175
3176         if (data[IFLA_AMT_MAX_TUNNELS] &&
3177             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3178                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3179         else
3180                 amt->max_tunnels = AMT_MAX_TUNNELS;
3181
3182         spin_lock_init(&amt->lock);
3183         amt->max_groups = AMT_MAX_GROUP;
3184         amt->max_sources = AMT_MAX_SOURCE;
3185         amt->hash_buckets = AMT_HSIZE;
3186         amt->nr_tunnels = 0;
3187         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3188         amt->stream_dev = dev_get_by_index(net,
3189                                            nla_get_u32(data[IFLA_AMT_LINK]));
3190         if (!amt->stream_dev) {
3191                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3192                                     "Can't find stream device");
3193                 return -ENODEV;
3194         }
3195
3196         if (amt->stream_dev->type != ARPHRD_ETHER) {
3197                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3198                                     "Invalid stream device type");
3199                 goto err;
3200         }
3201
3202         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3203         if (ipv4_is_loopback(amt->local_ip) ||
3204             ipv4_is_zeronet(amt->local_ip) ||
3205             ipv4_is_multicast(amt->local_ip)) {
3206                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3207                                     "Invalid Local address");
3208                 goto err;
3209         }
3210
3211         if (data[IFLA_AMT_RELAY_PORT])
3212                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3213         else
3214                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3215
3216         if (data[IFLA_AMT_GATEWAY_PORT])
3217                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3218         else
3219                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3220
3221         if (!amt->relay_port) {
3222                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3223                                     "relay port must not be 0");
3224                 goto err;
3225         }
3226         if (amt->mode == AMT_MODE_RELAY) {
3227                 amt->qrv = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
3228                 amt->qri = 10;
3229                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3230                                        AMT_RELAY_HLEN;
3231                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3232                 dev->max_mtu = dev->mtu;
3233                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3234         } else {
3235                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3236                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3237                                             "discovery must be set in gateway mode");
3238                         goto err;
3239                 }
3240                 if (!amt->gw_port) {
3241                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3242                                             "gateway port must not be 0");
3243                         goto err;
3244                 }
3245                 amt->remote_ip = 0;
3246                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3247                 if (ipv4_is_loopback(amt->discovery_ip) ||
3248                     ipv4_is_zeronet(amt->discovery_ip) ||
3249                     ipv4_is_multicast(amt->discovery_ip)) {
3250                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3251                                             "discovery must be unicast");
3252                         goto err;
3253                 }
3254
3255                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3256                                        AMT_GW_HLEN;
3257                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3258                 dev->max_mtu = dev->mtu;
3259                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3260         }
3261         amt->qi = AMT_INIT_QUERY_INTERVAL;
3262
3263         err = register_netdevice(dev);
3264         if (err < 0) {
3265                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3266                 goto err;
3267         }
3268
3269         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3270         if (err < 0) {
3271                 unregister_netdevice(dev);
3272                 goto err;
3273         }
3274
3275         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3276         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3277         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3278         INIT_WORK(&amt->event_wq, amt_event_work);
3279         INIT_LIST_HEAD(&amt->tunnel_list);
3280         return 0;
3281 err:
3282         dev_put(amt->stream_dev);
3283         return err;
3284 }
3285
3286 static void amt_dellink(struct net_device *dev, struct list_head *head)
3287 {
3288         struct amt_dev *amt = netdev_priv(dev);
3289
3290         unregister_netdevice_queue(dev, head);
3291         netdev_upper_dev_unlink(amt->stream_dev, dev);
3292         dev_put(amt->stream_dev);
3293 }
3294
3295 static size_t amt_get_size(const struct net_device *dev)
3296 {
3297         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3298                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3299                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3300                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3301                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3302                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3303                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3304                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3305 }
3306
3307 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3308 {
3309         struct amt_dev *amt = netdev_priv(dev);
3310
3311         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3312                 goto nla_put_failure;
3313         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3314                 goto nla_put_failure;
3315         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3316                 goto nla_put_failure;
3317         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3318                 goto nla_put_failure;
3319         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3320                 goto nla_put_failure;
3321         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3322                 goto nla_put_failure;
3323         if (amt->remote_ip)
3324                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3325                         goto nla_put_failure;
3326         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3327                 goto nla_put_failure;
3328
3329         return 0;
3330
3331 nla_put_failure:
3332         return -EMSGSIZE;
3333 }
3334
3335 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3336         .kind           = "amt",
3337         .maxtype        = IFLA_AMT_MAX,
3338         .policy         = amt_policy,
3339         .priv_size      = sizeof(struct amt_dev),
3340         .setup          = amt_link_setup,
3341         .validate       = amt_validate,
3342         .newlink        = amt_newlink,
3343         .dellink        = amt_dellink,
3344         .get_size       = amt_get_size,
3345         .fill_info      = amt_fill_info,
3346 };
3347
3348 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3349 {
3350         struct net_device *upper_dev;
3351         struct amt_dev *amt;
3352
3353         for_each_netdev(dev_net(dev), upper_dev) {
3354                 if (netif_is_amt(upper_dev)) {
3355                         amt = netdev_priv(upper_dev);
3356                         if (amt->stream_dev == dev)
3357                                 return upper_dev;
3358                 }
3359         }
3360
3361         return NULL;
3362 }
3363
3364 static int amt_device_event(struct notifier_block *unused,
3365                             unsigned long event, void *ptr)
3366 {
3367         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3368         struct net_device *upper_dev;
3369         struct amt_dev *amt;
3370         LIST_HEAD(list);
3371         int new_mtu;
3372
3373         upper_dev = amt_lookup_upper_dev(dev);
3374         if (!upper_dev)
3375                 return NOTIFY_DONE;
3376         amt = netdev_priv(upper_dev);
3377
3378         switch (event) {
3379         case NETDEV_UNREGISTER:
3380                 amt_dellink(amt->dev, &list);
3381                 unregister_netdevice_many(&list);
3382                 break;
3383         case NETDEV_CHANGEMTU:
3384                 if (amt->mode == AMT_MODE_RELAY)
3385                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3386                 else
3387                         new_mtu = dev->mtu - AMT_GW_HLEN;
3388
3389                 dev_set_mtu(amt->dev, new_mtu);
3390                 break;
3391         }
3392
3393         return NOTIFY_DONE;
3394 }
3395
3396 static struct notifier_block amt_notifier_block __read_mostly = {
3397         .notifier_call = amt_device_event,
3398 };
3399
3400 static int __init amt_init(void)
3401 {
3402         int err;
3403
3404         err = register_netdevice_notifier(&amt_notifier_block);
3405         if (err < 0)
3406                 goto err;
3407
3408         err = rtnl_link_register(&amt_link_ops);
3409         if (err < 0)
3410                 goto unregister_notifier;
3411
3412         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 0);
3413         if (!amt_wq) {
3414                 err = -ENOMEM;
3415                 goto rtnl_unregister;
3416         }
3417
3418         spin_lock_init(&source_gc_lock);
3419         spin_lock_bh(&source_gc_lock);
3420         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3421         mod_delayed_work(amt_wq, &source_gc_wq,
3422                          msecs_to_jiffies(AMT_GC_INTERVAL));
3423         spin_unlock_bh(&source_gc_lock);
3424
3425         return 0;
3426
3427 rtnl_unregister:
3428         rtnl_link_unregister(&amt_link_ops);
3429 unregister_notifier:
3430         unregister_netdevice_notifier(&amt_notifier_block);
3431 err:
3432         pr_err("error loading AMT module loaded\n");
3433         return err;
3434 }
3435 late_initcall(amt_init);
3436
3437 static void __exit amt_fini(void)
3438 {
3439         rtnl_link_unregister(&amt_link_ops);
3440         unregister_netdevice_notifier(&amt_notifier_block);
3441         cancel_delayed_work_sync(&source_gc_wq);
3442         __amt_source_gc_work();
3443         destroy_workqueue(amt_wq);
3444 }
3445 module_exit(amt_fini);
3446
3447 MODULE_LICENSE("GPL");
3448 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3449 MODULE_ALIAS_RTNL_LINK("amt");