Merge tag 'arm64-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/arm64/linux
[linux-block.git] / drivers / net / netkit.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2023 Isovalent */
3
4 #include <linux/netdevice.h>
5 #include <linux/ethtool.h>
6 #include <linux/etherdevice.h>
7 #include <linux/filter.h>
8 #include <linux/netfilter_netdev.h>
9 #include <linux/bpf_mprog.h>
10
11 #include <net/netkit.h>
12 #include <net/dst.h>
13 #include <net/tcx.h>
14
15 #define DRV_NAME "netkit"
16
17 struct netkit {
18         /* Needed in fast-path */
19         struct net_device __rcu *peer;
20         struct bpf_mprog_entry __rcu *active;
21         enum netkit_action policy;
22         struct bpf_mprog_bundle bundle;
23
24         /* Needed in slow-path */
25         enum netkit_mode mode;
26         bool primary;
27         u32 headroom;
28 };
29
30 struct netkit_link {
31         struct bpf_link link;
32         struct net_device *dev;
33         u32 location;
34 };
35
36 static __always_inline int
37 netkit_run(const struct bpf_mprog_entry *entry, struct sk_buff *skb,
38            enum netkit_action ret)
39 {
40         const struct bpf_mprog_fp *fp;
41         const struct bpf_prog *prog;
42
43         bpf_mprog_foreach_prog(entry, fp, prog) {
44                 bpf_compute_data_pointers(skb);
45                 ret = bpf_prog_run(prog, skb);
46                 if (ret != NETKIT_NEXT)
47                         break;
48         }
49         return ret;
50 }
51
52 static void netkit_prep_forward(struct sk_buff *skb, bool xnet)
53 {
54         skb_scrub_packet(skb, xnet);
55         skb->priority = 0;
56         nf_skip_egress(skb, true);
57 }
58
59 static struct netkit *netkit_priv(const struct net_device *dev)
60 {
61         return netdev_priv(dev);
62 }
63
64 static netdev_tx_t netkit_xmit(struct sk_buff *skb, struct net_device *dev)
65 {
66         struct netkit *nk = netkit_priv(dev);
67         enum netkit_action ret = READ_ONCE(nk->policy);
68         netdev_tx_t ret_dev = NET_XMIT_SUCCESS;
69         const struct bpf_mprog_entry *entry;
70         struct net_device *peer;
71
72         rcu_read_lock();
73         peer = rcu_dereference(nk->peer);
74         if (unlikely(!peer || !(peer->flags & IFF_UP) ||
75                      !pskb_may_pull(skb, ETH_HLEN) ||
76                      skb_orphan_frags(skb, GFP_ATOMIC)))
77                 goto drop;
78         netkit_prep_forward(skb, !net_eq(dev_net(dev), dev_net(peer)));
79         skb->dev = peer;
80         entry = rcu_dereference(nk->active);
81         if (entry)
82                 ret = netkit_run(entry, skb, ret);
83         switch (ret) {
84         case NETKIT_NEXT:
85         case NETKIT_PASS:
86                 skb->protocol = eth_type_trans(skb, skb->dev);
87                 skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN);
88                 __netif_rx(skb);
89                 break;
90         case NETKIT_REDIRECT:
91                 skb_do_redirect(skb);
92                 break;
93         case NETKIT_DROP:
94         default:
95 drop:
96                 kfree_skb(skb);
97                 dev_core_stats_tx_dropped_inc(dev);
98                 ret_dev = NET_XMIT_DROP;
99                 break;
100         }
101         rcu_read_unlock();
102         return ret_dev;
103 }
104
105 static int netkit_open(struct net_device *dev)
106 {
107         struct netkit *nk = netkit_priv(dev);
108         struct net_device *peer = rtnl_dereference(nk->peer);
109
110         if (!peer)
111                 return -ENOTCONN;
112         if (peer->flags & IFF_UP) {
113                 netif_carrier_on(dev);
114                 netif_carrier_on(peer);
115         }
116         return 0;
117 }
118
119 static int netkit_close(struct net_device *dev)
120 {
121         struct netkit *nk = netkit_priv(dev);
122         struct net_device *peer = rtnl_dereference(nk->peer);
123
124         netif_carrier_off(dev);
125         if (peer)
126                 netif_carrier_off(peer);
127         return 0;
128 }
129
130 static int netkit_get_iflink(const struct net_device *dev)
131 {
132         struct netkit *nk = netkit_priv(dev);
133         struct net_device *peer;
134         int iflink = 0;
135
136         rcu_read_lock();
137         peer = rcu_dereference(nk->peer);
138         if (peer)
139                 iflink = peer->ifindex;
140         rcu_read_unlock();
141         return iflink;
142 }
143
144 static void netkit_set_multicast(struct net_device *dev)
145 {
146         /* Nothing to do, we receive whatever gets pushed to us! */
147 }
148
149 static void netkit_set_headroom(struct net_device *dev, int headroom)
150 {
151         struct netkit *nk = netkit_priv(dev), *nk2;
152         struct net_device *peer;
153
154         if (headroom < 0)
155                 headroom = NET_SKB_PAD;
156
157         rcu_read_lock();
158         peer = rcu_dereference(nk->peer);
159         if (unlikely(!peer))
160                 goto out;
161
162         nk2 = netkit_priv(peer);
163         nk->headroom = headroom;
164         headroom = max(nk->headroom, nk2->headroom);
165
166         peer->needed_headroom = headroom;
167         dev->needed_headroom = headroom;
168 out:
169         rcu_read_unlock();
170 }
171
172 static struct net_device *netkit_peer_dev(struct net_device *dev)
173 {
174         return rcu_dereference(netkit_priv(dev)->peer);
175 }
176
177 static void netkit_uninit(struct net_device *dev);
178
179 static const struct net_device_ops netkit_netdev_ops = {
180         .ndo_open               = netkit_open,
181         .ndo_stop               = netkit_close,
182         .ndo_start_xmit         = netkit_xmit,
183         .ndo_set_rx_mode        = netkit_set_multicast,
184         .ndo_set_rx_headroom    = netkit_set_headroom,
185         .ndo_get_iflink         = netkit_get_iflink,
186         .ndo_get_peer_dev       = netkit_peer_dev,
187         .ndo_uninit             = netkit_uninit,
188         .ndo_features_check     = passthru_features_check,
189 };
190
191 static void netkit_get_drvinfo(struct net_device *dev,
192                                struct ethtool_drvinfo *info)
193 {
194         strscpy(info->driver, DRV_NAME, sizeof(info->driver));
195 }
196
197 static const struct ethtool_ops netkit_ethtool_ops = {
198         .get_drvinfo            = netkit_get_drvinfo,
199 };
200
201 static void netkit_setup(struct net_device *dev)
202 {
203         static const netdev_features_t netkit_features_hw_vlan =
204                 NETIF_F_HW_VLAN_CTAG_TX |
205                 NETIF_F_HW_VLAN_CTAG_RX |
206                 NETIF_F_HW_VLAN_STAG_TX |
207                 NETIF_F_HW_VLAN_STAG_RX;
208         static const netdev_features_t netkit_features =
209                 netkit_features_hw_vlan |
210                 NETIF_F_SG |
211                 NETIF_F_FRAGLIST |
212                 NETIF_F_HW_CSUM |
213                 NETIF_F_RXCSUM |
214                 NETIF_F_SCTP_CRC |
215                 NETIF_F_HIGHDMA |
216                 NETIF_F_GSO_SOFTWARE |
217                 NETIF_F_GSO_ENCAP_ALL;
218
219         ether_setup(dev);
220         dev->max_mtu = ETH_MAX_MTU;
221
222         dev->flags |= IFF_NOARP;
223         dev->priv_flags &= ~IFF_TX_SKB_SHARING;
224         dev->priv_flags |= IFF_LIVE_ADDR_CHANGE;
225         dev->priv_flags |= IFF_PHONY_HEADROOM;
226         dev->priv_flags |= IFF_NO_QUEUE;
227
228         dev->ethtool_ops = &netkit_ethtool_ops;
229         dev->netdev_ops  = &netkit_netdev_ops;
230
231         dev->features |= netkit_features | NETIF_F_LLTX;
232         dev->hw_features = netkit_features;
233         dev->hw_enc_features = netkit_features;
234         dev->mpls_features = NETIF_F_HW_CSUM | NETIF_F_GSO_SOFTWARE;
235         dev->vlan_features = dev->features & ~netkit_features_hw_vlan;
236
237         dev->needs_free_netdev = true;
238
239         netif_set_tso_max_size(dev, GSO_MAX_SIZE);
240 }
241
242 static struct net *netkit_get_link_net(const struct net_device *dev)
243 {
244         struct netkit *nk = netkit_priv(dev);
245         struct net_device *peer = rtnl_dereference(nk->peer);
246
247         return peer ? dev_net(peer) : dev_net(dev);
248 }
249
250 static int netkit_check_policy(int policy, struct nlattr *tb,
251                                struct netlink_ext_ack *extack)
252 {
253         switch (policy) {
254         case NETKIT_PASS:
255         case NETKIT_DROP:
256                 return 0;
257         default:
258                 NL_SET_ERR_MSG_ATTR(extack, tb,
259                                     "Provided default xmit policy not supported");
260                 return -EINVAL;
261         }
262 }
263
264 static int netkit_check_mode(int mode, struct nlattr *tb,
265                              struct netlink_ext_ack *extack)
266 {
267         switch (mode) {
268         case NETKIT_L2:
269         case NETKIT_L3:
270                 return 0;
271         default:
272                 NL_SET_ERR_MSG_ATTR(extack, tb,
273                                     "Provided device mode can only be L2 or L3");
274                 return -EINVAL;
275         }
276 }
277
278 static int netkit_validate(struct nlattr *tb[], struct nlattr *data[],
279                            struct netlink_ext_ack *extack)
280 {
281         struct nlattr *attr = tb[IFLA_ADDRESS];
282
283         if (!attr)
284                 return 0;
285         NL_SET_ERR_MSG_ATTR(extack, attr,
286                             "Setting Ethernet address is not supported");
287         return -EOPNOTSUPP;
288 }
289
290 static struct rtnl_link_ops netkit_link_ops;
291
292 static int netkit_new_link(struct net *src_net, struct net_device *dev,
293                            struct nlattr *tb[], struct nlattr *data[],
294                            struct netlink_ext_ack *extack)
295 {
296         struct nlattr *peer_tb[IFLA_MAX + 1], **tbp = tb, *attr;
297         enum netkit_action default_prim = NETKIT_PASS;
298         enum netkit_action default_peer = NETKIT_PASS;
299         enum netkit_mode mode = NETKIT_L3;
300         unsigned char ifname_assign_type;
301         struct ifinfomsg *ifmp = NULL;
302         struct net_device *peer;
303         char ifname[IFNAMSIZ];
304         struct netkit *nk;
305         struct net *net;
306         int err;
307
308         if (data) {
309                 if (data[IFLA_NETKIT_MODE]) {
310                         attr = data[IFLA_NETKIT_MODE];
311                         mode = nla_get_u32(attr);
312                         err = netkit_check_mode(mode, attr, extack);
313                         if (err < 0)
314                                 return err;
315                 }
316                 if (data[IFLA_NETKIT_PEER_INFO]) {
317                         attr = data[IFLA_NETKIT_PEER_INFO];
318                         ifmp = nla_data(attr);
319                         err = rtnl_nla_parse_ifinfomsg(peer_tb, attr, extack);
320                         if (err < 0)
321                                 return err;
322                         err = netkit_validate(peer_tb, NULL, extack);
323                         if (err < 0)
324                                 return err;
325                         tbp = peer_tb;
326                 }
327                 if (data[IFLA_NETKIT_POLICY]) {
328                         attr = data[IFLA_NETKIT_POLICY];
329                         default_prim = nla_get_u32(attr);
330                         err = netkit_check_policy(default_prim, attr, extack);
331                         if (err < 0)
332                                 return err;
333                 }
334                 if (data[IFLA_NETKIT_PEER_POLICY]) {
335                         attr = data[IFLA_NETKIT_PEER_POLICY];
336                         default_peer = nla_get_u32(attr);
337                         err = netkit_check_policy(default_peer, attr, extack);
338                         if (err < 0)
339                                 return err;
340                 }
341         }
342
343         if (ifmp && tbp[IFLA_IFNAME]) {
344                 nla_strscpy(ifname, tbp[IFLA_IFNAME], IFNAMSIZ);
345                 ifname_assign_type = NET_NAME_USER;
346         } else {
347                 strscpy(ifname, "nk%d", IFNAMSIZ);
348                 ifname_assign_type = NET_NAME_ENUM;
349         }
350
351         net = rtnl_link_get_net(src_net, tbp);
352         if (IS_ERR(net))
353                 return PTR_ERR(net);
354
355         peer = rtnl_create_link(net, ifname, ifname_assign_type,
356                                 &netkit_link_ops, tbp, extack);
357         if (IS_ERR(peer)) {
358                 put_net(net);
359                 return PTR_ERR(peer);
360         }
361
362         netif_inherit_tso_max(peer, dev);
363
364         if (mode == NETKIT_L2)
365                 eth_hw_addr_random(peer);
366         if (ifmp && dev->ifindex)
367                 peer->ifindex = ifmp->ifi_index;
368
369         nk = netkit_priv(peer);
370         nk->primary = false;
371         nk->policy = default_peer;
372         nk->mode = mode;
373         bpf_mprog_bundle_init(&nk->bundle);
374
375         err = register_netdevice(peer);
376         put_net(net);
377         if (err < 0)
378                 goto err_register_peer;
379         netif_carrier_off(peer);
380         if (mode == NETKIT_L2)
381                 dev_change_flags(peer, peer->flags & ~IFF_NOARP, NULL);
382
383         err = rtnl_configure_link(peer, NULL, 0, NULL);
384         if (err < 0)
385                 goto err_configure_peer;
386
387         if (mode == NETKIT_L2)
388                 eth_hw_addr_random(dev);
389         if (tb[IFLA_IFNAME])
390                 nla_strscpy(dev->name, tb[IFLA_IFNAME], IFNAMSIZ);
391         else
392                 strscpy(dev->name, "nk%d", IFNAMSIZ);
393
394         nk = netkit_priv(dev);
395         nk->primary = true;
396         nk->policy = default_prim;
397         nk->mode = mode;
398         bpf_mprog_bundle_init(&nk->bundle);
399
400         err = register_netdevice(dev);
401         if (err < 0)
402                 goto err_configure_peer;
403         netif_carrier_off(dev);
404         if (mode == NETKIT_L2)
405                 dev_change_flags(dev, dev->flags & ~IFF_NOARP, NULL);
406
407         rcu_assign_pointer(netkit_priv(dev)->peer, peer);
408         rcu_assign_pointer(netkit_priv(peer)->peer, dev);
409         return 0;
410 err_configure_peer:
411         unregister_netdevice(peer);
412         return err;
413 err_register_peer:
414         free_netdev(peer);
415         return err;
416 }
417
418 static struct bpf_mprog_entry *netkit_entry_fetch(struct net_device *dev,
419                                                   bool bundle_fallback)
420 {
421         struct netkit *nk = netkit_priv(dev);
422         struct bpf_mprog_entry *entry;
423
424         ASSERT_RTNL();
425         entry = rcu_dereference_rtnl(nk->active);
426         if (entry)
427                 return entry;
428         if (bundle_fallback)
429                 return &nk->bundle.a;
430         return NULL;
431 }
432
433 static void netkit_entry_update(struct net_device *dev,
434                                 struct bpf_mprog_entry *entry)
435 {
436         struct netkit *nk = netkit_priv(dev);
437
438         ASSERT_RTNL();
439         rcu_assign_pointer(nk->active, entry);
440 }
441
442 static void netkit_entry_sync(void)
443 {
444         synchronize_rcu();
445 }
446
447 static struct net_device *netkit_dev_fetch(struct net *net, u32 ifindex, u32 which)
448 {
449         struct net_device *dev;
450         struct netkit *nk;
451
452         ASSERT_RTNL();
453
454         switch (which) {
455         case BPF_NETKIT_PRIMARY:
456         case BPF_NETKIT_PEER:
457                 break;
458         default:
459                 return ERR_PTR(-EINVAL);
460         }
461
462         dev = __dev_get_by_index(net, ifindex);
463         if (!dev)
464                 return ERR_PTR(-ENODEV);
465         if (dev->netdev_ops != &netkit_netdev_ops)
466                 return ERR_PTR(-ENXIO);
467
468         nk = netkit_priv(dev);
469         if (!nk->primary)
470                 return ERR_PTR(-EACCES);
471         if (which == BPF_NETKIT_PEER) {
472                 dev = rcu_dereference_rtnl(nk->peer);
473                 if (!dev)
474                         return ERR_PTR(-ENODEV);
475         }
476         return dev;
477 }
478
479 int netkit_prog_attach(const union bpf_attr *attr, struct bpf_prog *prog)
480 {
481         struct bpf_mprog_entry *entry, *entry_new;
482         struct bpf_prog *replace_prog = NULL;
483         struct net_device *dev;
484         int ret;
485
486         rtnl_lock();
487         dev = netkit_dev_fetch(current->nsproxy->net_ns, attr->target_ifindex,
488                                attr->attach_type);
489         if (IS_ERR(dev)) {
490                 ret = PTR_ERR(dev);
491                 goto out;
492         }
493         entry = netkit_entry_fetch(dev, true);
494         if (attr->attach_flags & BPF_F_REPLACE) {
495                 replace_prog = bpf_prog_get_type(attr->replace_bpf_fd,
496                                                  prog->type);
497                 if (IS_ERR(replace_prog)) {
498                         ret = PTR_ERR(replace_prog);
499                         replace_prog = NULL;
500                         goto out;
501                 }
502         }
503         ret = bpf_mprog_attach(entry, &entry_new, prog, NULL, replace_prog,
504                                attr->attach_flags, attr->relative_fd,
505                                attr->expected_revision);
506         if (!ret) {
507                 if (entry != entry_new) {
508                         netkit_entry_update(dev, entry_new);
509                         netkit_entry_sync();
510                 }
511                 bpf_mprog_commit(entry);
512         }
513 out:
514         if (replace_prog)
515                 bpf_prog_put(replace_prog);
516         rtnl_unlock();
517         return ret;
518 }
519
520 int netkit_prog_detach(const union bpf_attr *attr, struct bpf_prog *prog)
521 {
522         struct bpf_mprog_entry *entry, *entry_new;
523         struct net_device *dev;
524         int ret;
525
526         rtnl_lock();
527         dev = netkit_dev_fetch(current->nsproxy->net_ns, attr->target_ifindex,
528                                attr->attach_type);
529         if (IS_ERR(dev)) {
530                 ret = PTR_ERR(dev);
531                 goto out;
532         }
533         entry = netkit_entry_fetch(dev, false);
534         if (!entry) {
535                 ret = -ENOENT;
536                 goto out;
537         }
538         ret = bpf_mprog_detach(entry, &entry_new, prog, NULL, attr->attach_flags,
539                                attr->relative_fd, attr->expected_revision);
540         if (!ret) {
541                 if (!bpf_mprog_total(entry_new))
542                         entry_new = NULL;
543                 netkit_entry_update(dev, entry_new);
544                 netkit_entry_sync();
545                 bpf_mprog_commit(entry);
546         }
547 out:
548         rtnl_unlock();
549         return ret;
550 }
551
552 int netkit_prog_query(const union bpf_attr *attr, union bpf_attr __user *uattr)
553 {
554         struct net_device *dev;
555         int ret;
556
557         rtnl_lock();
558         dev = netkit_dev_fetch(current->nsproxy->net_ns,
559                                attr->query.target_ifindex,
560                                attr->query.attach_type);
561         if (IS_ERR(dev)) {
562                 ret = PTR_ERR(dev);
563                 goto out;
564         }
565         ret = bpf_mprog_query(attr, uattr, netkit_entry_fetch(dev, false));
566 out:
567         rtnl_unlock();
568         return ret;
569 }
570
571 static struct netkit_link *netkit_link(const struct bpf_link *link)
572 {
573         return container_of(link, struct netkit_link, link);
574 }
575
576 static int netkit_link_prog_attach(struct bpf_link *link, u32 flags,
577                                    u32 id_or_fd, u64 revision)
578 {
579         struct netkit_link *nkl = netkit_link(link);
580         struct bpf_mprog_entry *entry, *entry_new;
581         struct net_device *dev = nkl->dev;
582         int ret;
583
584         ASSERT_RTNL();
585         entry = netkit_entry_fetch(dev, true);
586         ret = bpf_mprog_attach(entry, &entry_new, link->prog, link, NULL, flags,
587                                id_or_fd, revision);
588         if (!ret) {
589                 if (entry != entry_new) {
590                         netkit_entry_update(dev, entry_new);
591                         netkit_entry_sync();
592                 }
593                 bpf_mprog_commit(entry);
594         }
595         return ret;
596 }
597
598 static void netkit_link_release(struct bpf_link *link)
599 {
600         struct netkit_link *nkl = netkit_link(link);
601         struct bpf_mprog_entry *entry, *entry_new;
602         struct net_device *dev;
603         int ret = 0;
604
605         rtnl_lock();
606         dev = nkl->dev;
607         if (!dev)
608                 goto out;
609         entry = netkit_entry_fetch(dev, false);
610         if (!entry) {
611                 ret = -ENOENT;
612                 goto out;
613         }
614         ret = bpf_mprog_detach(entry, &entry_new, link->prog, link, 0, 0, 0);
615         if (!ret) {
616                 if (!bpf_mprog_total(entry_new))
617                         entry_new = NULL;
618                 netkit_entry_update(dev, entry_new);
619                 netkit_entry_sync();
620                 bpf_mprog_commit(entry);
621                 nkl->dev = NULL;
622         }
623 out:
624         WARN_ON_ONCE(ret);
625         rtnl_unlock();
626 }
627
628 static int netkit_link_update(struct bpf_link *link, struct bpf_prog *nprog,
629                               struct bpf_prog *oprog)
630 {
631         struct netkit_link *nkl = netkit_link(link);
632         struct bpf_mprog_entry *entry, *entry_new;
633         struct net_device *dev;
634         int ret = 0;
635
636         rtnl_lock();
637         dev = nkl->dev;
638         if (!dev) {
639                 ret = -ENOLINK;
640                 goto out;
641         }
642         if (oprog && link->prog != oprog) {
643                 ret = -EPERM;
644                 goto out;
645         }
646         oprog = link->prog;
647         if (oprog == nprog) {
648                 bpf_prog_put(nprog);
649                 goto out;
650         }
651         entry = netkit_entry_fetch(dev, false);
652         if (!entry) {
653                 ret = -ENOENT;
654                 goto out;
655         }
656         ret = bpf_mprog_attach(entry, &entry_new, nprog, link, oprog,
657                                BPF_F_REPLACE | BPF_F_ID,
658                                link->prog->aux->id, 0);
659         if (!ret) {
660                 WARN_ON_ONCE(entry != entry_new);
661                 oprog = xchg(&link->prog, nprog);
662                 bpf_prog_put(oprog);
663                 bpf_mprog_commit(entry);
664         }
665 out:
666         rtnl_unlock();
667         return ret;
668 }
669
670 static void netkit_link_dealloc(struct bpf_link *link)
671 {
672         kfree(netkit_link(link));
673 }
674
675 static void netkit_link_fdinfo(const struct bpf_link *link, struct seq_file *seq)
676 {
677         const struct netkit_link *nkl = netkit_link(link);
678         u32 ifindex = 0;
679
680         rtnl_lock();
681         if (nkl->dev)
682                 ifindex = nkl->dev->ifindex;
683         rtnl_unlock();
684
685         seq_printf(seq, "ifindex:\t%u\n", ifindex);
686         seq_printf(seq, "attach_type:\t%u (%s)\n",
687                    nkl->location,
688                    nkl->location == BPF_NETKIT_PRIMARY ? "primary" : "peer");
689 }
690
691 static int netkit_link_fill_info(const struct bpf_link *link,
692                                  struct bpf_link_info *info)
693 {
694         const struct netkit_link *nkl = netkit_link(link);
695         u32 ifindex = 0;
696
697         rtnl_lock();
698         if (nkl->dev)
699                 ifindex = nkl->dev->ifindex;
700         rtnl_unlock();
701
702         info->netkit.ifindex = ifindex;
703         info->netkit.attach_type = nkl->location;
704         return 0;
705 }
706
707 static int netkit_link_detach(struct bpf_link *link)
708 {
709         netkit_link_release(link);
710         return 0;
711 }
712
713 static const struct bpf_link_ops netkit_link_lops = {
714         .release        = netkit_link_release,
715         .detach         = netkit_link_detach,
716         .dealloc        = netkit_link_dealloc,
717         .update_prog    = netkit_link_update,
718         .show_fdinfo    = netkit_link_fdinfo,
719         .fill_link_info = netkit_link_fill_info,
720 };
721
722 static int netkit_link_init(struct netkit_link *nkl,
723                             struct bpf_link_primer *link_primer,
724                             const union bpf_attr *attr,
725                             struct net_device *dev,
726                             struct bpf_prog *prog)
727 {
728         bpf_link_init(&nkl->link, BPF_LINK_TYPE_NETKIT,
729                       &netkit_link_lops, prog);
730         nkl->location = attr->link_create.attach_type;
731         nkl->dev = dev;
732         return bpf_link_prime(&nkl->link, link_primer);
733 }
734
735 int netkit_link_attach(const union bpf_attr *attr, struct bpf_prog *prog)
736 {
737         struct bpf_link_primer link_primer;
738         struct netkit_link *nkl;
739         struct net_device *dev;
740         int ret;
741
742         rtnl_lock();
743         dev = netkit_dev_fetch(current->nsproxy->net_ns,
744                                attr->link_create.target_ifindex,
745                                attr->link_create.attach_type);
746         if (IS_ERR(dev)) {
747                 ret = PTR_ERR(dev);
748                 goto out;
749         }
750         nkl = kzalloc(sizeof(*nkl), GFP_KERNEL_ACCOUNT);
751         if (!nkl) {
752                 ret = -ENOMEM;
753                 goto out;
754         }
755         ret = netkit_link_init(nkl, &link_primer, attr, dev, prog);
756         if (ret) {
757                 kfree(nkl);
758                 goto out;
759         }
760         ret = netkit_link_prog_attach(&nkl->link,
761                                       attr->link_create.flags,
762                                       attr->link_create.netkit.relative_fd,
763                                       attr->link_create.netkit.expected_revision);
764         if (ret) {
765                 nkl->dev = NULL;
766                 bpf_link_cleanup(&link_primer);
767                 goto out;
768         }
769         ret = bpf_link_settle(&link_primer);
770 out:
771         rtnl_unlock();
772         return ret;
773 }
774
775 static void netkit_release_all(struct net_device *dev)
776 {
777         struct bpf_mprog_entry *entry;
778         struct bpf_tuple tuple = {};
779         struct bpf_mprog_fp *fp;
780         struct bpf_mprog_cp *cp;
781
782         entry = netkit_entry_fetch(dev, false);
783         if (!entry)
784                 return;
785         netkit_entry_update(dev, NULL);
786         netkit_entry_sync();
787         bpf_mprog_foreach_tuple(entry, fp, cp, tuple) {
788                 if (tuple.link)
789                         netkit_link(tuple.link)->dev = NULL;
790                 else
791                         bpf_prog_put(tuple.prog);
792         }
793 }
794
795 static void netkit_uninit(struct net_device *dev)
796 {
797         netkit_release_all(dev);
798 }
799
800 static void netkit_del_link(struct net_device *dev, struct list_head *head)
801 {
802         struct netkit *nk = netkit_priv(dev);
803         struct net_device *peer = rtnl_dereference(nk->peer);
804
805         RCU_INIT_POINTER(nk->peer, NULL);
806         unregister_netdevice_queue(dev, head);
807         if (peer) {
808                 nk = netkit_priv(peer);
809                 RCU_INIT_POINTER(nk->peer, NULL);
810                 unregister_netdevice_queue(peer, head);
811         }
812 }
813
814 static int netkit_change_link(struct net_device *dev, struct nlattr *tb[],
815                               struct nlattr *data[],
816                               struct netlink_ext_ack *extack)
817 {
818         struct netkit *nk = netkit_priv(dev);
819         struct net_device *peer = rtnl_dereference(nk->peer);
820         enum netkit_action policy;
821         struct nlattr *attr;
822         int err;
823
824         if (!nk->primary) {
825                 NL_SET_ERR_MSG(extack,
826                                "netkit link settings can be changed only through the primary device");
827                 return -EACCES;
828         }
829
830         if (data[IFLA_NETKIT_MODE]) {
831                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_NETKIT_MODE],
832                                     "netkit link operating mode cannot be changed after device creation");
833                 return -EACCES;
834         }
835
836         if (data[IFLA_NETKIT_POLICY]) {
837                 attr = data[IFLA_NETKIT_POLICY];
838                 policy = nla_get_u32(attr);
839                 err = netkit_check_policy(policy, attr, extack);
840                 if (err)
841                         return err;
842                 WRITE_ONCE(nk->policy, policy);
843         }
844
845         if (data[IFLA_NETKIT_PEER_POLICY]) {
846                 err = -EOPNOTSUPP;
847                 attr = data[IFLA_NETKIT_PEER_POLICY];
848                 policy = nla_get_u32(attr);
849                 if (peer)
850                         err = netkit_check_policy(policy, attr, extack);
851                 if (err)
852                         return err;
853                 nk = netkit_priv(peer);
854                 WRITE_ONCE(nk->policy, policy);
855         }
856
857         return 0;
858 }
859
860 static size_t netkit_get_size(const struct net_device *dev)
861 {
862         return nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_POLICY */
863                nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_PEER_POLICY */
864                nla_total_size(sizeof(u8))  + /* IFLA_NETKIT_PRIMARY */
865                nla_total_size(sizeof(u32)) + /* IFLA_NETKIT_MODE */
866                0;
867 }
868
869 static int netkit_fill_info(struct sk_buff *skb, const struct net_device *dev)
870 {
871         struct netkit *nk = netkit_priv(dev);
872         struct net_device *peer = rtnl_dereference(nk->peer);
873
874         if (nla_put_u8(skb, IFLA_NETKIT_PRIMARY, nk->primary))
875                 return -EMSGSIZE;
876         if (nla_put_u32(skb, IFLA_NETKIT_POLICY, nk->policy))
877                 return -EMSGSIZE;
878         if (nla_put_u32(skb, IFLA_NETKIT_MODE, nk->mode))
879                 return -EMSGSIZE;
880
881         if (peer) {
882                 nk = netkit_priv(peer);
883                 if (nla_put_u32(skb, IFLA_NETKIT_PEER_POLICY, nk->policy))
884                         return -EMSGSIZE;
885         }
886
887         return 0;
888 }
889
890 static const struct nla_policy netkit_policy[IFLA_NETKIT_MAX + 1] = {
891         [IFLA_NETKIT_PEER_INFO]         = { .len = sizeof(struct ifinfomsg) },
892         [IFLA_NETKIT_POLICY]            = { .type = NLA_U32 },
893         [IFLA_NETKIT_MODE]              = { .type = NLA_U32 },
894         [IFLA_NETKIT_PEER_POLICY]       = { .type = NLA_U32 },
895         [IFLA_NETKIT_PRIMARY]           = { .type = NLA_REJECT,
896                                             .reject_message = "Primary attribute is read-only" },
897 };
898
899 static struct rtnl_link_ops netkit_link_ops = {
900         .kind           = DRV_NAME,
901         .priv_size      = sizeof(struct netkit),
902         .setup          = netkit_setup,
903         .newlink        = netkit_new_link,
904         .dellink        = netkit_del_link,
905         .changelink     = netkit_change_link,
906         .get_link_net   = netkit_get_link_net,
907         .get_size       = netkit_get_size,
908         .fill_info      = netkit_fill_info,
909         .policy         = netkit_policy,
910         .validate       = netkit_validate,
911         .maxtype        = IFLA_NETKIT_MAX,
912 };
913
914 static __init int netkit_init(void)
915 {
916         BUILD_BUG_ON((int)NETKIT_NEXT != (int)TCX_NEXT ||
917                      (int)NETKIT_PASS != (int)TCX_PASS ||
918                      (int)NETKIT_DROP != (int)TCX_DROP ||
919                      (int)NETKIT_REDIRECT != (int)TCX_REDIRECT);
920
921         return rtnl_link_register(&netkit_link_ops);
922 }
923
924 static __exit void netkit_exit(void)
925 {
926         rtnl_link_unregister(&netkit_link_ops);
927 }
928
929 module_init(netkit_init);
930 module_exit(netkit_exit);
931
932 MODULE_DESCRIPTION("BPF-programmable network device");
933 MODULE_AUTHOR("Daniel Borkmann <daniel@iogearbox.net>");
934 MODULE_AUTHOR("Nikolay Aleksandrov <razor@blackwall.org>");
935 MODULE_LICENSE("GPL");
936 MODULE_ALIAS_RTNL_LINK(DRV_NAME);