Merge tag 'linux_kselftest-fixes-6.12-rc2' of git://git.kernel.org/pub/scm/linux...
[linux-2.6-block.git] / tools / testing / selftests / bpf / progs / test_sock_fields.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook */
3
4 #include <linux/bpf.h>
5 #include <netinet/in.h>
6 #include <stdbool.h>
7
8 #include <bpf/bpf_helpers.h>
9 #include <bpf/bpf_endian.h>
10
11 enum bpf_linum_array_idx {
12         EGRESS_LINUM_IDX,
13         INGRESS_LINUM_IDX,
14         READ_SK_DST_PORT_LINUM_IDX,
15         __NR_BPF_LINUM_ARRAY_IDX,
16 };
17
18 struct {
19         __uint(type, BPF_MAP_TYPE_ARRAY);
20         __uint(max_entries, __NR_BPF_LINUM_ARRAY_IDX);
21         __type(key, __u32);
22         __type(value, __u32);
23 } linum_map SEC(".maps");
24
25 struct bpf_spinlock_cnt {
26         struct bpf_spin_lock lock;
27         __u32 cnt;
28 };
29
30 struct {
31         __uint(type, BPF_MAP_TYPE_SK_STORAGE);
32         __uint(map_flags, BPF_F_NO_PREALLOC);
33         __type(key, int);
34         __type(value, struct bpf_spinlock_cnt);
35 } sk_pkt_out_cnt SEC(".maps");
36
37 struct {
38         __uint(type, BPF_MAP_TYPE_SK_STORAGE);
39         __uint(map_flags, BPF_F_NO_PREALLOC);
40         __type(key, int);
41         __type(value, struct bpf_spinlock_cnt);
42 } sk_pkt_out_cnt10 SEC(".maps");
43
44 struct tcp_sock {
45         __u32   lsndtime;
46 } __attribute__((preserve_access_index));
47
48 struct bpf_tcp_sock listen_tp = {};
49 struct sockaddr_in6 srv_sa6 = {};
50 struct bpf_tcp_sock cli_tp = {};
51 struct bpf_tcp_sock srv_tp = {};
52 struct bpf_sock listen_sk = {};
53 struct bpf_sock srv_sk = {};
54 struct bpf_sock cli_sk = {};
55 __u64 parent_cg_id = 0;
56 __u64 child_cg_id = 0;
57 __u64 lsndtime = 0;
58
59 static bool is_loopback6(__u32 *a6)
60 {
61         return !a6[0] && !a6[1] && !a6[2] && a6[3] == bpf_htonl(1);
62 }
63
64 static void skcpy(struct bpf_sock *dst,
65                   const struct bpf_sock *src)
66 {
67         dst->bound_dev_if = src->bound_dev_if;
68         dst->family = src->family;
69         dst->type = src->type;
70         dst->protocol = src->protocol;
71         dst->mark = src->mark;
72         dst->priority = src->priority;
73         dst->src_ip4 = src->src_ip4;
74         dst->src_ip6[0] = src->src_ip6[0];
75         dst->src_ip6[1] = src->src_ip6[1];
76         dst->src_ip6[2] = src->src_ip6[2];
77         dst->src_ip6[3] = src->src_ip6[3];
78         dst->src_port = src->src_port;
79         dst->dst_ip4 = src->dst_ip4;
80         dst->dst_ip6[0] = src->dst_ip6[0];
81         dst->dst_ip6[1] = src->dst_ip6[1];
82         dst->dst_ip6[2] = src->dst_ip6[2];
83         dst->dst_ip6[3] = src->dst_ip6[3];
84         dst->dst_port = src->dst_port;
85         dst->state = src->state;
86 }
87
88 static void tpcpy(struct bpf_tcp_sock *dst,
89                   const struct bpf_tcp_sock *src)
90 {
91         dst->snd_cwnd = src->snd_cwnd;
92         dst->srtt_us = src->srtt_us;
93         dst->rtt_min = src->rtt_min;
94         dst->snd_ssthresh = src->snd_ssthresh;
95         dst->rcv_nxt = src->rcv_nxt;
96         dst->snd_nxt = src->snd_nxt;
97         dst->snd_una = src->snd_una;
98         dst->mss_cache = src->mss_cache;
99         dst->ecn_flags = src->ecn_flags;
100         dst->rate_delivered = src->rate_delivered;
101         dst->rate_interval_us = src->rate_interval_us;
102         dst->packets_out = src->packets_out;
103         dst->retrans_out = src->retrans_out;
104         dst->total_retrans = src->total_retrans;
105         dst->segs_in = src->segs_in;
106         dst->data_segs_in = src->data_segs_in;
107         dst->segs_out = src->segs_out;
108         dst->data_segs_out = src->data_segs_out;
109         dst->lost_out = src->lost_out;
110         dst->sacked_out = src->sacked_out;
111         dst->bytes_received = src->bytes_received;
112         dst->bytes_acked = src->bytes_acked;
113 }
114
115 /* Always return CG_OK so that no pkt will be filtered out */
116 #define CG_OK 1
117
118 #define RET_LOG() ({                                            \
119         linum = __LINE__;                                       \
120         bpf_map_update_elem(&linum_map, &linum_idx, &linum, BPF_ANY);   \
121         return CG_OK;                                           \
122 })
123
124 SEC("cgroup_skb/egress")
125 int egress_read_sock_fields(struct __sk_buff *skb)
126 {
127         struct bpf_spinlock_cnt cli_cnt_init = { .lock = {}, .cnt = 0xeB9F };
128         struct bpf_spinlock_cnt *pkt_out_cnt, *pkt_out_cnt10;
129         struct bpf_tcp_sock *tp, *tp_ret;
130         struct bpf_sock *sk, *sk_ret;
131         __u32 linum, linum_idx;
132         struct tcp_sock *ktp;
133
134         linum_idx = EGRESS_LINUM_IDX;
135
136         sk = skb->sk;
137         if (!sk)
138                 RET_LOG();
139
140         /* Not testing the egress traffic or the listening socket,
141          * which are covered by the cgroup_skb/ingress test program.
142          */
143         if (sk->family != AF_INET6 || !is_loopback6(sk->src_ip6) ||
144             sk->state == BPF_TCP_LISTEN)
145                 return CG_OK;
146
147         if (sk->src_port == bpf_ntohs(srv_sa6.sin6_port)) {
148                 /* Server socket */
149                 sk_ret = &srv_sk;
150                 tp_ret = &srv_tp;
151         } else if (sk->dst_port == srv_sa6.sin6_port) {
152                 /* Client socket */
153                 sk_ret = &cli_sk;
154                 tp_ret = &cli_tp;
155         } else {
156                 /* Not the testing egress traffic */
157                 return CG_OK;
158         }
159
160         /* It must be a fullsock for cgroup_skb/egress prog */
161         sk = bpf_sk_fullsock(sk);
162         if (!sk)
163                 RET_LOG();
164
165         /* Not the testing egress traffic */
166         if (sk->protocol != IPPROTO_TCP)
167                 return CG_OK;
168
169         tp = bpf_tcp_sock(sk);
170         if (!tp)
171                 RET_LOG();
172
173         skcpy(sk_ret, sk);
174         tpcpy(tp_ret, tp);
175
176         if (sk_ret == &srv_sk) {
177                 ktp = bpf_skc_to_tcp_sock(sk);
178
179                 if (!ktp)
180                         RET_LOG();
181
182                 lsndtime = ktp->lsndtime;
183
184                 child_cg_id = bpf_sk_cgroup_id(ktp);
185                 if (!child_cg_id)
186                         RET_LOG();
187
188                 parent_cg_id = bpf_sk_ancestor_cgroup_id(ktp, 2);
189                 if (!parent_cg_id)
190                         RET_LOG();
191
192                 /* The userspace has created it for srv sk */
193                 pkt_out_cnt = bpf_sk_storage_get(&sk_pkt_out_cnt, ktp, 0, 0);
194                 pkt_out_cnt10 = bpf_sk_storage_get(&sk_pkt_out_cnt10, ktp,
195                                                    0, 0);
196         } else {
197                 pkt_out_cnt = bpf_sk_storage_get(&sk_pkt_out_cnt, sk,
198                                                  &cli_cnt_init,
199                                                  BPF_SK_STORAGE_GET_F_CREATE);
200                 pkt_out_cnt10 = bpf_sk_storage_get(&sk_pkt_out_cnt10,
201                                                    sk, &cli_cnt_init,
202                                                    BPF_SK_STORAGE_GET_F_CREATE);
203         }
204
205         if (!pkt_out_cnt || !pkt_out_cnt10)
206                 RET_LOG();
207
208         /* Even both cnt and cnt10 have lock defined in their BTF,
209          * intentionally one cnt takes lock while one does not
210          * as a test for the spinlock support in BPF_MAP_TYPE_SK_STORAGE.
211          */
212         pkt_out_cnt->cnt += 1;
213         bpf_spin_lock(&pkt_out_cnt10->lock);
214         pkt_out_cnt10->cnt += 10;
215         bpf_spin_unlock(&pkt_out_cnt10->lock);
216
217         return CG_OK;
218 }
219
220 SEC("cgroup_skb/ingress")
221 int ingress_read_sock_fields(struct __sk_buff *skb)
222 {
223         struct bpf_tcp_sock *tp;
224         __u32 linum, linum_idx;
225         struct bpf_sock *sk;
226
227         linum_idx = INGRESS_LINUM_IDX;
228
229         sk = skb->sk;
230         if (!sk)
231                 RET_LOG();
232
233         /* Not the testing ingress traffic to the server */
234         if (sk->family != AF_INET6 || !is_loopback6(sk->src_ip6) ||
235             sk->src_port != bpf_ntohs(srv_sa6.sin6_port))
236                 return CG_OK;
237
238         /* Only interested in the listening socket */
239         if (sk->state != BPF_TCP_LISTEN)
240                 return CG_OK;
241
242         /* It must be a fullsock for cgroup_skb/ingress prog */
243         sk = bpf_sk_fullsock(sk);
244         if (!sk)
245                 RET_LOG();
246
247         tp = bpf_tcp_sock(sk);
248         if (!tp)
249                 RET_LOG();
250
251         skcpy(&listen_sk, sk);
252         tpcpy(&listen_tp, tp);
253
254         return CG_OK;
255 }
256
257 /*
258  * NOTE: 4-byte load from bpf_sock at dst_port offset is quirky. It
259  * gets rewritten by the access converter to a 2-byte load for
260  * backward compatibility. Treating the load result as a be16 value
261  * makes the code portable across little- and big-endian platforms.
262  */
263 static __noinline bool sk_dst_port__load_word(struct bpf_sock *sk)
264 {
265         __u32 *word = (__u32 *)&sk->dst_port;
266         return word[0] == bpf_htons(0xcafe);
267 }
268
269 static __noinline bool sk_dst_port__load_half(struct bpf_sock *sk)
270 {
271         __u16 *half;
272
273         asm volatile ("");
274         half = (__u16 *)&sk->dst_port;
275         return half[0] == bpf_htons(0xcafe);
276 }
277
278 static __noinline bool sk_dst_port__load_byte(struct bpf_sock *sk)
279 {
280         __u8 *byte = (__u8 *)&sk->dst_port;
281         return byte[0] == 0xca && byte[1] == 0xfe;
282 }
283
284 SEC("cgroup_skb/egress")
285 int read_sk_dst_port(struct __sk_buff *skb)
286 {
287         __u32 linum, linum_idx;
288         struct bpf_sock *sk;
289
290         linum_idx = READ_SK_DST_PORT_LINUM_IDX;
291
292         sk = skb->sk;
293         if (!sk)
294                 RET_LOG();
295
296         /* Ignore everything but the SYN from the client socket */
297         if (sk->state != BPF_TCP_SYN_SENT)
298                 return CG_OK;
299
300         if (!sk_dst_port__load_word(sk))
301                 RET_LOG();
302         if (!sk_dst_port__load_half(sk))
303                 RET_LOG();
304         if (!sk_dst_port__load_byte(sk))
305                 RET_LOG();
306
307         return CG_OK;
308 }
309
310 char _license[] SEC("license") = "GPL";