Merge tag 'kvm-x86-misc-6.9' of https://github.com/kvm-x86/linux into HEAD
[linux-2.6-block.git] / tools / testing / selftests / net / tcp_ao / lib / sock.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <alloca.h>
3 #include <fcntl.h>
4 #include <inttypes.h>
5 #include <string.h>
6 #include "../../../../../include/linux/kernel.h"
7 #include "../../../../../include/linux/stringify.h"
8 #include "aolib.h"
9
10 const unsigned int test_server_port = 7010;
11 int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
12 {
13         int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
14         long flags;
15
16         if (sk < 0)
17                 test_error("socket()");
18
19         err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
20                          strlen(veth_name) + 1);
21         if (err < 0)
22                 test_error("setsockopt(SO_BINDTODEVICE)");
23
24         if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
25                 test_error("bind()");
26
27         flags = fcntl(sk, F_GETFL);
28         if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
29                 test_error("fcntl()");
30
31         if (listen(sk, backlog))
32                 test_error("listen()");
33
34         return sk;
35 }
36
37 int test_wait_fd(int sk, time_t sec, bool write)
38 {
39         struct timeval tv = { .tv_sec = sec };
40         struct timeval *ptv = NULL;
41         fd_set fds, efds;
42         int ret;
43         socklen_t slen = sizeof(ret);
44
45         FD_ZERO(&fds);
46         FD_SET(sk, &fds);
47         FD_ZERO(&efds);
48         FD_SET(sk, &efds);
49
50         if (sec)
51                 ptv = &tv;
52
53         errno = 0;
54         if (write)
55                 ret = select(sk + 1, NULL, &fds, &efds, ptv);
56         else
57                 ret = select(sk + 1, &fds, NULL, &efds, ptv);
58         if (ret < 0)
59                 return -errno;
60         if (ret == 0) {
61                 errno = ETIMEDOUT;
62                 return -ETIMEDOUT;
63         }
64
65         if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
66                 return -errno;
67         if (ret)
68                 return -ret;
69         return 0;
70 }
71
72 int __test_connect_socket(int sk, const char *device,
73                           void *addr, size_t addr_sz, time_t timeout)
74 {
75         long flags;
76         int err;
77
78         if (device != NULL) {
79                 err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
80                                  strlen(device) + 1);
81                 if (err < 0)
82                         test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
83         }
84
85         if (!timeout) {
86                 err = connect(sk, addr, addr_sz);
87                 if (err) {
88                         err = -errno;
89                         goto out;
90                 }
91                 return 0;
92         }
93
94         flags = fcntl(sk, F_GETFL);
95         if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
96                 test_error("fcntl()");
97
98         if (connect(sk, addr, addr_sz) < 0) {
99                 if (errno != EINPROGRESS) {
100                         err = -errno;
101                         goto out;
102                 }
103                 if (timeout < 0)
104                         return sk;
105                 err = test_wait_fd(sk, timeout, 1);
106                 if (err)
107                         goto out;
108         }
109         return sk;
110
111 out:
112         close(sk);
113         return err;
114 }
115
116 int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
117                    int vrf, const char *password)
118 {
119         size_t pwd_len = strlen(password);
120         struct tcp_md5sig md5sig = {};
121
122         md5sig.tcpm_keylen = pwd_len;
123         memcpy(md5sig.tcpm_key, password, pwd_len);
124         md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
125         md5sig.tcpm_prefixlen = prefix;
126         if (vrf >= 0) {
127                 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
128                 md5sig.tcpm_ifindex = (uint8_t)vrf;
129         }
130         memcpy(&md5sig.tcpm_addr, addr, addr_sz);
131
132         errno = 0;
133         return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
134                         &md5sig, sizeof(md5sig));
135 }
136
137
138 int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
139                 void *addr, size_t addr_sz, bool set_current, bool set_rnext,
140                 uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
141                 uint8_t maclen, uint8_t keyflags,
142                 uint8_t keylen, const char *key)
143 {
144         memset(ao, 0, sizeof(struct tcp_ao_add));
145
146         ao->set_current = !!set_current;
147         ao->set_rnext   = !!set_rnext;
148         ao->prefix      = prefix;
149         ao->sndid       = sndid;
150         ao->rcvid       = rcvid;
151         ao->maclen      = maclen;
152         ao->keyflags    = keyflags;
153         ao->keylen      = keylen;
154         ao->ifindex     = vrf;
155
156         memcpy(&ao->addr, addr, addr_sz);
157
158         if (strlen(alg) > 64)
159                 return -ENOBUFS;
160         strncpy(ao->alg_name, alg, 64);
161
162         memcpy(ao->key, key,
163                (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
164         return 0;
165 }
166
167 static int test_get_ao_keys_nr(int sk)
168 {
169         struct tcp_ao_getsockopt tmp = {};
170         socklen_t tmp_sz = sizeof(tmp);
171         int ret;
172
173         tmp.nkeys  = 1;
174         tmp.get_all = 1;
175
176         ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
177         if (ret)
178                 return -errno;
179         return (int)tmp.nkeys;
180 }
181
182 int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
183                 void *addr, size_t addr_sz, uint8_t prefix,
184                 uint8_t sndid, uint8_t rcvid)
185 {
186         struct tcp_ao_getsockopt tmp = {};
187         socklen_t tmp_sz = sizeof(tmp);
188         int ret;
189
190         memcpy(&tmp.addr, addr, addr_sz);
191         tmp.prefix = prefix;
192         tmp.sndid  = sndid;
193         tmp.rcvid  = rcvid;
194         tmp.nkeys  = 1;
195
196         ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
197         if (ret)
198                 return ret;
199         if (tmp.nkeys != 1)
200                 return -E2BIG;
201         *out = tmp;
202         return 0;
203 }
204
205 int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
206 {
207         socklen_t sz = sizeof(*out);
208
209         out->reserved = 0;
210         out->reserved2 = 0;
211         if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
212                 return -errno;
213         if (sz != sizeof(*out))
214                 return -EMSGSIZE;
215         return 0;
216 }
217
218 int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
219 {
220         socklen_t sz = sizeof(*in);
221
222         in->reserved = 0;
223         in->reserved2 = 0;
224         if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
225                 return -errno;
226         return 0;
227 }
228
229 int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
230                                    const struct tcp_ao_getsockopt *b)
231 {
232         bool is_kdf_aes_128_cmac = false;
233         bool is_cmac_aes = false;
234
235         if (!strcmp("cmac(aes128)", a->alg_name)) {
236                 is_kdf_aes_128_cmac = (a->keylen != 16);
237                 is_cmac_aes = true;
238         }
239
240 #define __cmp_ao(member)                                                \
241 do {                                                                    \
242         if (b->member != a->member) {                                   \
243                 test_fail("getsockopt(): " __stringify(member) " %u != %u",     \
244                                 b->member, a->member);                  \
245                 return -1;                                              \
246         }                                                               \
247 } while(0)
248         __cmp_ao(sndid);
249         __cmp_ao(rcvid);
250         __cmp_ao(prefix);
251         __cmp_ao(keyflags);
252         __cmp_ao(ifindex);
253         if (a->maclen) {
254                 __cmp_ao(maclen);
255         } else if (b->maclen != 12) {
256                 test_fail("getsockopt(): expected default maclen 12, but it's %u",
257                                 b->maclen);
258                 return -1;
259         }
260         if (!is_kdf_aes_128_cmac) {
261                 __cmp_ao(keylen);
262         } else if (b->keylen != 16) {
263                 test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
264                                 b->keylen);
265                 return -1;
266         }
267 #undef __cmp_ao
268         if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
269                 test_fail("getsockopt(): returned key is different `%s' != `%s'",
270                                 b->key, a->key);
271                 return -1;
272         }
273         if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
274                 test_fail("getsockopt(): returned address is different");
275                 return -1;
276         }
277         if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
278                 test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
279                 return -1;
280         }
281         if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
282                 test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
283                 return -1;
284         }
285         /* For a established key rotation test don't add a key with
286          * set_current = 1, as it's likely to change by peer's request;
287          * rather use setsockopt(TCP_AO_INFO)
288          */
289         if (a->set_current != b->is_current) {
290                 test_fail("getsockopt(): returned key is not Current_key");
291                 return -1;
292         }
293         if (a->set_rnext != b->is_rnext) {
294                 test_fail("getsockopt(): returned key is not RNext_key");
295                 return -1;
296         }
297
298         return 0;
299 }
300
301 int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
302                                       const struct tcp_ao_info_opt *b)
303 {
304         /* No check for ::current_key, as it may change by the peer */
305         if (a->ao_required != b->ao_required) {
306                 test_fail("getsockopt(): returned ao doesn't have ao_required");
307                 return -1;
308         }
309         if (a->accept_icmps != b->accept_icmps) {
310                 test_fail("getsockopt(): returned ao doesn't accept ICMPs");
311                 return -1;
312         }
313         if (a->set_rnext && a->rnext != b->rnext) {
314                 test_fail("getsockopt(): RNext KeyID has changed");
315                 return -1;
316         }
317 #define __cmp_cnt(member)                                               \
318 do {                                                                    \
319         if (b->member != a->member) {                                   \
320                 test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \
321                                 b->member, a->member);                  \
322                 return -1;                                              \
323         }                                                               \
324 } while(0)
325         if (a->set_counters) {
326                 __cmp_cnt(pkt_good);
327                 __cmp_cnt(pkt_bad);
328                 __cmp_cnt(pkt_key_not_found);
329                 __cmp_cnt(pkt_ao_required);
330                 __cmp_cnt(pkt_dropped_icmp);
331         }
332 #undef __cmp_cnt
333         return 0;
334 }
335
336 int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
337 {
338         struct tcp_ao_getsockopt *key_dump;
339         socklen_t key_dump_sz = sizeof(*key_dump);
340         struct tcp_ao_info_opt info = {};
341         bool c1, c2, c3, c4, c5;
342         struct netstat *ns;
343         int err, nr_keys;
344
345         memset(out, 0, sizeof(*out));
346
347         /* per-netns */
348         ns = netstat_read();
349         out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
350         out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
351         out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
352         out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
353         out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
354         netstat_free(ns);
355         if (c1 || c2 || c3 || c4 || c5)
356                 return -EOPNOTSUPP;
357
358         err = test_get_ao_info(sk, &info);
359         if (err)
360                 return err;
361
362         /* per-socket */
363         out->ao_info_pkt_good           = info.pkt_good;
364         out->ao_info_pkt_bad            = info.pkt_bad;
365         out->ao_info_pkt_key_not_found  = info.pkt_key_not_found;
366         out->ao_info_pkt_ao_required    = info.pkt_ao_required;
367         out->ao_info_pkt_dropped_icmp   = info.pkt_dropped_icmp;
368
369         /* per-key */
370         nr_keys = test_get_ao_keys_nr(sk);
371         if (nr_keys < 0)
372                 return nr_keys;
373         if (nr_keys == 0)
374                 test_error("test_get_ao_keys_nr() == 0");
375         out->nr_keys = (size_t)nr_keys;
376         key_dump = calloc(nr_keys, key_dump_sz);
377         if (!key_dump)
378                 return -errno;
379
380         key_dump[0].nkeys = nr_keys;
381         key_dump[0].get_all = 1;
382         key_dump[0].get_all = 1;
383         err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
384                          key_dump, &key_dump_sz);
385         if (err) {
386                 free(key_dump);
387                 return -errno;
388         }
389
390         out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
391         if (!out->key_cnts) {
392                 free(key_dump);
393                 return -errno;
394         }
395
396         while (nr_keys--) {
397                 out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
398                 out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
399                 out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
400                 out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
401         }
402         free(key_dump);
403
404         return 0;
405 }
406
407 int __test_tcp_ao_counters_cmp(const char *tst_name,
408                                struct tcp_ao_counters *before,
409                                struct tcp_ao_counters *after,
410                                test_cnt expected)
411 {
412 #define __cmp_ao(cnt, expecting_inc)                                    \
413 do {                                                                    \
414         if (before->cnt > after->cnt) {                                 \
415                 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
416                           tst_name ?: "", before->cnt, after->cnt);             \
417                 return -1;                                              \
418         }                                                               \
419         if ((before->cnt != after->cnt) != (expecting_inc)) {           \
420                 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
421                           tst_name ?: "", (expecting_inc) ? "" : "not ",        \
422                           before->cnt, after->cnt);                     \
423                 return -1;                                              \
424         }                                                               \
425 } while(0)
426
427         errno = 0;
428         /* per-netns */
429         __cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
430         __cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
431         __cmp_ao(netns_ao_key_not_found,
432                  !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
433         __cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
434         __cmp_ao(netns_ao_dropped_icmp,
435                  !!(expected & TEST_CNT_NS_DROPPED_ICMP));
436         /* per-socket */
437         __cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
438         __cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
439         __cmp_ao(ao_info_pkt_key_not_found,
440                  !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
441         __cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
442         __cmp_ao(ao_info_pkt_dropped_icmp,
443                  !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
444         return 0;
445 #undef __cmp_ao
446 }
447
448 int test_tcp_ao_key_counters_cmp(const char *tst_name,
449                                  struct tcp_ao_counters *before,
450                                  struct tcp_ao_counters *after,
451                                  test_cnt expected,
452                                  int sndid, int rcvid)
453 {
454         size_t i;
455 #define __cmp_ao(i, cnt, expecting_inc)                                 \
456 do {                                                                    \
457         if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) {         \
458                 test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
459                           tst_name ?: "", before->key_cnts[i].cnt,      \
460                           after->key_cnts[i].cnt,                       \
461                           before->key_cnts[i].sndid,                    \
462                           before->key_cnts[i].rcvid);                   \
463                 return -1;                                              \
464         }                                                               \
465         if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) {           \
466                 test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
467                           tst_name ?: "", (expecting_inc) ? "" : "not ",\
468                           before->key_cnts[i].cnt,                      \
469                           after->key_cnts[i].cnt,                       \
470                           before->key_cnts[i].sndid,                    \
471                           before->key_cnts[i].rcvid);                   \
472                 return -1;                                              \
473         }                                                               \
474 } while(0)
475
476         if (before->nr_keys != after->nr_keys) {
477                 test_fail("%s: Keys changed on the socket %zu != %zu",
478                           tst_name, before->nr_keys, after->nr_keys);
479                 return -1;
480         }
481
482         /* per-key */
483         i = before->nr_keys;
484         while (i--) {
485                 if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
486                         continue;
487                 if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
488                         continue;
489                 __cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
490                 __cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
491         }
492         return 0;
493 #undef __cmp_ao
494 }
495
496 void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
497 {
498         free(cnts->key_cnts);
499 }
500
501 #define TEST_BUF_SIZE 4096
502 ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
503 {
504         ssize_t total = 0;
505
506         do {
507                 char buf[TEST_BUF_SIZE];
508                 ssize_t bytes, sent;
509                 int ret;
510
511                 ret = test_wait_fd(sk, timeout_sec, 0);
512                 if (ret)
513                         return ret;
514
515                 bytes = recv(sk, buf, sizeof(buf), 0);
516
517                 if (bytes < 0)
518                         test_error("recv(): %zd", bytes);
519                 if (bytes == 0)
520                         break;
521
522                 ret = test_wait_fd(sk, timeout_sec, 1);
523                 if (ret)
524                         return ret;
525
526                 sent = send(sk, buf, bytes, 0);
527                 if (sent == 0)
528                         break;
529                 if (sent != bytes)
530                         test_error("send()");
531                 total += bytes;
532         } while (!quota || total < quota);
533
534         return total;
535 }
536
537 ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
538                          const size_t msg_len, time_t timeout_sec)
539 {
540         char msg[msg_len];
541         int nodelay = 1;
542         size_t i;
543
544         if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
545                 test_error("setsockopt(TCP_NODELAY)");
546
547         for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
548                 size_t sent, bytes = min(msg_len, buf_sz - i);
549                 int ret;
550
551                 ret = test_wait_fd(sk, timeout_sec, 1);
552                 if (ret)
553                         return ret;
554
555                 sent = send(sk, buf + i, bytes, 0);
556                 if (sent == 0)
557                         break;
558                 if (sent != bytes)
559                         test_error("send()");
560
561                 bytes = 0;
562                 do {
563                         ssize_t got;
564
565                         ret = test_wait_fd(sk, timeout_sec, 0);
566                         if (ret)
567                                 return ret;
568
569                         got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
570                         if (got <= 0)
571                                 return i;
572                         bytes += got;
573                 } while (bytes < sent);
574                 if (bytes > sent)
575                         test_error("recv(): %zd > %zd", bytes, sent);
576                 if (memcmp(buf + i, msg, bytes) != 0) {
577                         test_fail("received message differs");
578                         return -1;
579                 }
580         }
581         return i;
582 }
583
584 int test_client_verify(int sk, const size_t msg_len, const size_t nr,
585                        time_t timeout_sec)
586 {
587         size_t buf_sz = msg_len * nr;
588         char *buf = alloca(buf_sz);
589         ssize_t ret;
590
591         randomize_buffer(buf, buf_sz);
592         ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec);
593         if (ret < 0)
594                 return (int)ret;
595         return ret != buf_sz ? -1 : 0;
596 }