Merge tag '6.7-rc-ksmbd-server-fixes' of git://git.samba.org/ksmbd
[linux-block.git] / drivers / net / vxlan / vxlan_mdb.c
1 // SPDX-License-Identifier: GPL-2.0-only
2
3 #include <linux/if_bridge.h>
4 #include <linux/in.h>
5 #include <linux/list.h>
6 #include <linux/netdevice.h>
7 #include <linux/netlink.h>
8 #include <linux/rhashtable.h>
9 #include <linux/rhashtable-types.h>
10 #include <linux/rtnetlink.h>
11 #include <linux/skbuff.h>
12 #include <linux/types.h>
13 #include <net/netlink.h>
14 #include <net/vxlan.h>
15
16 #include "vxlan_private.h"
17
18 struct vxlan_mdb_entry_key {
19         union vxlan_addr src;
20         union vxlan_addr dst;
21         __be32 vni;
22 };
23
24 struct vxlan_mdb_entry {
25         struct rhash_head rhnode;
26         struct list_head remotes;
27         struct vxlan_mdb_entry_key key;
28         struct hlist_node mdb_node;
29         struct rcu_head rcu;
30 };
31
32 #define VXLAN_MDB_REMOTE_F_BLOCKED      BIT(0)
33
34 struct vxlan_mdb_remote {
35         struct list_head list;
36         struct vxlan_rdst __rcu *rd;
37         u8 flags;
38         u8 filter_mode;
39         u8 rt_protocol;
40         struct hlist_head src_list;
41         struct rcu_head rcu;
42 };
43
44 #define VXLAN_SGRP_F_DELETE     BIT(0)
45
46 struct vxlan_mdb_src_entry {
47         struct hlist_node node;
48         union vxlan_addr addr;
49         u8 flags;
50 };
51
52 struct vxlan_mdb_dump_ctx {
53         long reserved;
54         long entry_idx;
55         long remote_idx;
56 };
57
58 struct vxlan_mdb_config_src_entry {
59         union vxlan_addr addr;
60         struct list_head node;
61 };
62
63 struct vxlan_mdb_config {
64         struct vxlan_dev *vxlan;
65         struct vxlan_mdb_entry_key group;
66         struct list_head src_list;
67         union vxlan_addr remote_ip;
68         u32 remote_ifindex;
69         __be32 remote_vni;
70         __be16 remote_port;
71         u16 nlflags;
72         u8 flags;
73         u8 filter_mode;
74         u8 rt_protocol;
75 };
76
77 static const struct rhashtable_params vxlan_mdb_rht_params = {
78         .head_offset = offsetof(struct vxlan_mdb_entry, rhnode),
79         .key_offset = offsetof(struct vxlan_mdb_entry, key),
80         .key_len = sizeof(struct vxlan_mdb_entry_key),
81         .automatic_shrinking = true,
82 };
83
84 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
85                            struct netlink_ext_ack *extack);
86 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
87                            struct netlink_ext_ack *extack);
88
89 static void vxlan_br_mdb_entry_fill(const struct vxlan_dev *vxlan,
90                                     const struct vxlan_mdb_entry *mdb_entry,
91                                     const struct vxlan_mdb_remote *remote,
92                                     struct br_mdb_entry *e)
93 {
94         const union vxlan_addr *dst = &mdb_entry->key.dst;
95
96         memset(e, 0, sizeof(*e));
97         e->ifindex = vxlan->dev->ifindex;
98         e->state = MDB_PERMANENT;
99
100         if (remote->flags & VXLAN_MDB_REMOTE_F_BLOCKED)
101                 e->flags |= MDB_FLAGS_BLOCKED;
102
103         switch (dst->sa.sa_family) {
104         case AF_INET:
105                 e->addr.u.ip4 = dst->sin.sin_addr.s_addr;
106                 e->addr.proto = htons(ETH_P_IP);
107                 break;
108 #if IS_ENABLED(CONFIG_IPV6)
109         case AF_INET6:
110                 e->addr.u.ip6 = dst->sin6.sin6_addr;
111                 e->addr.proto = htons(ETH_P_IPV6);
112                 break;
113 #endif
114         }
115 }
116
117 static int vxlan_mdb_entry_info_fill_srcs(struct sk_buff *skb,
118                                           const struct vxlan_mdb_remote *remote)
119 {
120         struct vxlan_mdb_src_entry *ent;
121         struct nlattr *nest;
122
123         if (hlist_empty(&remote->src_list))
124                 return 0;
125
126         nest = nla_nest_start(skb, MDBA_MDB_EATTR_SRC_LIST);
127         if (!nest)
128                 return -EMSGSIZE;
129
130         hlist_for_each_entry(ent, &remote->src_list, node) {
131                 struct nlattr *nest_ent;
132
133                 nest_ent = nla_nest_start(skb, MDBA_MDB_SRCLIST_ENTRY);
134                 if (!nest_ent)
135                         goto out_cancel_err;
136
137                 if (vxlan_nla_put_addr(skb, MDBA_MDB_SRCATTR_ADDRESS,
138                                        &ent->addr) ||
139                     nla_put_u32(skb, MDBA_MDB_SRCATTR_TIMER, 0))
140                         goto out_cancel_err;
141
142                 nla_nest_end(skb, nest_ent);
143         }
144
145         nla_nest_end(skb, nest);
146
147         return 0;
148
149 out_cancel_err:
150         nla_nest_cancel(skb, nest);
151         return -EMSGSIZE;
152 }
153
154 static int vxlan_mdb_entry_info_fill(const struct vxlan_dev *vxlan,
155                                      struct sk_buff *skb,
156                                      const struct vxlan_mdb_entry *mdb_entry,
157                                      const struct vxlan_mdb_remote *remote)
158 {
159         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
160         struct br_mdb_entry e;
161         struct nlattr *nest;
162
163         nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY_INFO);
164         if (!nest)
165                 return -EMSGSIZE;
166
167         vxlan_br_mdb_entry_fill(vxlan, mdb_entry, remote, &e);
168
169         if (nla_put_nohdr(skb, sizeof(e), &e) ||
170             nla_put_u32(skb, MDBA_MDB_EATTR_TIMER, 0))
171                 goto nest_err;
172
173         if (!vxlan_addr_any(&mdb_entry->key.src) &&
174             vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_SOURCE, &mdb_entry->key.src))
175                 goto nest_err;
176
177         if (nla_put_u8(skb, MDBA_MDB_EATTR_RTPROT, remote->rt_protocol) ||
178             nla_put_u8(skb, MDBA_MDB_EATTR_GROUP_MODE, remote->filter_mode) ||
179             vxlan_mdb_entry_info_fill_srcs(skb, remote) ||
180             vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_DST, &rd->remote_ip))
181                 goto nest_err;
182
183         if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port &&
184             nla_put_u16(skb, MDBA_MDB_EATTR_DST_PORT,
185                         be16_to_cpu(rd->remote_port)))
186                 goto nest_err;
187
188         if (rd->remote_vni != vxlan->default_dst.remote_vni &&
189             nla_put_u32(skb, MDBA_MDB_EATTR_VNI, be32_to_cpu(rd->remote_vni)))
190                 goto nest_err;
191
192         if (rd->remote_ifindex &&
193             nla_put_u32(skb, MDBA_MDB_EATTR_IFINDEX, rd->remote_ifindex))
194                 goto nest_err;
195
196         if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) &&
197             mdb_entry->key.vni && nla_put_u32(skb, MDBA_MDB_EATTR_SRC_VNI,
198                                               be32_to_cpu(mdb_entry->key.vni)))
199                 goto nest_err;
200
201         nla_nest_end(skb, nest);
202
203         return 0;
204
205 nest_err:
206         nla_nest_cancel(skb, nest);
207         return -EMSGSIZE;
208 }
209
210 static int vxlan_mdb_entry_fill(const struct vxlan_dev *vxlan,
211                                 struct sk_buff *skb,
212                                 struct vxlan_mdb_dump_ctx *ctx,
213                                 const struct vxlan_mdb_entry *mdb_entry)
214 {
215         int remote_idx = 0, s_remote_idx = ctx->remote_idx;
216         struct vxlan_mdb_remote *remote;
217         struct nlattr *nest;
218         int err = 0;
219
220         nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
221         if (!nest)
222                 return -EMSGSIZE;
223
224         list_for_each_entry(remote, &mdb_entry->remotes, list) {
225                 if (remote_idx < s_remote_idx)
226                         goto skip;
227
228                 err = vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote);
229                 if (err)
230                         break;
231 skip:
232                 remote_idx++;
233         }
234
235         ctx->remote_idx = err ? remote_idx : 0;
236         nla_nest_end(skb, nest);
237         return err;
238 }
239
240 static int vxlan_mdb_fill(const struct vxlan_dev *vxlan, struct sk_buff *skb,
241                           struct vxlan_mdb_dump_ctx *ctx)
242 {
243         int entry_idx = 0, s_entry_idx = ctx->entry_idx;
244         struct vxlan_mdb_entry *mdb_entry;
245         struct nlattr *nest;
246         int err = 0;
247
248         nest = nla_nest_start_noflag(skb, MDBA_MDB);
249         if (!nest)
250                 return -EMSGSIZE;
251
252         hlist_for_each_entry(mdb_entry, &vxlan->mdb_list, mdb_node) {
253                 if (entry_idx < s_entry_idx)
254                         goto skip;
255
256                 err = vxlan_mdb_entry_fill(vxlan, skb, ctx, mdb_entry);
257                 if (err)
258                         break;
259 skip:
260                 entry_idx++;
261         }
262
263         ctx->entry_idx = err ? entry_idx : 0;
264         nla_nest_end(skb, nest);
265         return err;
266 }
267
268 int vxlan_mdb_dump(struct net_device *dev, struct sk_buff *skb,
269                    struct netlink_callback *cb)
270 {
271         struct vxlan_mdb_dump_ctx *ctx = (void *)cb->ctx;
272         struct vxlan_dev *vxlan = netdev_priv(dev);
273         struct br_port_msg *bpm;
274         struct nlmsghdr *nlh;
275         int err;
276
277         ASSERT_RTNL();
278
279         NL_ASSERT_DUMP_CTX_FITS(struct vxlan_mdb_dump_ctx);
280
281         nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
282                         cb->nlh->nlmsg_seq, RTM_NEWMDB, sizeof(*bpm),
283                         NLM_F_MULTI);
284         if (!nlh)
285                 return -EMSGSIZE;
286
287         bpm = nlmsg_data(nlh);
288         memset(bpm, 0, sizeof(*bpm));
289         bpm->family = AF_BRIDGE;
290         bpm->ifindex = dev->ifindex;
291
292         err = vxlan_mdb_fill(vxlan, skb, ctx);
293
294         nlmsg_end(skb, nlh);
295
296         cb->seq = vxlan->mdb_seq;
297         nl_dump_check_consistent(cb, nlh);
298
299         return err;
300 }
301
302 static const struct nla_policy
303 vxlan_mdbe_src_list_entry_pol[MDBE_SRCATTR_MAX + 1] = {
304         [MDBE_SRCATTR_ADDRESS] = NLA_POLICY_RANGE(NLA_BINARY,
305                                                   sizeof(struct in_addr),
306                                                   sizeof(struct in6_addr)),
307 };
308
309 static const struct nla_policy
310 vxlan_mdbe_src_list_pol[MDBE_SRC_LIST_MAX + 1] = {
311         [MDBE_SRC_LIST_ENTRY] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_entry_pol),
312 };
313
314 static const struct netlink_range_validation vni_range = {
315         .max = VXLAN_N_VID - 1,
316 };
317
318 static const struct nla_policy vxlan_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = {
319         [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
320                                               sizeof(struct in_addr),
321                                               sizeof(struct in6_addr)),
322         [MDBE_ATTR_GROUP_MODE] = NLA_POLICY_RANGE(NLA_U8, MCAST_EXCLUDE,
323                                                   MCAST_INCLUDE),
324         [MDBE_ATTR_SRC_LIST] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_pol),
325         [MDBE_ATTR_RTPROT] = NLA_POLICY_MIN(NLA_U8, RTPROT_STATIC),
326         [MDBE_ATTR_DST] = NLA_POLICY_RANGE(NLA_BINARY,
327                                            sizeof(struct in_addr),
328                                            sizeof(struct in6_addr)),
329         [MDBE_ATTR_DST_PORT] = { .type = NLA_U16 },
330         [MDBE_ATTR_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
331         [MDBE_ATTR_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 1),
332         [MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
333 };
334
335 static bool vxlan_mdb_is_valid_source(const struct nlattr *attr, __be16 proto,
336                                       struct netlink_ext_ack *extack)
337 {
338         switch (proto) {
339         case htons(ETH_P_IP):
340                 if (nla_len(attr) != sizeof(struct in_addr)) {
341                         NL_SET_ERR_MSG_MOD(extack, "IPv4 invalid source address length");
342                         return false;
343                 }
344                 if (ipv4_is_multicast(nla_get_in_addr(attr))) {
345                         NL_SET_ERR_MSG_MOD(extack, "IPv4 multicast source address is not allowed");
346                         return false;
347                 }
348                 break;
349 #if IS_ENABLED(CONFIG_IPV6)
350         case htons(ETH_P_IPV6): {
351                 struct in6_addr src;
352
353                 if (nla_len(attr) != sizeof(struct in6_addr)) {
354                         NL_SET_ERR_MSG_MOD(extack, "IPv6 invalid source address length");
355                         return false;
356                 }
357                 src = nla_get_in6_addr(attr);
358                 if (ipv6_addr_is_multicast(&src)) {
359                         NL_SET_ERR_MSG_MOD(extack, "IPv6 multicast source address is not allowed");
360                         return false;
361                 }
362                 break;
363         }
364 #endif
365         default:
366                 NL_SET_ERR_MSG_MOD(extack, "Invalid protocol used with source address");
367                 return false;
368         }
369
370         return true;
371 }
372
373 static void vxlan_mdb_group_set(struct vxlan_mdb_entry_key *group,
374                                 const struct br_mdb_entry *entry,
375                                 const struct nlattr *source_attr)
376 {
377         switch (entry->addr.proto) {
378         case htons(ETH_P_IP):
379                 group->dst.sa.sa_family = AF_INET;
380                 group->dst.sin.sin_addr.s_addr = entry->addr.u.ip4;
381                 break;
382 #if IS_ENABLED(CONFIG_IPV6)
383         case htons(ETH_P_IPV6):
384                 group->dst.sa.sa_family = AF_INET6;
385                 group->dst.sin6.sin6_addr = entry->addr.u.ip6;
386                 break;
387 #endif
388         }
389
390         if (source_attr)
391                 vxlan_nla_get_addr(&group->src, source_attr);
392 }
393
394 static bool vxlan_mdb_is_star_g(const struct vxlan_mdb_entry_key *group)
395 {
396         return !vxlan_addr_any(&group->dst) && vxlan_addr_any(&group->src);
397 }
398
399 static bool vxlan_mdb_is_sg(const struct vxlan_mdb_entry_key *group)
400 {
401         return !vxlan_addr_any(&group->dst) && !vxlan_addr_any(&group->src);
402 }
403
404 static int vxlan_mdb_config_src_entry_init(struct vxlan_mdb_config *cfg,
405                                            __be16 proto,
406                                            const struct nlattr *src_entry,
407                                            struct netlink_ext_ack *extack)
408 {
409         struct nlattr *tb[MDBE_SRCATTR_MAX + 1];
410         struct vxlan_mdb_config_src_entry *src;
411         int err;
412
413         err = nla_parse_nested(tb, MDBE_SRCATTR_MAX, src_entry,
414                                vxlan_mdbe_src_list_entry_pol, extack);
415         if (err)
416                 return err;
417
418         if (NL_REQ_ATTR_CHECK(extack, src_entry, tb, MDBE_SRCATTR_ADDRESS))
419                 return -EINVAL;
420
421         if (!vxlan_mdb_is_valid_source(tb[MDBE_SRCATTR_ADDRESS], proto,
422                                        extack))
423                 return -EINVAL;
424
425         src = kzalloc(sizeof(*src), GFP_KERNEL);
426         if (!src)
427                 return -ENOMEM;
428
429         err = vxlan_nla_get_addr(&src->addr, tb[MDBE_SRCATTR_ADDRESS]);
430         if (err)
431                 goto err_free_src;
432
433         list_add_tail(&src->node, &cfg->src_list);
434
435         return 0;
436
437 err_free_src:
438         kfree(src);
439         return err;
440 }
441
442 static void
443 vxlan_mdb_config_src_entry_fini(struct vxlan_mdb_config_src_entry *src)
444 {
445         list_del(&src->node);
446         kfree(src);
447 }
448
449 static int vxlan_mdb_config_src_list_init(struct vxlan_mdb_config *cfg,
450                                           __be16 proto,
451                                           const struct nlattr *src_list,
452                                           struct netlink_ext_ack *extack)
453 {
454         struct vxlan_mdb_config_src_entry *src, *tmp;
455         struct nlattr *src_entry;
456         int rem, err;
457
458         nla_for_each_nested(src_entry, src_list, rem) {
459                 err = vxlan_mdb_config_src_entry_init(cfg, proto, src_entry,
460                                                       extack);
461                 if (err)
462                         goto err_src_entry_init;
463         }
464
465         return 0;
466
467 err_src_entry_init:
468         list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
469                 vxlan_mdb_config_src_entry_fini(src);
470         return err;
471 }
472
473 static void vxlan_mdb_config_src_list_fini(struct vxlan_mdb_config *cfg)
474 {
475         struct vxlan_mdb_config_src_entry *src, *tmp;
476
477         list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
478                 vxlan_mdb_config_src_entry_fini(src);
479 }
480
481 static int vxlan_mdb_config_attrs_init(struct vxlan_mdb_config *cfg,
482                                        const struct br_mdb_entry *entry,
483                                        const struct nlattr *set_attrs,
484                                        struct netlink_ext_ack *extack)
485 {
486         struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
487         int err;
488
489         err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX, set_attrs,
490                                vxlan_mdbe_attrs_pol, extack);
491         if (err)
492                 return err;
493
494         if (NL_REQ_ATTR_CHECK(extack, set_attrs, mdbe_attrs, MDBE_ATTR_DST)) {
495                 NL_SET_ERR_MSG_MOD(extack, "Missing remote destination IP address");
496                 return -EINVAL;
497         }
498
499         if (mdbe_attrs[MDBE_ATTR_SOURCE] &&
500             !vxlan_mdb_is_valid_source(mdbe_attrs[MDBE_ATTR_SOURCE],
501                                        entry->addr.proto, extack))
502                 return -EINVAL;
503
504         vxlan_mdb_group_set(&cfg->group, entry, mdbe_attrs[MDBE_ATTR_SOURCE]);
505
506         /* rtnetlink code only validates that IPv4 group address is
507          * multicast.
508          */
509         if (!vxlan_addr_is_multicast(&cfg->group.dst) &&
510             !vxlan_addr_any(&cfg->group.dst)) {
511                 NL_SET_ERR_MSG_MOD(extack, "Group address is not multicast");
512                 return -EINVAL;
513         }
514
515         if (vxlan_addr_any(&cfg->group.dst) &&
516             mdbe_attrs[MDBE_ATTR_SOURCE]) {
517                 NL_SET_ERR_MSG_MOD(extack, "Source cannot be specified for the all-zeros entry");
518                 return -EINVAL;
519         }
520
521         if (vxlan_mdb_is_sg(&cfg->group))
522                 cfg->filter_mode = MCAST_INCLUDE;
523
524         if (mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
525                 if (!vxlan_mdb_is_star_g(&cfg->group)) {
526                         NL_SET_ERR_MSG_MOD(extack, "Filter mode can only be set for (*, G) entries");
527                         return -EINVAL;
528                 }
529                 cfg->filter_mode = nla_get_u8(mdbe_attrs[MDBE_ATTR_GROUP_MODE]);
530         }
531
532         if (mdbe_attrs[MDBE_ATTR_SRC_LIST]) {
533                 if (!vxlan_mdb_is_star_g(&cfg->group)) {
534                         NL_SET_ERR_MSG_MOD(extack, "Source list can only be set for (*, G) entries");
535                         return -EINVAL;
536                 }
537                 if (!mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
538                         NL_SET_ERR_MSG_MOD(extack, "Source list cannot be set without filter mode");
539                         return -EINVAL;
540                 }
541                 err = vxlan_mdb_config_src_list_init(cfg, entry->addr.proto,
542                                                      mdbe_attrs[MDBE_ATTR_SRC_LIST],
543                                                      extack);
544                 if (err)
545                         return err;
546         }
547
548         if (vxlan_mdb_is_star_g(&cfg->group) && list_empty(&cfg->src_list) &&
549             cfg->filter_mode == MCAST_INCLUDE) {
550                 NL_SET_ERR_MSG_MOD(extack, "Cannot add (*, G) INCLUDE with an empty source list");
551                 return -EINVAL;
552         }
553
554         if (mdbe_attrs[MDBE_ATTR_RTPROT])
555                 cfg->rt_protocol = nla_get_u8(mdbe_attrs[MDBE_ATTR_RTPROT]);
556
557         err = vxlan_nla_get_addr(&cfg->remote_ip, mdbe_attrs[MDBE_ATTR_DST]);
558         if (err) {
559                 NL_SET_ERR_MSG_MOD(extack, "Invalid remote destination address");
560                 goto err_src_list_fini;
561         }
562
563         if (mdbe_attrs[MDBE_ATTR_DST_PORT])
564                 cfg->remote_port =
565                         cpu_to_be16(nla_get_u16(mdbe_attrs[MDBE_ATTR_DST_PORT]));
566
567         if (mdbe_attrs[MDBE_ATTR_VNI])
568                 cfg->remote_vni =
569                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_VNI]));
570
571         if (mdbe_attrs[MDBE_ATTR_IFINDEX]) {
572                 cfg->remote_ifindex =
573                         nla_get_s32(mdbe_attrs[MDBE_ATTR_IFINDEX]);
574                 if (!__dev_get_by_index(cfg->vxlan->net, cfg->remote_ifindex)) {
575                         NL_SET_ERR_MSG_MOD(extack, "Outgoing interface not found");
576                         err = -EINVAL;
577                         goto err_src_list_fini;
578                 }
579         }
580
581         if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
582                 cfg->group.vni =
583                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
584
585         return 0;
586
587 err_src_list_fini:
588         vxlan_mdb_config_src_list_fini(cfg);
589         return err;
590 }
591
592 static int vxlan_mdb_config_init(struct vxlan_mdb_config *cfg,
593                                  struct net_device *dev, struct nlattr *tb[],
594                                  u16 nlmsg_flags,
595                                  struct netlink_ext_ack *extack)
596 {
597         struct br_mdb_entry *entry = nla_data(tb[MDBA_SET_ENTRY]);
598         struct vxlan_dev *vxlan = netdev_priv(dev);
599
600         memset(cfg, 0, sizeof(*cfg));
601         cfg->vxlan = vxlan;
602         cfg->group.vni = vxlan->default_dst.remote_vni;
603         INIT_LIST_HEAD(&cfg->src_list);
604         cfg->nlflags = nlmsg_flags;
605         cfg->filter_mode = MCAST_EXCLUDE;
606         cfg->rt_protocol = RTPROT_STATIC;
607         cfg->remote_vni = vxlan->default_dst.remote_vni;
608         cfg->remote_port = vxlan->cfg.dst_port;
609
610         if (entry->ifindex != dev->ifindex) {
611                 NL_SET_ERR_MSG_MOD(extack, "Port net device must be the VXLAN net device");
612                 return -EINVAL;
613         }
614
615         /* State is not part of the entry key and can be ignored on deletion
616          * requests.
617          */
618         if ((nlmsg_flags & (NLM_F_CREATE | NLM_F_REPLACE)) &&
619             entry->state != MDB_PERMANENT) {
620                 NL_SET_ERR_MSG_MOD(extack, "MDB entry must be permanent");
621                 return -EINVAL;
622         }
623
624         if (entry->flags) {
625                 NL_SET_ERR_MSG_MOD(extack, "Invalid MDB entry flags");
626                 return -EINVAL;
627         }
628
629         if (entry->vid) {
630                 NL_SET_ERR_MSG_MOD(extack, "VID must not be specified");
631                 return -EINVAL;
632         }
633
634         if (entry->addr.proto != htons(ETH_P_IP) &&
635             entry->addr.proto != htons(ETH_P_IPV6)) {
636                 NL_SET_ERR_MSG_MOD(extack, "Group address must be an IPv4 / IPv6 address");
637                 return -EINVAL;
638         }
639
640         if (NL_REQ_ATTR_CHECK(extack, NULL, tb, MDBA_SET_ENTRY_ATTRS)) {
641                 NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY_ATTRS attribute");
642                 return -EINVAL;
643         }
644
645         return vxlan_mdb_config_attrs_init(cfg, entry, tb[MDBA_SET_ENTRY_ATTRS],
646                                            extack);
647 }
648
649 static void vxlan_mdb_config_fini(struct vxlan_mdb_config *cfg)
650 {
651         vxlan_mdb_config_src_list_fini(cfg);
652 }
653
654 static struct vxlan_mdb_entry *
655 vxlan_mdb_entry_lookup(struct vxlan_dev *vxlan,
656                        const struct vxlan_mdb_entry_key *group)
657 {
658         return rhashtable_lookup_fast(&vxlan->mdb_tbl, group,
659                                       vxlan_mdb_rht_params);
660 }
661
662 static struct vxlan_mdb_remote *
663 vxlan_mdb_remote_lookup(const struct vxlan_mdb_entry *mdb_entry,
664                         const union vxlan_addr *addr)
665 {
666         struct vxlan_mdb_remote *remote;
667
668         list_for_each_entry(remote, &mdb_entry->remotes, list) {
669                 struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
670
671                 if (vxlan_addr_equal(addr, &rd->remote_ip))
672                         return remote;
673         }
674
675         return NULL;
676 }
677
678 static void vxlan_mdb_rdst_free(struct rcu_head *head)
679 {
680         struct vxlan_rdst *rd = container_of(head, struct vxlan_rdst, rcu);
681
682         dst_cache_destroy(&rd->dst_cache);
683         kfree(rd);
684 }
685
686 static int vxlan_mdb_remote_rdst_init(const struct vxlan_mdb_config *cfg,
687                                       struct vxlan_mdb_remote *remote)
688 {
689         struct vxlan_rdst *rd;
690         int err;
691
692         rd = kzalloc(sizeof(*rd), GFP_KERNEL);
693         if (!rd)
694                 return -ENOMEM;
695
696         err = dst_cache_init(&rd->dst_cache, GFP_KERNEL);
697         if (err)
698                 goto err_free_rdst;
699
700         rd->remote_ip = cfg->remote_ip;
701         rd->remote_port = cfg->remote_port;
702         rd->remote_vni = cfg->remote_vni;
703         rd->remote_ifindex = cfg->remote_ifindex;
704         rcu_assign_pointer(remote->rd, rd);
705
706         return 0;
707
708 err_free_rdst:
709         kfree(rd);
710         return err;
711 }
712
713 static void vxlan_mdb_remote_rdst_fini(struct vxlan_rdst *rd)
714 {
715         call_rcu(&rd->rcu, vxlan_mdb_rdst_free);
716 }
717
718 static int vxlan_mdb_remote_init(const struct vxlan_mdb_config *cfg,
719                                  struct vxlan_mdb_remote *remote)
720 {
721         int err;
722
723         err = vxlan_mdb_remote_rdst_init(cfg, remote);
724         if (err)
725                 return err;
726
727         remote->flags = cfg->flags;
728         remote->filter_mode = cfg->filter_mode;
729         remote->rt_protocol = cfg->rt_protocol;
730         INIT_HLIST_HEAD(&remote->src_list);
731
732         return 0;
733 }
734
735 static void vxlan_mdb_remote_fini(struct vxlan_dev *vxlan,
736                                   struct vxlan_mdb_remote *remote)
737 {
738         WARN_ON_ONCE(!hlist_empty(&remote->src_list));
739         vxlan_mdb_remote_rdst_fini(rtnl_dereference(remote->rd));
740 }
741
742 static struct vxlan_mdb_src_entry *
743 vxlan_mdb_remote_src_entry_lookup(const struct vxlan_mdb_remote *remote,
744                                   const union vxlan_addr *addr)
745 {
746         struct vxlan_mdb_src_entry *ent;
747
748         hlist_for_each_entry(ent, &remote->src_list, node) {
749                 if (vxlan_addr_equal(&ent->addr, addr))
750                         return ent;
751         }
752
753         return NULL;
754 }
755
756 static struct vxlan_mdb_src_entry *
757 vxlan_mdb_remote_src_entry_add(struct vxlan_mdb_remote *remote,
758                                const union vxlan_addr *addr)
759 {
760         struct vxlan_mdb_src_entry *ent;
761
762         ent = kzalloc(sizeof(*ent), GFP_KERNEL);
763         if (!ent)
764                 return NULL;
765
766         ent->addr = *addr;
767         hlist_add_head(&ent->node, &remote->src_list);
768
769         return ent;
770 }
771
772 static void
773 vxlan_mdb_remote_src_entry_del(struct vxlan_mdb_src_entry *ent)
774 {
775         hlist_del(&ent->node);
776         kfree(ent);
777 }
778
779 static int
780 vxlan_mdb_remote_src_fwd_add(const struct vxlan_mdb_config *cfg,
781                              const union vxlan_addr *addr,
782                              struct netlink_ext_ack *extack)
783 {
784         struct vxlan_mdb_config sg_cfg;
785
786         memset(&sg_cfg, 0, sizeof(sg_cfg));
787         sg_cfg.vxlan = cfg->vxlan;
788         sg_cfg.group.src = *addr;
789         sg_cfg.group.dst = cfg->group.dst;
790         sg_cfg.group.vni = cfg->group.vni;
791         INIT_LIST_HEAD(&sg_cfg.src_list);
792         sg_cfg.remote_ip = cfg->remote_ip;
793         sg_cfg.remote_ifindex = cfg->remote_ifindex;
794         sg_cfg.remote_vni = cfg->remote_vni;
795         sg_cfg.remote_port = cfg->remote_port;
796         sg_cfg.nlflags = cfg->nlflags;
797         sg_cfg.filter_mode = MCAST_INCLUDE;
798         if (cfg->filter_mode == MCAST_EXCLUDE)
799                 sg_cfg.flags = VXLAN_MDB_REMOTE_F_BLOCKED;
800         sg_cfg.rt_protocol = cfg->rt_protocol;
801
802         return __vxlan_mdb_add(&sg_cfg, extack);
803 }
804
805 static void
806 vxlan_mdb_remote_src_fwd_del(struct vxlan_dev *vxlan,
807                              const struct vxlan_mdb_entry_key *group,
808                              const struct vxlan_mdb_remote *remote,
809                              const union vxlan_addr *addr)
810 {
811         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
812         struct vxlan_mdb_config sg_cfg;
813
814         memset(&sg_cfg, 0, sizeof(sg_cfg));
815         sg_cfg.vxlan = vxlan;
816         sg_cfg.group.src = *addr;
817         sg_cfg.group.dst = group->dst;
818         sg_cfg.group.vni = group->vni;
819         INIT_LIST_HEAD(&sg_cfg.src_list);
820         sg_cfg.remote_ip = rd->remote_ip;
821
822         __vxlan_mdb_del(&sg_cfg, NULL);
823 }
824
825 static int
826 vxlan_mdb_remote_src_add(const struct vxlan_mdb_config *cfg,
827                          struct vxlan_mdb_remote *remote,
828                          const struct vxlan_mdb_config_src_entry *src,
829                          struct netlink_ext_ack *extack)
830 {
831         struct vxlan_mdb_src_entry *ent;
832         int err;
833
834         ent = vxlan_mdb_remote_src_entry_lookup(remote, &src->addr);
835         if (!ent) {
836                 ent = vxlan_mdb_remote_src_entry_add(remote, &src->addr);
837                 if (!ent)
838                         return -ENOMEM;
839         } else if (!(cfg->nlflags & NLM_F_REPLACE)) {
840                 NL_SET_ERR_MSG_MOD(extack, "Source entry already exists");
841                 return -EEXIST;
842         }
843
844         err = vxlan_mdb_remote_src_fwd_add(cfg, &ent->addr, extack);
845         if (err)
846                 goto err_src_del;
847
848         /* Clear flags in case source entry was marked for deletion as part of
849          * replace flow.
850          */
851         ent->flags = 0;
852
853         return 0;
854
855 err_src_del:
856         vxlan_mdb_remote_src_entry_del(ent);
857         return err;
858 }
859
860 static void vxlan_mdb_remote_src_del(struct vxlan_dev *vxlan,
861                                      const struct vxlan_mdb_entry_key *group,
862                                      const struct vxlan_mdb_remote *remote,
863                                      struct vxlan_mdb_src_entry *ent)
864 {
865         vxlan_mdb_remote_src_fwd_del(vxlan, group, remote, &ent->addr);
866         vxlan_mdb_remote_src_entry_del(ent);
867 }
868
869 static int vxlan_mdb_remote_srcs_add(const struct vxlan_mdb_config *cfg,
870                                      struct vxlan_mdb_remote *remote,
871                                      struct netlink_ext_ack *extack)
872 {
873         struct vxlan_mdb_config_src_entry *src;
874         struct vxlan_mdb_src_entry *ent;
875         struct hlist_node *tmp;
876         int err;
877
878         list_for_each_entry(src, &cfg->src_list, node) {
879                 err = vxlan_mdb_remote_src_add(cfg, remote, src, extack);
880                 if (err)
881                         goto err_src_del;
882         }
883
884         return 0;
885
886 err_src_del:
887         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
888                 vxlan_mdb_remote_src_del(cfg->vxlan, &cfg->group, remote, ent);
889         return err;
890 }
891
892 static void vxlan_mdb_remote_srcs_del(struct vxlan_dev *vxlan,
893                                       const struct vxlan_mdb_entry_key *group,
894                                       struct vxlan_mdb_remote *remote)
895 {
896         struct vxlan_mdb_src_entry *ent;
897         struct hlist_node *tmp;
898
899         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
900                 vxlan_mdb_remote_src_del(vxlan, group, remote, ent);
901 }
902
903 static size_t
904 vxlan_mdb_nlmsg_src_list_size(const struct vxlan_mdb_entry_key *group,
905                               const struct vxlan_mdb_remote *remote)
906 {
907         struct vxlan_mdb_src_entry *ent;
908         size_t nlmsg_size;
909
910         if (hlist_empty(&remote->src_list))
911                 return 0;
912
913         /* MDBA_MDB_EATTR_SRC_LIST */
914         nlmsg_size = nla_total_size(0);
915
916         hlist_for_each_entry(ent, &remote->src_list, node) {
917                               /* MDBA_MDB_SRCLIST_ENTRY */
918                 nlmsg_size += nla_total_size(0) +
919                               /* MDBA_MDB_SRCATTR_ADDRESS */
920                               nla_total_size(vxlan_addr_size(&group->dst)) +
921                               /* MDBA_MDB_SRCATTR_TIMER */
922                               nla_total_size(sizeof(u8));
923         }
924
925         return nlmsg_size;
926 }
927
928 static size_t
929 vxlan_mdb_nlmsg_remote_size(const struct vxlan_dev *vxlan,
930                             const struct vxlan_mdb_entry *mdb_entry,
931                             const struct vxlan_mdb_remote *remote)
932 {
933         const struct vxlan_mdb_entry_key *group = &mdb_entry->key;
934         struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
935         size_t nlmsg_size;
936
937                      /* MDBA_MDB_ENTRY_INFO */
938         nlmsg_size = nla_total_size(sizeof(struct br_mdb_entry)) +
939                      /* MDBA_MDB_EATTR_TIMER */
940                      nla_total_size(sizeof(u32));
941
942         /* MDBA_MDB_EATTR_SOURCE */
943         if (vxlan_mdb_is_sg(group))
944                 nlmsg_size += nla_total_size(vxlan_addr_size(&group->dst));
945         /* MDBA_MDB_EATTR_RTPROT */
946         nlmsg_size += nla_total_size(sizeof(u8));
947         /* MDBA_MDB_EATTR_SRC_LIST */
948         nlmsg_size += vxlan_mdb_nlmsg_src_list_size(group, remote);
949         /* MDBA_MDB_EATTR_GROUP_MODE */
950         nlmsg_size += nla_total_size(sizeof(u8));
951         /* MDBA_MDB_EATTR_DST */
952         nlmsg_size += nla_total_size(vxlan_addr_size(&rd->remote_ip));
953         /* MDBA_MDB_EATTR_DST_PORT */
954         if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port)
955                 nlmsg_size += nla_total_size(sizeof(u16));
956         /* MDBA_MDB_EATTR_VNI */
957         if (rd->remote_vni != vxlan->default_dst.remote_vni)
958                 nlmsg_size += nla_total_size(sizeof(u32));
959         /* MDBA_MDB_EATTR_IFINDEX */
960         if (rd->remote_ifindex)
961                 nlmsg_size += nla_total_size(sizeof(u32));
962         /* MDBA_MDB_EATTR_SRC_VNI */
963         if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) && group->vni)
964                 nlmsg_size += nla_total_size(sizeof(u32));
965
966         return nlmsg_size;
967 }
968
969 static size_t vxlan_mdb_nlmsg_size(const struct vxlan_dev *vxlan,
970                                    const struct vxlan_mdb_entry *mdb_entry,
971                                    const struct vxlan_mdb_remote *remote)
972 {
973         return NLMSG_ALIGN(sizeof(struct br_port_msg)) +
974                /* MDBA_MDB */
975                nla_total_size(0) +
976                /* MDBA_MDB_ENTRY */
977                nla_total_size(0) +
978                /* Remote entry */
979                vxlan_mdb_nlmsg_remote_size(vxlan, mdb_entry, remote);
980 }
981
982 static int vxlan_mdb_nlmsg_fill(const struct vxlan_dev *vxlan,
983                                 struct sk_buff *skb,
984                                 const struct vxlan_mdb_entry *mdb_entry,
985                                 const struct vxlan_mdb_remote *remote,
986                                 int type)
987 {
988         struct nlattr *mdb_nest, *mdb_entry_nest;
989         struct br_port_msg *bpm;
990         struct nlmsghdr *nlh;
991
992         nlh = nlmsg_put(skb, 0, 0, type, sizeof(*bpm), 0);
993         if (!nlh)
994                 return -EMSGSIZE;
995
996         bpm = nlmsg_data(nlh);
997         memset(bpm, 0, sizeof(*bpm));
998         bpm->family  = AF_BRIDGE;
999         bpm->ifindex = vxlan->dev->ifindex;
1000
1001         mdb_nest = nla_nest_start_noflag(skb, MDBA_MDB);
1002         if (!mdb_nest)
1003                 goto cancel;
1004         mdb_entry_nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
1005         if (!mdb_entry_nest)
1006                 goto cancel;
1007
1008         if (vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote))
1009                 goto cancel;
1010
1011         nla_nest_end(skb, mdb_entry_nest);
1012         nla_nest_end(skb, mdb_nest);
1013         nlmsg_end(skb, nlh);
1014
1015         return 0;
1016
1017 cancel:
1018         nlmsg_cancel(skb, nlh);
1019         return -EMSGSIZE;
1020 }
1021
1022 static void vxlan_mdb_remote_notify(const struct vxlan_dev *vxlan,
1023                                     const struct vxlan_mdb_entry *mdb_entry,
1024                                     const struct vxlan_mdb_remote *remote,
1025                                     int type)
1026 {
1027         struct net *net = dev_net(vxlan->dev);
1028         struct sk_buff *skb;
1029         int err = -ENOBUFS;
1030
1031         skb = nlmsg_new(vxlan_mdb_nlmsg_size(vxlan, mdb_entry, remote),
1032                         GFP_KERNEL);
1033         if (!skb)
1034                 goto errout;
1035
1036         err = vxlan_mdb_nlmsg_fill(vxlan, skb, mdb_entry, remote, type);
1037         if (err) {
1038                 kfree_skb(skb);
1039                 goto errout;
1040         }
1041
1042         rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_KERNEL);
1043         return;
1044 errout:
1045         rtnl_set_sk_err(net, RTNLGRP_MDB, err);
1046 }
1047
1048 static int
1049 vxlan_mdb_remote_srcs_replace(const struct vxlan_mdb_config *cfg,
1050                               const struct vxlan_mdb_entry *mdb_entry,
1051                               struct vxlan_mdb_remote *remote,
1052                               struct netlink_ext_ack *extack)
1053 {
1054         struct vxlan_dev *vxlan = cfg->vxlan;
1055         struct vxlan_mdb_src_entry *ent;
1056         struct hlist_node *tmp;
1057         int err;
1058
1059         hlist_for_each_entry(ent, &remote->src_list, node)
1060                 ent->flags |= VXLAN_SGRP_F_DELETE;
1061
1062         err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1063         if (err)
1064                 goto err_clear_delete;
1065
1066         hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node) {
1067                 if (ent->flags & VXLAN_SGRP_F_DELETE)
1068                         vxlan_mdb_remote_src_del(vxlan, &mdb_entry->key, remote,
1069                                                  ent);
1070         }
1071
1072         return 0;
1073
1074 err_clear_delete:
1075         hlist_for_each_entry(ent, &remote->src_list, node)
1076                 ent->flags &= ~VXLAN_SGRP_F_DELETE;
1077         return err;
1078 }
1079
1080 static int vxlan_mdb_remote_replace(const struct vxlan_mdb_config *cfg,
1081                                     const struct vxlan_mdb_entry *mdb_entry,
1082                                     struct vxlan_mdb_remote *remote,
1083                                     struct netlink_ext_ack *extack)
1084 {
1085         struct vxlan_rdst *new_rd, *old_rd = rtnl_dereference(remote->rd);
1086         struct vxlan_dev *vxlan = cfg->vxlan;
1087         int err;
1088
1089         err = vxlan_mdb_remote_rdst_init(cfg, remote);
1090         if (err)
1091                 return err;
1092         new_rd = rtnl_dereference(remote->rd);
1093
1094         err = vxlan_mdb_remote_srcs_replace(cfg, mdb_entry, remote, extack);
1095         if (err)
1096                 goto err_rdst_reset;
1097
1098         WRITE_ONCE(remote->flags, cfg->flags);
1099         WRITE_ONCE(remote->filter_mode, cfg->filter_mode);
1100         remote->rt_protocol = cfg->rt_protocol;
1101         vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_NEWMDB);
1102
1103         vxlan_mdb_remote_rdst_fini(old_rd);
1104
1105         return 0;
1106
1107 err_rdst_reset:
1108         rcu_assign_pointer(remote->rd, old_rd);
1109         vxlan_mdb_remote_rdst_fini(new_rd);
1110         return err;
1111 }
1112
1113 static int vxlan_mdb_remote_add(const struct vxlan_mdb_config *cfg,
1114                                 struct vxlan_mdb_entry *mdb_entry,
1115                                 struct netlink_ext_ack *extack)
1116 {
1117         struct vxlan_mdb_remote *remote;
1118         int err;
1119
1120         remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1121         if (remote) {
1122                 if (!(cfg->nlflags & NLM_F_REPLACE)) {
1123                         NL_SET_ERR_MSG_MOD(extack, "Replace not specified and MDB remote entry already exists");
1124                         return -EEXIST;
1125                 }
1126                 return vxlan_mdb_remote_replace(cfg, mdb_entry, remote, extack);
1127         }
1128
1129         if (!(cfg->nlflags & NLM_F_CREATE)) {
1130                 NL_SET_ERR_MSG_MOD(extack, "Create not specified and entry does not exist");
1131                 return -ENOENT;
1132         }
1133
1134         remote = kzalloc(sizeof(*remote), GFP_KERNEL);
1135         if (!remote)
1136                 return -ENOMEM;
1137
1138         err = vxlan_mdb_remote_init(cfg, remote);
1139         if (err) {
1140                 NL_SET_ERR_MSG_MOD(extack, "Failed to initialize remote MDB entry");
1141                 goto err_free_remote;
1142         }
1143
1144         err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1145         if (err)
1146                 goto err_remote_fini;
1147
1148         list_add_rcu(&remote->list, &mdb_entry->remotes);
1149         vxlan_mdb_remote_notify(cfg->vxlan, mdb_entry, remote, RTM_NEWMDB);
1150
1151         return 0;
1152
1153 err_remote_fini:
1154         vxlan_mdb_remote_fini(cfg->vxlan, remote);
1155 err_free_remote:
1156         kfree(remote);
1157         return err;
1158 }
1159
1160 static void vxlan_mdb_remote_del(struct vxlan_dev *vxlan,
1161                                  struct vxlan_mdb_entry *mdb_entry,
1162                                  struct vxlan_mdb_remote *remote)
1163 {
1164         vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_DELMDB);
1165         list_del_rcu(&remote->list);
1166         vxlan_mdb_remote_srcs_del(vxlan, &mdb_entry->key, remote);
1167         vxlan_mdb_remote_fini(vxlan, remote);
1168         kfree_rcu(remote, rcu);
1169 }
1170
1171 static struct vxlan_mdb_entry *
1172 vxlan_mdb_entry_get(struct vxlan_dev *vxlan,
1173                     const struct vxlan_mdb_entry_key *group)
1174 {
1175         struct vxlan_mdb_entry *mdb_entry;
1176         int err;
1177
1178         mdb_entry = vxlan_mdb_entry_lookup(vxlan, group);
1179         if (mdb_entry)
1180                 return mdb_entry;
1181
1182         mdb_entry = kzalloc(sizeof(*mdb_entry), GFP_KERNEL);
1183         if (!mdb_entry)
1184                 return ERR_PTR(-ENOMEM);
1185
1186         INIT_LIST_HEAD(&mdb_entry->remotes);
1187         memcpy(&mdb_entry->key, group, sizeof(mdb_entry->key));
1188         hlist_add_head(&mdb_entry->mdb_node, &vxlan->mdb_list);
1189
1190         err = rhashtable_lookup_insert_fast(&vxlan->mdb_tbl,
1191                                             &mdb_entry->rhnode,
1192                                             vxlan_mdb_rht_params);
1193         if (err)
1194                 goto err_free_entry;
1195
1196         if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1197                 vxlan->cfg.flags |= VXLAN_F_MDB;
1198
1199         return mdb_entry;
1200
1201 err_free_entry:
1202         hlist_del(&mdb_entry->mdb_node);
1203         kfree(mdb_entry);
1204         return ERR_PTR(err);
1205 }
1206
1207 static void vxlan_mdb_entry_put(struct vxlan_dev *vxlan,
1208                                 struct vxlan_mdb_entry *mdb_entry)
1209 {
1210         if (!list_empty(&mdb_entry->remotes))
1211                 return;
1212
1213         if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1214                 vxlan->cfg.flags &= ~VXLAN_F_MDB;
1215
1216         rhashtable_remove_fast(&vxlan->mdb_tbl, &mdb_entry->rhnode,
1217                                vxlan_mdb_rht_params);
1218         hlist_del(&mdb_entry->mdb_node);
1219         kfree_rcu(mdb_entry, rcu);
1220 }
1221
1222 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
1223                            struct netlink_ext_ack *extack)
1224 {
1225         struct vxlan_dev *vxlan = cfg->vxlan;
1226         struct vxlan_mdb_entry *mdb_entry;
1227         int err;
1228
1229         mdb_entry = vxlan_mdb_entry_get(vxlan, &cfg->group);
1230         if (IS_ERR(mdb_entry))
1231                 return PTR_ERR(mdb_entry);
1232
1233         err = vxlan_mdb_remote_add(cfg, mdb_entry, extack);
1234         if (err)
1235                 goto err_entry_put;
1236
1237         vxlan->mdb_seq++;
1238
1239         return 0;
1240
1241 err_entry_put:
1242         vxlan_mdb_entry_put(vxlan, mdb_entry);
1243         return err;
1244 }
1245
1246 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
1247                            struct netlink_ext_ack *extack)
1248 {
1249         struct vxlan_dev *vxlan = cfg->vxlan;
1250         struct vxlan_mdb_entry *mdb_entry;
1251         struct vxlan_mdb_remote *remote;
1252
1253         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &cfg->group);
1254         if (!mdb_entry) {
1255                 NL_SET_ERR_MSG_MOD(extack, "Did not find MDB entry");
1256                 return -ENOENT;
1257         }
1258
1259         remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1260         if (!remote) {
1261                 NL_SET_ERR_MSG_MOD(extack, "Did not find MDB remote entry");
1262                 return -ENOENT;
1263         }
1264
1265         vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1266         vxlan_mdb_entry_put(vxlan, mdb_entry);
1267
1268         vxlan->mdb_seq++;
1269
1270         return 0;
1271 }
1272
1273 int vxlan_mdb_add(struct net_device *dev, struct nlattr *tb[], u16 nlmsg_flags,
1274                   struct netlink_ext_ack *extack)
1275 {
1276         struct vxlan_mdb_config cfg;
1277         int err;
1278
1279         ASSERT_RTNL();
1280
1281         err = vxlan_mdb_config_init(&cfg, dev, tb, nlmsg_flags, extack);
1282         if (err)
1283                 return err;
1284
1285         err = __vxlan_mdb_add(&cfg, extack);
1286
1287         vxlan_mdb_config_fini(&cfg);
1288         return err;
1289 }
1290
1291 int vxlan_mdb_del(struct net_device *dev, struct nlattr *tb[],
1292                   struct netlink_ext_ack *extack)
1293 {
1294         struct vxlan_mdb_config cfg;
1295         int err;
1296
1297         ASSERT_RTNL();
1298
1299         err = vxlan_mdb_config_init(&cfg, dev, tb, 0, extack);
1300         if (err)
1301                 return err;
1302
1303         err = __vxlan_mdb_del(&cfg, extack);
1304
1305         vxlan_mdb_config_fini(&cfg);
1306         return err;
1307 }
1308
1309 static const struct nla_policy vxlan_mdbe_attrs_get_pol[MDBE_ATTR_MAX + 1] = {
1310         [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
1311                                               sizeof(struct in_addr),
1312                                               sizeof(struct in6_addr)),
1313         [MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
1314 };
1315
1316 static int vxlan_mdb_get_parse(struct net_device *dev, struct nlattr *tb[],
1317                                struct vxlan_mdb_entry_key *group,
1318                                struct netlink_ext_ack *extack)
1319 {
1320         struct br_mdb_entry *entry = nla_data(tb[MDBA_GET_ENTRY]);
1321         struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
1322         struct vxlan_dev *vxlan = netdev_priv(dev);
1323         int err;
1324
1325         memset(group, 0, sizeof(*group));
1326         group->vni = vxlan->default_dst.remote_vni;
1327
1328         if (!tb[MDBA_GET_ENTRY_ATTRS]) {
1329                 vxlan_mdb_group_set(group, entry, NULL);
1330                 return 0;
1331         }
1332
1333         err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX,
1334                                tb[MDBA_GET_ENTRY_ATTRS],
1335                                vxlan_mdbe_attrs_get_pol, extack);
1336         if (err)
1337                 return err;
1338
1339         if (mdbe_attrs[MDBE_ATTR_SOURCE] &&
1340             !vxlan_mdb_is_valid_source(mdbe_attrs[MDBE_ATTR_SOURCE],
1341                                        entry->addr.proto, extack))
1342                 return -EINVAL;
1343
1344         vxlan_mdb_group_set(group, entry, mdbe_attrs[MDBE_ATTR_SOURCE]);
1345
1346         if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
1347                 group->vni =
1348                         cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
1349
1350         return 0;
1351 }
1352
1353 static struct sk_buff *
1354 vxlan_mdb_get_reply_alloc(const struct vxlan_dev *vxlan,
1355                           const struct vxlan_mdb_entry *mdb_entry)
1356 {
1357         struct vxlan_mdb_remote *remote;
1358         size_t nlmsg_size;
1359
1360         nlmsg_size = NLMSG_ALIGN(sizeof(struct br_port_msg)) +
1361                      /* MDBA_MDB */
1362                      nla_total_size(0) +
1363                      /* MDBA_MDB_ENTRY */
1364                      nla_total_size(0);
1365
1366         list_for_each_entry(remote, &mdb_entry->remotes, list)
1367                 nlmsg_size += vxlan_mdb_nlmsg_remote_size(vxlan, mdb_entry,
1368                                                           remote);
1369
1370         return nlmsg_new(nlmsg_size, GFP_KERNEL);
1371 }
1372
1373 static int
1374 vxlan_mdb_get_reply_fill(const struct vxlan_dev *vxlan,
1375                          struct sk_buff *skb,
1376                          const struct vxlan_mdb_entry *mdb_entry,
1377                          u32 portid, u32 seq)
1378 {
1379         struct nlattr *mdb_nest, *mdb_entry_nest;
1380         struct vxlan_mdb_remote *remote;
1381         struct br_port_msg *bpm;
1382         struct nlmsghdr *nlh;
1383         int err;
1384
1385         nlh = nlmsg_put(skb, portid, seq, RTM_NEWMDB, sizeof(*bpm), 0);
1386         if (!nlh)
1387                 return -EMSGSIZE;
1388
1389         bpm = nlmsg_data(nlh);
1390         memset(bpm, 0, sizeof(*bpm));
1391         bpm->family  = AF_BRIDGE;
1392         bpm->ifindex = vxlan->dev->ifindex;
1393         mdb_nest = nla_nest_start_noflag(skb, MDBA_MDB);
1394         if (!mdb_nest) {
1395                 err = -EMSGSIZE;
1396                 goto cancel;
1397         }
1398         mdb_entry_nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
1399         if (!mdb_entry_nest) {
1400                 err = -EMSGSIZE;
1401                 goto cancel;
1402         }
1403
1404         list_for_each_entry(remote, &mdb_entry->remotes, list) {
1405                 err = vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote);
1406                 if (err)
1407                         goto cancel;
1408         }
1409
1410         nla_nest_end(skb, mdb_entry_nest);
1411         nla_nest_end(skb, mdb_nest);
1412         nlmsg_end(skb, nlh);
1413
1414         return 0;
1415
1416 cancel:
1417         nlmsg_cancel(skb, nlh);
1418         return err;
1419 }
1420
1421 int vxlan_mdb_get(struct net_device *dev, struct nlattr *tb[], u32 portid,
1422                   u32 seq, struct netlink_ext_ack *extack)
1423 {
1424         struct vxlan_dev *vxlan = netdev_priv(dev);
1425         struct vxlan_mdb_entry *mdb_entry;
1426         struct vxlan_mdb_entry_key group;
1427         struct sk_buff *skb;
1428         int err;
1429
1430         ASSERT_RTNL();
1431
1432         err = vxlan_mdb_get_parse(dev, tb, &group, extack);
1433         if (err)
1434                 return err;
1435
1436         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1437         if (!mdb_entry) {
1438                 NL_SET_ERR_MSG_MOD(extack, "MDB entry not found");
1439                 return -ENOENT;
1440         }
1441
1442         skb = vxlan_mdb_get_reply_alloc(vxlan, mdb_entry);
1443         if (!skb)
1444                 return -ENOMEM;
1445
1446         err = vxlan_mdb_get_reply_fill(vxlan, skb, mdb_entry, portid, seq);
1447         if (err) {
1448                 NL_SET_ERR_MSG_MOD(extack, "Failed to fill MDB get reply");
1449                 goto free;
1450         }
1451
1452         return rtnl_unicast(skb, dev_net(dev), portid);
1453
1454 free:
1455         kfree_skb(skb);
1456         return err;
1457 }
1458
1459 struct vxlan_mdb_entry *vxlan_mdb_entry_skb_get(struct vxlan_dev *vxlan,
1460                                                 struct sk_buff *skb,
1461                                                 __be32 src_vni)
1462 {
1463         struct vxlan_mdb_entry *mdb_entry;
1464         struct vxlan_mdb_entry_key group;
1465
1466         if (!is_multicast_ether_addr(eth_hdr(skb)->h_dest) ||
1467             is_broadcast_ether_addr(eth_hdr(skb)->h_dest))
1468                 return NULL;
1469
1470         /* When not in collect metadata mode, 'src_vni' is zero, but MDB
1471          * entries are stored with the VNI of the VXLAN device.
1472          */
1473         if (!(vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA))
1474                 src_vni = vxlan->default_dst.remote_vni;
1475
1476         memset(&group, 0, sizeof(group));
1477         group.vni = src_vni;
1478
1479         switch (skb->protocol) {
1480         case htons(ETH_P_IP):
1481                 if (!pskb_may_pull(skb, sizeof(struct iphdr)))
1482                         return NULL;
1483                 group.dst.sa.sa_family = AF_INET;
1484                 group.dst.sin.sin_addr.s_addr = ip_hdr(skb)->daddr;
1485                 group.src.sa.sa_family = AF_INET;
1486                 group.src.sin.sin_addr.s_addr = ip_hdr(skb)->saddr;
1487                 break;
1488 #if IS_ENABLED(CONFIG_IPV6)
1489         case htons(ETH_P_IPV6):
1490                 if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
1491                         return NULL;
1492                 group.dst.sa.sa_family = AF_INET6;
1493                 group.dst.sin6.sin6_addr = ipv6_hdr(skb)->daddr;
1494                 group.src.sa.sa_family = AF_INET6;
1495                 group.src.sin6.sin6_addr = ipv6_hdr(skb)->saddr;
1496                 break;
1497 #endif
1498         default:
1499                 return NULL;
1500         }
1501
1502         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1503         if (mdb_entry)
1504                 return mdb_entry;
1505
1506         memset(&group.src, 0, sizeof(group.src));
1507         mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1508         if (mdb_entry)
1509                 return mdb_entry;
1510
1511         /* No (S, G) or (*, G) found. Look up the all-zeros entry, but only if
1512          * the destination IP address is not link-local multicast since we want
1513          * to transmit such traffic together with broadcast and unknown unicast
1514          * traffic.
1515          */
1516         switch (skb->protocol) {
1517         case htons(ETH_P_IP):
1518                 if (ipv4_is_local_multicast(group.dst.sin.sin_addr.s_addr))
1519                         return NULL;
1520                 group.dst.sin.sin_addr.s_addr = 0;
1521                 break;
1522 #if IS_ENABLED(CONFIG_IPV6)
1523         case htons(ETH_P_IPV6):
1524                 if (ipv6_addr_type(&group.dst.sin6.sin6_addr) &
1525                     IPV6_ADDR_LINKLOCAL)
1526                         return NULL;
1527                 memset(&group.dst.sin6.sin6_addr, 0,
1528                        sizeof(group.dst.sin6.sin6_addr));
1529                 break;
1530 #endif
1531         default:
1532                 return NULL;
1533         }
1534
1535         return vxlan_mdb_entry_lookup(vxlan, &group);
1536 }
1537
1538 netdev_tx_t vxlan_mdb_xmit(struct vxlan_dev *vxlan,
1539                            const struct vxlan_mdb_entry *mdb_entry,
1540                            struct sk_buff *skb)
1541 {
1542         struct vxlan_mdb_remote *remote, *fremote = NULL;
1543         __be32 src_vni = mdb_entry->key.vni;
1544
1545         list_for_each_entry_rcu(remote, &mdb_entry->remotes, list) {
1546                 struct sk_buff *skb1;
1547
1548                 if ((vxlan_mdb_is_star_g(&mdb_entry->key) &&
1549                      READ_ONCE(remote->filter_mode) == MCAST_INCLUDE) ||
1550                     (READ_ONCE(remote->flags) & VXLAN_MDB_REMOTE_F_BLOCKED))
1551                         continue;
1552
1553                 if (!fremote) {
1554                         fremote = remote;
1555                         continue;
1556                 }
1557
1558                 skb1 = skb_clone(skb, GFP_ATOMIC);
1559                 if (skb1)
1560                         vxlan_xmit_one(skb1, vxlan->dev, src_vni,
1561                                        rcu_dereference(remote->rd), false);
1562         }
1563
1564         if (fremote)
1565                 vxlan_xmit_one(skb, vxlan->dev, src_vni,
1566                                rcu_dereference(fremote->rd), false);
1567         else
1568                 kfree_skb(skb);
1569
1570         return NETDEV_TX_OK;
1571 }
1572
1573 static void vxlan_mdb_check_empty(void *ptr, void *arg)
1574 {
1575         WARN_ON_ONCE(1);
1576 }
1577
1578 static void vxlan_mdb_remotes_flush(struct vxlan_dev *vxlan,
1579                                     struct vxlan_mdb_entry *mdb_entry)
1580 {
1581         struct vxlan_mdb_remote *remote, *tmp;
1582
1583         list_for_each_entry_safe(remote, tmp, &mdb_entry->remotes, list)
1584                 vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1585 }
1586
1587 static void vxlan_mdb_entries_flush(struct vxlan_dev *vxlan)
1588 {
1589         struct vxlan_mdb_entry *mdb_entry;
1590         struct hlist_node *tmp;
1591
1592         /* The removal of an entry cannot trigger the removal of another entry
1593          * since entries are always added to the head of the list.
1594          */
1595         hlist_for_each_entry_safe(mdb_entry, tmp, &vxlan->mdb_list, mdb_node) {
1596                 vxlan_mdb_remotes_flush(vxlan, mdb_entry);
1597                 vxlan_mdb_entry_put(vxlan, mdb_entry);
1598         }
1599 }
1600
1601 int vxlan_mdb_init(struct vxlan_dev *vxlan)
1602 {
1603         int err;
1604
1605         err = rhashtable_init(&vxlan->mdb_tbl, &vxlan_mdb_rht_params);
1606         if (err)
1607                 return err;
1608
1609         INIT_HLIST_HEAD(&vxlan->mdb_list);
1610
1611         return 0;
1612 }
1613
1614 void vxlan_mdb_fini(struct vxlan_dev *vxlan)
1615 {
1616         vxlan_mdb_entries_flush(vxlan);
1617         WARN_ON_ONCE(vxlan->cfg.flags & VXLAN_F_MDB);
1618         rhashtable_free_and_destroy(&vxlan->mdb_tbl, vxlan_mdb_check_empty,
1619                                     NULL);
1620 }