selftests: tls: test splicing decrypted records
[linux-block.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 struct tls_crypto_info_keys {
29         union {
30                 struct tls12_crypto_info_aes_gcm_128 aes128;
31                 struct tls12_crypto_info_chacha20_poly1305 chacha20;
32                 struct tls12_crypto_info_sm4_gcm sm4gcm;
33                 struct tls12_crypto_info_sm4_ccm sm4ccm;
34         };
35         size_t len;
36 };
37
38 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
39                                  struct tls_crypto_info_keys *tls12)
40 {
41         memset(tls12, 0, sizeof(*tls12));
42
43         switch (cipher_type) {
44         case TLS_CIPHER_CHACHA20_POLY1305:
45                 tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
46                 tls12->chacha20.info.version = tls_version;
47                 tls12->chacha20.info.cipher_type = cipher_type;
48                 break;
49         case TLS_CIPHER_AES_GCM_128:
50                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
51                 tls12->aes128.info.version = tls_version;
52                 tls12->aes128.info.cipher_type = cipher_type;
53                 break;
54         case TLS_CIPHER_SM4_GCM:
55                 tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
56                 tls12->sm4gcm.info.version = tls_version;
57                 tls12->sm4gcm.info.cipher_type = cipher_type;
58                 break;
59         case TLS_CIPHER_SM4_CCM:
60                 tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
61                 tls12->sm4ccm.info.version = tls_version;
62                 tls12->sm4ccm.info.cipher_type = cipher_type;
63                 break;
64         default:
65                 break;
66         }
67 }
68
69 static void memrnd(void *s, size_t n)
70 {
71         int *dword = s;
72         char *byte;
73
74         for (; n >= 4; n -= 4)
75                 *dword++ = rand();
76         byte = (void *)dword;
77         while (n--)
78                 *byte++ = rand();
79 }
80
81 static void ulp_sock_pair(struct __test_metadata *_metadata,
82                           int *fd, int *cfd, bool *notls)
83 {
84         struct sockaddr_in addr;
85         socklen_t len;
86         int sfd, ret;
87
88         *notls = false;
89         len = sizeof(addr);
90
91         addr.sin_family = AF_INET;
92         addr.sin_addr.s_addr = htonl(INADDR_ANY);
93         addr.sin_port = 0;
94
95         *fd = socket(AF_INET, SOCK_STREAM, 0);
96         sfd = socket(AF_INET, SOCK_STREAM, 0);
97
98         ret = bind(sfd, &addr, sizeof(addr));
99         ASSERT_EQ(ret, 0);
100         ret = listen(sfd, 10);
101         ASSERT_EQ(ret, 0);
102
103         ret = getsockname(sfd, &addr, &len);
104         ASSERT_EQ(ret, 0);
105
106         ret = connect(*fd, &addr, sizeof(addr));
107         ASSERT_EQ(ret, 0);
108
109         *cfd = accept(sfd, &addr, &len);
110         ASSERT_GE(*cfd, 0);
111
112         close(sfd);
113
114         ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
115         if (ret != 0) {
116                 ASSERT_EQ(errno, ENOENT);
117                 *notls = true;
118                 printf("Failure setting TCP_ULP, testing without tls\n");
119                 return;
120         }
121
122         ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
123         ASSERT_EQ(ret, 0);
124 }
125
126 /* Produce a basic cmsg */
127 static int tls_send_cmsg(int fd, unsigned char record_type,
128                          void *data, size_t len, int flags)
129 {
130         char cbuf[CMSG_SPACE(sizeof(char))];
131         int cmsg_len = sizeof(char);
132         struct cmsghdr *cmsg;
133         struct msghdr msg;
134         struct iovec vec;
135
136         vec.iov_base = data;
137         vec.iov_len = len;
138         memset(&msg, 0, sizeof(struct msghdr));
139         msg.msg_iov = &vec;
140         msg.msg_iovlen = 1;
141         msg.msg_control = cbuf;
142         msg.msg_controllen = sizeof(cbuf);
143         cmsg = CMSG_FIRSTHDR(&msg);
144         cmsg->cmsg_level = SOL_TLS;
145         /* test sending non-record types. */
146         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
147         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
148         *CMSG_DATA(cmsg) = record_type;
149         msg.msg_controllen = cmsg->cmsg_len;
150
151         return sendmsg(fd, &msg, flags);
152 }
153
154 static int tls_recv_cmsg(struct __test_metadata *_metadata,
155                          int fd, unsigned char record_type,
156                          void *data, size_t len, int flags)
157 {
158         char cbuf[CMSG_SPACE(sizeof(char))];
159         struct cmsghdr *cmsg;
160         unsigned char ctype;
161         struct msghdr msg;
162         struct iovec vec;
163         int n;
164
165         vec.iov_base = data;
166         vec.iov_len = len;
167         memset(&msg, 0, sizeof(struct msghdr));
168         msg.msg_iov = &vec;
169         msg.msg_iovlen = 1;
170         msg.msg_control = cbuf;
171         msg.msg_controllen = sizeof(cbuf);
172
173         n = recvmsg(fd, &msg, flags);
174
175         cmsg = CMSG_FIRSTHDR(&msg);
176         EXPECT_NE(cmsg, NULL);
177         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
178         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
179         ctype = *((unsigned char *)CMSG_DATA(cmsg));
180         EXPECT_EQ(ctype, record_type);
181
182         return n;
183 }
184
185 FIXTURE(tls_basic)
186 {
187         int fd, cfd;
188         bool notls;
189 };
190
191 FIXTURE_SETUP(tls_basic)
192 {
193         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
194 }
195
196 FIXTURE_TEARDOWN(tls_basic)
197 {
198         close(self->fd);
199         close(self->cfd);
200 }
201
202 /* Send some data through with ULP but no keys */
203 TEST_F(tls_basic, base_base)
204 {
205         char const *test_str = "test_read";
206         int send_len = 10;
207         char buf[10];
208
209         ASSERT_EQ(strlen(test_str) + 1, send_len);
210
211         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
212         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
213         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
214 };
215
216 FIXTURE(tls)
217 {
218         int fd, cfd;
219         bool notls;
220 };
221
222 FIXTURE_VARIANT(tls)
223 {
224         uint16_t tls_version;
225         uint16_t cipher_type;
226 };
227
228 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
229 {
230         .tls_version = TLS_1_2_VERSION,
231         .cipher_type = TLS_CIPHER_AES_GCM_128,
232 };
233
234 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
235 {
236         .tls_version = TLS_1_3_VERSION,
237         .cipher_type = TLS_CIPHER_AES_GCM_128,
238 };
239
240 FIXTURE_VARIANT_ADD(tls, 12_chacha)
241 {
242         .tls_version = TLS_1_2_VERSION,
243         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
244 };
245
246 FIXTURE_VARIANT_ADD(tls, 13_chacha)
247 {
248         .tls_version = TLS_1_3_VERSION,
249         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
250 };
251
252 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
253 {
254         .tls_version = TLS_1_3_VERSION,
255         .cipher_type = TLS_CIPHER_SM4_GCM,
256 };
257
258 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
259 {
260         .tls_version = TLS_1_3_VERSION,
261         .cipher_type = TLS_CIPHER_SM4_CCM,
262 };
263
264 FIXTURE_SETUP(tls)
265 {
266         struct tls_crypto_info_keys tls12;
267         int ret;
268
269         tls_crypto_info_init(variant->tls_version, variant->cipher_type,
270                              &tls12);
271
272         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
273
274         if (self->notls)
275                 return;
276
277         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
278         ASSERT_EQ(ret, 0);
279
280         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
281         ASSERT_EQ(ret, 0);
282 }
283
284 FIXTURE_TEARDOWN(tls)
285 {
286         close(self->fd);
287         close(self->cfd);
288 }
289
290 TEST_F(tls, sendfile)
291 {
292         int filefd = open("/proc/self/exe", O_RDONLY);
293         struct stat st;
294
295         EXPECT_GE(filefd, 0);
296         fstat(filefd, &st);
297         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
298 }
299
300 TEST_F(tls, send_then_sendfile)
301 {
302         int filefd = open("/proc/self/exe", O_RDONLY);
303         char const *test_str = "test_send";
304         int to_send = strlen(test_str) + 1;
305         char recv_buf[10];
306         struct stat st;
307         char *buf;
308
309         EXPECT_GE(filefd, 0);
310         fstat(filefd, &st);
311         buf = (char *)malloc(st.st_size);
312
313         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
314         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
315         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
316
317         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
318         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
319 }
320
321 static void chunked_sendfile(struct __test_metadata *_metadata,
322                              struct _test_data_tls *self,
323                              uint16_t chunk_size,
324                              uint16_t extra_payload_size)
325 {
326         char buf[TLS_PAYLOAD_MAX_LEN];
327         uint16_t test_payload_size;
328         int size = 0;
329         int ret;
330         char filename[] = "/tmp/mytemp.XXXXXX";
331         int fd = mkstemp(filename);
332         off_t offset = 0;
333
334         unlink(filename);
335         ASSERT_GE(fd, 0);
336         EXPECT_GE(chunk_size, 1);
337         test_payload_size = chunk_size + extra_payload_size;
338         ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
339         memset(buf, 1, test_payload_size);
340         size = write(fd, buf, test_payload_size);
341         EXPECT_EQ(size, test_payload_size);
342         fsync(fd);
343
344         while (size > 0) {
345                 ret = sendfile(self->fd, fd, &offset, chunk_size);
346                 EXPECT_GE(ret, 0);
347                 size -= ret;
348         }
349
350         EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
351                   test_payload_size);
352
353         close(fd);
354 }
355
356 TEST_F(tls, multi_chunk_sendfile)
357 {
358         chunked_sendfile(_metadata, self, 4096, 4096);
359         chunked_sendfile(_metadata, self, 4096, 0);
360         chunked_sendfile(_metadata, self, 4096, 1);
361         chunked_sendfile(_metadata, self, 4096, 2048);
362         chunked_sendfile(_metadata, self, 8192, 2048);
363         chunked_sendfile(_metadata, self, 4096, 8192);
364         chunked_sendfile(_metadata, self, 8192, 4096);
365         chunked_sendfile(_metadata, self, 12288, 1024);
366         chunked_sendfile(_metadata, self, 12288, 2000);
367         chunked_sendfile(_metadata, self, 15360, 100);
368         chunked_sendfile(_metadata, self, 15360, 300);
369         chunked_sendfile(_metadata, self, 1, 4096);
370         chunked_sendfile(_metadata, self, 2048, 4096);
371         chunked_sendfile(_metadata, self, 2048, 8192);
372         chunked_sendfile(_metadata, self, 4096, 8192);
373         chunked_sendfile(_metadata, self, 1024, 12288);
374         chunked_sendfile(_metadata, self, 2000, 12288);
375         chunked_sendfile(_metadata, self, 100, 15360);
376         chunked_sendfile(_metadata, self, 300, 15360);
377 }
378
379 TEST_F(tls, recv_max)
380 {
381         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
382         char recv_mem[TLS_PAYLOAD_MAX_LEN];
383         char buf[TLS_PAYLOAD_MAX_LEN];
384
385         memrnd(buf, sizeof(buf));
386
387         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
388         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
389         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
390 }
391
392 TEST_F(tls, recv_small)
393 {
394         char const *test_str = "test_read";
395         int send_len = 10;
396         char buf[10];
397
398         send_len = strlen(test_str) + 1;
399         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
400         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
401         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
402 }
403
404 TEST_F(tls, msg_more)
405 {
406         char const *test_str = "test_read";
407         int send_len = 10;
408         char buf[10 * 2];
409
410         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
411         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
412         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
413         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
414                   send_len * 2);
415         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
416 }
417
418 TEST_F(tls, msg_more_unsent)
419 {
420         char const *test_str = "test_read";
421         int send_len = 10;
422         char buf[10];
423
424         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
425         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
426 }
427
428 TEST_F(tls, sendmsg_single)
429 {
430         struct msghdr msg;
431
432         char const *test_str = "test_sendmsg";
433         size_t send_len = 13;
434         struct iovec vec;
435         char buf[13];
436
437         vec.iov_base = (char *)test_str;
438         vec.iov_len = send_len;
439         memset(&msg, 0, sizeof(struct msghdr));
440         msg.msg_iov = &vec;
441         msg.msg_iovlen = 1;
442         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
443         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
444         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
445 }
446
447 #define MAX_FRAGS       64
448 #define SEND_LEN        13
449 TEST_F(tls, sendmsg_fragmented)
450 {
451         char const *test_str = "test_sendmsg";
452         char buf[SEND_LEN * MAX_FRAGS];
453         struct iovec vec[MAX_FRAGS];
454         struct msghdr msg;
455         int i, frags;
456
457         for (frags = 1; frags <= MAX_FRAGS; frags++) {
458                 for (i = 0; i < frags; i++) {
459                         vec[i].iov_base = (char *)test_str;
460                         vec[i].iov_len = SEND_LEN;
461                 }
462
463                 memset(&msg, 0, sizeof(struct msghdr));
464                 msg.msg_iov = vec;
465                 msg.msg_iovlen = frags;
466
467                 EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
468                 EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
469                           SEND_LEN * frags);
470
471                 for (i = 0; i < frags; i++)
472                         EXPECT_EQ(memcmp(buf + SEND_LEN * i,
473                                          test_str, SEND_LEN), 0);
474         }
475 }
476 #undef MAX_FRAGS
477 #undef SEND_LEN
478
479 TEST_F(tls, sendmsg_large)
480 {
481         void *mem = malloc(16384);
482         size_t send_len = 16384;
483         size_t sends = 128;
484         struct msghdr msg;
485         size_t recvs = 0;
486         size_t sent = 0;
487
488         memset(&msg, 0, sizeof(struct msghdr));
489         while (sent++ < sends) {
490                 struct iovec vec = { (void *)mem, send_len };
491
492                 msg.msg_iov = &vec;
493                 msg.msg_iovlen = 1;
494                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
495         }
496
497         while (recvs++ < sends) {
498                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
499         }
500
501         free(mem);
502 }
503
504 TEST_F(tls, sendmsg_multiple)
505 {
506         char const *test_str = "test_sendmsg_multiple";
507         struct iovec vec[5];
508         char *test_strs[5];
509         struct msghdr msg;
510         int total_len = 0;
511         int len_cmp = 0;
512         int iov_len = 5;
513         char *buf;
514         int i;
515
516         memset(&msg, 0, sizeof(struct msghdr));
517         for (i = 0; i < iov_len; i++) {
518                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
519                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
520                 vec[i].iov_base = (void *)test_strs[i];
521                 vec[i].iov_len = strlen(test_strs[i]) + 1;
522                 total_len += vec[i].iov_len;
523         }
524         msg.msg_iov = vec;
525         msg.msg_iovlen = iov_len;
526
527         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
528         buf = malloc(total_len);
529         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
530         for (i = 0; i < iov_len; i++) {
531                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
532                                  strlen(test_strs[i])),
533                           0);
534                 len_cmp += strlen(buf + len_cmp) + 1;
535         }
536         for (i = 0; i < iov_len; i++)
537                 free(test_strs[i]);
538         free(buf);
539 }
540
541 TEST_F(tls, sendmsg_multiple_stress)
542 {
543         char const *test_str = "abcdefghijklmno";
544         struct iovec vec[1024];
545         char *test_strs[1024];
546         int iov_len = 1024;
547         int total_len = 0;
548         char buf[1 << 14];
549         struct msghdr msg;
550         int len_cmp = 0;
551         int i;
552
553         memset(&msg, 0, sizeof(struct msghdr));
554         for (i = 0; i < iov_len; i++) {
555                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
556                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
557                 vec[i].iov_base = (void *)test_strs[i];
558                 vec[i].iov_len = strlen(test_strs[i]) + 1;
559                 total_len += vec[i].iov_len;
560         }
561         msg.msg_iov = vec;
562         msg.msg_iovlen = iov_len;
563
564         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
565         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
566
567         for (i = 0; i < iov_len; i++)
568                 len_cmp += strlen(buf + len_cmp) + 1;
569
570         for (i = 0; i < iov_len; i++)
571                 free(test_strs[i]);
572 }
573
574 TEST_F(tls, splice_from_pipe)
575 {
576         int send_len = TLS_PAYLOAD_MAX_LEN;
577         char mem_send[TLS_PAYLOAD_MAX_LEN];
578         char mem_recv[TLS_PAYLOAD_MAX_LEN];
579         int p[2];
580
581         ASSERT_GE(pipe(p), 0);
582         EXPECT_GE(write(p[1], mem_send, send_len), 0);
583         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
584         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
585         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
586 }
587
588 TEST_F(tls, splice_from_pipe2)
589 {
590         int send_len = 16000;
591         char mem_send[16000];
592         char mem_recv[16000];
593         int p2[2];
594         int p[2];
595
596         ASSERT_GE(pipe(p), 0);
597         ASSERT_GE(pipe(p2), 0);
598         EXPECT_GE(write(p[1], mem_send, 8000), 0);
599         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
600         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
601         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
602         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
603         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
604 }
605
606 TEST_F(tls, send_and_splice)
607 {
608         int send_len = TLS_PAYLOAD_MAX_LEN;
609         char mem_send[TLS_PAYLOAD_MAX_LEN];
610         char mem_recv[TLS_PAYLOAD_MAX_LEN];
611         char const *test_str = "test_read";
612         int send_len2 = 10;
613         char buf[10];
614         int p[2];
615
616         ASSERT_GE(pipe(p), 0);
617         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
618         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
619         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
620
621         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
622         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
623
624         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
625         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
626 }
627
628 TEST_F(tls, splice_to_pipe)
629 {
630         int send_len = TLS_PAYLOAD_MAX_LEN;
631         char mem_send[TLS_PAYLOAD_MAX_LEN];
632         char mem_recv[TLS_PAYLOAD_MAX_LEN];
633         int p[2];
634
635         ASSERT_GE(pipe(p), 0);
636         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
637         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
638         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
639         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
640 }
641
642 TEST_F(tls, splice_cmsg_to_pipe)
643 {
644         char *test_str = "test_read";
645         char record_type = 100;
646         int send_len = 10;
647         char buf[10];
648         int p[2];
649
650         ASSERT_GE(pipe(p), 0);
651         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
652         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
653         EXPECT_EQ(errno, EINVAL);
654         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
655         EXPECT_EQ(errno, EIO);
656         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
657                                 buf, sizeof(buf), MSG_WAITALL),
658                   send_len);
659         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
660 }
661
662 TEST_F(tls, splice_dec_cmsg_to_pipe)
663 {
664         char *test_str = "test_read";
665         char record_type = 100;
666         int send_len = 10;
667         char buf[10];
668         int p[2];
669
670         ASSERT_GE(pipe(p), 0);
671         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
672         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
673         EXPECT_EQ(errno, EIO);
674         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
675         EXPECT_EQ(errno, EINVAL);
676         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
677                                 buf, sizeof(buf), MSG_WAITALL),
678                   send_len);
679         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
680 }
681
682 TEST_F(tls, recv_and_splice)
683 {
684         int send_len = TLS_PAYLOAD_MAX_LEN;
685         char mem_send[TLS_PAYLOAD_MAX_LEN];
686         char mem_recv[TLS_PAYLOAD_MAX_LEN];
687         int half = send_len / 2;
688         int p[2];
689
690         ASSERT_GE(pipe(p), 0);
691         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
692         /* Recv hald of the record, splice the other half */
693         EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
694         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
695                   half);
696         EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
697         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
698 }
699
700 TEST_F(tls, peek_and_splice)
701 {
702         int send_len = TLS_PAYLOAD_MAX_LEN;
703         char mem_send[TLS_PAYLOAD_MAX_LEN];
704         char mem_recv[TLS_PAYLOAD_MAX_LEN];
705         int chunk = TLS_PAYLOAD_MAX_LEN / 4;
706         int n, i, p[2];
707
708         memrnd(mem_send, sizeof(mem_send));
709
710         ASSERT_GE(pipe(p), 0);
711         for (i = 0; i < 4; i++)
712                 EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
713                           chunk);
714
715         EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
716                        MSG_WAITALL | MSG_PEEK),
717                   chunk * 5 / 2);
718         EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
719
720         n = 0;
721         while (n < send_len) {
722                 i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
723                 EXPECT_GT(i, 0);
724                 n += i;
725         }
726         EXPECT_EQ(n, send_len);
727         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
728         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
729 }
730
731 TEST_F(tls, recvmsg_single)
732 {
733         char const *test_str = "test_recvmsg_single";
734         int send_len = strlen(test_str) + 1;
735         char buf[20];
736         struct msghdr hdr;
737         struct iovec vec;
738
739         memset(&hdr, 0, sizeof(hdr));
740         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
741         vec.iov_base = (char *)buf;
742         vec.iov_len = send_len;
743         hdr.msg_iovlen = 1;
744         hdr.msg_iov = &vec;
745         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
746         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
747 }
748
749 TEST_F(tls, recvmsg_single_max)
750 {
751         int send_len = TLS_PAYLOAD_MAX_LEN;
752         char send_mem[TLS_PAYLOAD_MAX_LEN];
753         char recv_mem[TLS_PAYLOAD_MAX_LEN];
754         struct iovec vec;
755         struct msghdr hdr;
756
757         memrnd(send_mem, sizeof(send_mem));
758
759         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
760         vec.iov_base = (char *)recv_mem;
761         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
762
763         hdr.msg_iovlen = 1;
764         hdr.msg_iov = &vec;
765         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
766         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
767 }
768
769 TEST_F(tls, recvmsg_multiple)
770 {
771         unsigned int msg_iovlen = 1024;
772         struct iovec vec[1024];
773         char *iov_base[1024];
774         unsigned int iov_len = 16;
775         int send_len = 1 << 14;
776         char buf[1 << 14];
777         struct msghdr hdr;
778         int i;
779
780         memrnd(buf, sizeof(buf));
781
782         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
783         for (i = 0; i < msg_iovlen; i++) {
784                 iov_base[i] = (char *)malloc(iov_len);
785                 vec[i].iov_base = iov_base[i];
786                 vec[i].iov_len = iov_len;
787         }
788
789         hdr.msg_iovlen = msg_iovlen;
790         hdr.msg_iov = vec;
791         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
792
793         for (i = 0; i < msg_iovlen; i++)
794                 free(iov_base[i]);
795 }
796
797 TEST_F(tls, single_send_multiple_recv)
798 {
799         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
800         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
801         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
802         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
803
804         memrnd(send_mem, sizeof(send_mem));
805
806         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
807         memset(recv_mem, 0, total_len);
808
809         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
810         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
811         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
812 }
813
814 TEST_F(tls, multiple_send_single_recv)
815 {
816         unsigned int total_len = 2 * 10;
817         unsigned int send_len = 10;
818         char recv_mem[2 * 10];
819         char send_mem[10];
820
821         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
822         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
823         memset(recv_mem, 0, total_len);
824         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
825
826         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
827         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
828 }
829
830 TEST_F(tls, single_send_multiple_recv_non_align)
831 {
832         const unsigned int total_len = 15;
833         const unsigned int recv_len = 10;
834         char recv_mem[recv_len * 2];
835         char send_mem[total_len];
836
837         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
838         memset(recv_mem, 0, total_len);
839
840         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
841         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
842         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
843 }
844
845 TEST_F(tls, recv_partial)
846 {
847         char const *test_str = "test_read_partial";
848         char const *test_str_first = "test_read";
849         char const *test_str_second = "_partial";
850         int send_len = strlen(test_str) + 1;
851         char recv_mem[18];
852
853         memset(recv_mem, 0, sizeof(recv_mem));
854         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
855         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
856                        MSG_WAITALL), -1);
857         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
858         memset(recv_mem, 0, sizeof(recv_mem));
859         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
860                        MSG_WAITALL), -1);
861         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
862                   0);
863 }
864
865 TEST_F(tls, recv_nonblock)
866 {
867         char buf[4096];
868         bool err;
869
870         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
871         err = (errno == EAGAIN || errno == EWOULDBLOCK);
872         EXPECT_EQ(err, true);
873 }
874
875 TEST_F(tls, recv_peek)
876 {
877         char const *test_str = "test_read_peek";
878         int send_len = strlen(test_str) + 1;
879         char buf[15];
880
881         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
882         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
883         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
884         memset(buf, 0, sizeof(buf));
885         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
886         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
887 }
888
889 TEST_F(tls, recv_peek_multiple)
890 {
891         char const *test_str = "test_read_peek";
892         int send_len = strlen(test_str) + 1;
893         unsigned int num_peeks = 100;
894         char buf[15];
895         int i;
896
897         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
898         for (i = 0; i < num_peeks; i++) {
899                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
900                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
901                 memset(buf, 0, sizeof(buf));
902         }
903         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
904         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
905 }
906
907 TEST_F(tls, recv_peek_multiple_records)
908 {
909         char const *test_str = "test_read_peek_mult_recs";
910         char const *test_str_first = "test_read_peek";
911         char const *test_str_second = "_mult_recs";
912         int len;
913         char buf[64];
914
915         len = strlen(test_str_first);
916         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
917
918         len = strlen(test_str_second) + 1;
919         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
920
921         len = strlen(test_str_first);
922         memset(buf, 0, len);
923         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
924
925         /* MSG_PEEK can only peek into the current record. */
926         len = strlen(test_str_first);
927         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
928
929         len = strlen(test_str) + 1;
930         memset(buf, 0, len);
931         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
932
933         /* Non-MSG_PEEK will advance strparser (and therefore record)
934          * however.
935          */
936         len = strlen(test_str) + 1;
937         EXPECT_EQ(memcmp(test_str, buf, len), 0);
938
939         /* MSG_MORE will hold current record open, so later MSG_PEEK
940          * will see everything.
941          */
942         len = strlen(test_str_first);
943         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
944
945         len = strlen(test_str_second) + 1;
946         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
947
948         len = strlen(test_str) + 1;
949         memset(buf, 0, len);
950         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
951
952         len = strlen(test_str) + 1;
953         EXPECT_EQ(memcmp(test_str, buf, len), 0);
954 }
955
956 TEST_F(tls, recv_peek_large_buf_mult_recs)
957 {
958         char const *test_str = "test_read_peek_mult_recs";
959         char const *test_str_first = "test_read_peek";
960         char const *test_str_second = "_mult_recs";
961         int len;
962         char buf[64];
963
964         len = strlen(test_str_first);
965         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
966
967         len = strlen(test_str_second) + 1;
968         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
969
970         len = strlen(test_str) + 1;
971         memset(buf, 0, len);
972         EXPECT_NE((len = recv(self->cfd, buf, len,
973                               MSG_PEEK | MSG_WAITALL)), -1);
974         len = strlen(test_str) + 1;
975         EXPECT_EQ(memcmp(test_str, buf, len), 0);
976 }
977
978 TEST_F(tls, recv_lowat)
979 {
980         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
981         char recv_mem[20];
982         int lowat = 8;
983
984         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
985         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
986
987         memset(recv_mem, 0, 20);
988         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
989                              &lowat, sizeof(lowat)), 0);
990         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
991         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
992         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
993
994         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
995         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
996 }
997
998 TEST_F(tls, bidir)
999 {
1000         char const *test_str = "test_read";
1001         int send_len = 10;
1002         char buf[10];
1003         int ret;
1004
1005         if (!self->notls) {
1006                 struct tls_crypto_info_keys tls12;
1007
1008                 tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1009                                      &tls12);
1010
1011                 ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1012                                  tls12.len);
1013                 ASSERT_EQ(ret, 0);
1014
1015                 ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1016                                  tls12.len);
1017                 ASSERT_EQ(ret, 0);
1018         }
1019
1020         ASSERT_EQ(strlen(test_str) + 1, send_len);
1021
1022         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1023         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1024         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1025
1026         memset(buf, 0, sizeof(buf));
1027
1028         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1029         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1030         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1031 };
1032
1033 TEST_F(tls, pollin)
1034 {
1035         char const *test_str = "test_poll";
1036         struct pollfd fd = { 0, 0, 0 };
1037         char buf[10];
1038         int send_len = 10;
1039
1040         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1041         fd.fd = self->cfd;
1042         fd.events = POLLIN;
1043
1044         EXPECT_EQ(poll(&fd, 1, 20), 1);
1045         EXPECT_EQ(fd.revents & POLLIN, 1);
1046         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1047         /* Test timing out */
1048         EXPECT_EQ(poll(&fd, 1, 20), 0);
1049 }
1050
1051 TEST_F(tls, poll_wait)
1052 {
1053         char const *test_str = "test_poll_wait";
1054         int send_len = strlen(test_str) + 1;
1055         struct pollfd fd = { 0, 0, 0 };
1056         char recv_mem[15];
1057
1058         fd.fd = self->cfd;
1059         fd.events = POLLIN;
1060         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1061         /* Set timeout to inf. secs */
1062         EXPECT_EQ(poll(&fd, 1, -1), 1);
1063         EXPECT_EQ(fd.revents & POLLIN, 1);
1064         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1065 }
1066
1067 TEST_F(tls, poll_wait_split)
1068 {
1069         struct pollfd fd = { 0, 0, 0 };
1070         char send_mem[20] = {};
1071         char recv_mem[15];
1072
1073         fd.fd = self->cfd;
1074         fd.events = POLLIN;
1075         /* Send 20 bytes */
1076         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1077                   sizeof(send_mem));
1078         /* Poll with inf. timeout */
1079         EXPECT_EQ(poll(&fd, 1, -1), 1);
1080         EXPECT_EQ(fd.revents & POLLIN, 1);
1081         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1082                   sizeof(recv_mem));
1083
1084         /* Now the remaining 5 bytes of record data are in TLS ULP */
1085         fd.fd = self->cfd;
1086         fd.events = POLLIN;
1087         EXPECT_EQ(poll(&fd, 1, -1), 1);
1088         EXPECT_EQ(fd.revents & POLLIN, 1);
1089         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1090                   sizeof(send_mem) - sizeof(recv_mem));
1091 }
1092
1093 TEST_F(tls, blocking)
1094 {
1095         size_t data = 100000;
1096         int res = fork();
1097
1098         EXPECT_NE(res, -1);
1099
1100         if (res) {
1101                 /* parent */
1102                 size_t left = data;
1103                 char buf[16384];
1104                 int status;
1105                 int pid2;
1106
1107                 while (left) {
1108                         int res = send(self->fd, buf,
1109                                        left > 16384 ? 16384 : left, 0);
1110
1111                         EXPECT_GE(res, 0);
1112                         left -= res;
1113                 }
1114
1115                 pid2 = wait(&status);
1116                 EXPECT_EQ(status, 0);
1117                 EXPECT_EQ(res, pid2);
1118         } else {
1119                 /* child */
1120                 size_t left = data;
1121                 char buf[16384];
1122
1123                 while (left) {
1124                         int res = recv(self->cfd, buf,
1125                                        left > 16384 ? 16384 : left, 0);
1126
1127                         EXPECT_GE(res, 0);
1128                         left -= res;
1129                 }
1130         }
1131 }
1132
1133 TEST_F(tls, nonblocking)
1134 {
1135         size_t data = 100000;
1136         int sendbuf = 100;
1137         int flags;
1138         int res;
1139
1140         flags = fcntl(self->fd, F_GETFL, 0);
1141         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1142         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1143
1144         /* Ensure nonblocking behavior by imposing a small send
1145          * buffer.
1146          */
1147         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1148                              &sendbuf, sizeof(sendbuf)), 0);
1149
1150         res = fork();
1151         EXPECT_NE(res, -1);
1152
1153         if (res) {
1154                 /* parent */
1155                 bool eagain = false;
1156                 size_t left = data;
1157                 char buf[16384];
1158                 int status;
1159                 int pid2;
1160
1161                 while (left) {
1162                         int res = send(self->fd, buf,
1163                                        left > 16384 ? 16384 : left, 0);
1164
1165                         if (res == -1 && errno == EAGAIN) {
1166                                 eagain = true;
1167                                 usleep(10000);
1168                                 continue;
1169                         }
1170                         EXPECT_GE(res, 0);
1171                         left -= res;
1172                 }
1173
1174                 EXPECT_TRUE(eagain);
1175                 pid2 = wait(&status);
1176
1177                 EXPECT_EQ(status, 0);
1178                 EXPECT_EQ(res, pid2);
1179         } else {
1180                 /* child */
1181                 bool eagain = false;
1182                 size_t left = data;
1183                 char buf[16384];
1184
1185                 while (left) {
1186                         int res = recv(self->cfd, buf,
1187                                        left > 16384 ? 16384 : left, 0);
1188
1189                         if (res == -1 && errno == EAGAIN) {
1190                                 eagain = true;
1191                                 usleep(10000);
1192                                 continue;
1193                         }
1194                         EXPECT_GE(res, 0);
1195                         left -= res;
1196                 }
1197                 EXPECT_TRUE(eagain);
1198         }
1199 }
1200
1201 static void
1202 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1203                bool sendpg, unsigned int n_readers, unsigned int n_writers)
1204 {
1205         const unsigned int n_children = n_readers + n_writers;
1206         const size_t data = 6 * 1000 * 1000;
1207         const size_t file_sz = data / 100;
1208         size_t read_bias, write_bias;
1209         int i, fd, child_id;
1210         char buf[file_sz];
1211         pid_t pid;
1212
1213         /* Only allow multiples for simplicity */
1214         ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1215         read_bias = n_writers / n_readers ?: 1;
1216         write_bias = n_readers / n_writers ?: 1;
1217
1218         /* prep a file to send */
1219         fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1220         ASSERT_GE(fd, 0);
1221
1222         memset(buf, 0xac, file_sz);
1223         ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1224
1225         /* spawn children */
1226         for (child_id = 0; child_id < n_children; child_id++) {
1227                 pid = fork();
1228                 ASSERT_NE(pid, -1);
1229                 if (!pid)
1230                         break;
1231         }
1232
1233         /* parent waits for all children */
1234         if (pid) {
1235                 for (i = 0; i < n_children; i++) {
1236                         int status;
1237
1238                         wait(&status);
1239                         EXPECT_EQ(status, 0);
1240                 }
1241
1242                 return;
1243         }
1244
1245         /* Split threads for reading and writing */
1246         if (child_id < n_readers) {
1247                 size_t left = data * read_bias;
1248                 char rb[8001];
1249
1250                 while (left) {
1251                         int res;
1252
1253                         res = recv(self->cfd, rb,
1254                                    left > sizeof(rb) ? sizeof(rb) : left, 0);
1255
1256                         EXPECT_GE(res, 0);
1257                         left -= res;
1258                 }
1259         } else {
1260                 size_t left = data * write_bias;
1261
1262                 while (left) {
1263                         int res;
1264
1265                         ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1266                         if (sendpg)
1267                                 res = sendfile(self->fd, fd, NULL,
1268                                                left > file_sz ? file_sz : left);
1269                         else
1270                                 res = send(self->fd, buf,
1271                                            left > file_sz ? file_sz : left, 0);
1272
1273                         EXPECT_GE(res, 0);
1274                         left -= res;
1275                 }
1276         }
1277 }
1278
1279 TEST_F(tls, mutliproc_even)
1280 {
1281         test_mutliproc(_metadata, self, false, 6, 6);
1282 }
1283
1284 TEST_F(tls, mutliproc_readers)
1285 {
1286         test_mutliproc(_metadata, self, false, 4, 12);
1287 }
1288
1289 TEST_F(tls, mutliproc_writers)
1290 {
1291         test_mutliproc(_metadata, self, false, 10, 2);
1292 }
1293
1294 TEST_F(tls, mutliproc_sendpage_even)
1295 {
1296         test_mutliproc(_metadata, self, true, 6, 6);
1297 }
1298
1299 TEST_F(tls, mutliproc_sendpage_readers)
1300 {
1301         test_mutliproc(_metadata, self, true, 4, 12);
1302 }
1303
1304 TEST_F(tls, mutliproc_sendpage_writers)
1305 {
1306         test_mutliproc(_metadata, self, true, 10, 2);
1307 }
1308
1309 TEST_F(tls, control_msg)
1310 {
1311         char *test_str = "test_read";
1312         char record_type = 100;
1313         int send_len = 10;
1314         char buf[10];
1315
1316         if (self->notls)
1317                 SKIP(return, "no TLS support");
1318
1319         EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1320                   send_len);
1321         /* Should fail because we didn't provide a control message */
1322         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1323
1324         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1325                                 buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1326                   send_len);
1327         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1328
1329         /* Recv the message again without MSG_PEEK */
1330         memset(buf, 0, sizeof(buf));
1331
1332         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1333                                 buf, sizeof(buf), MSG_WAITALL),
1334                   send_len);
1335         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1336 }
1337
1338 TEST_F(tls, shutdown)
1339 {
1340         char const *test_str = "test_read";
1341         int send_len = 10;
1342         char buf[10];
1343
1344         ASSERT_EQ(strlen(test_str) + 1, send_len);
1345
1346         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1347         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1348         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1349
1350         shutdown(self->fd, SHUT_RDWR);
1351         shutdown(self->cfd, SHUT_RDWR);
1352 }
1353
1354 TEST_F(tls, shutdown_unsent)
1355 {
1356         char const *test_str = "test_read";
1357         int send_len = 10;
1358
1359         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1360
1361         shutdown(self->fd, SHUT_RDWR);
1362         shutdown(self->cfd, SHUT_RDWR);
1363 }
1364
1365 TEST_F(tls, shutdown_reuse)
1366 {
1367         struct sockaddr_in addr;
1368         int ret;
1369
1370         shutdown(self->fd, SHUT_RDWR);
1371         shutdown(self->cfd, SHUT_RDWR);
1372         close(self->cfd);
1373
1374         addr.sin_family = AF_INET;
1375         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1376         addr.sin_port = 0;
1377
1378         ret = bind(self->fd, &addr, sizeof(addr));
1379         EXPECT_EQ(ret, 0);
1380         ret = listen(self->fd, 10);
1381         EXPECT_EQ(ret, -1);
1382         EXPECT_EQ(errno, EINVAL);
1383
1384         ret = connect(self->fd, &addr, sizeof(addr));
1385         EXPECT_EQ(ret, -1);
1386         EXPECT_EQ(errno, EISCONN);
1387 }
1388
1389 FIXTURE(tls_err)
1390 {
1391         int fd, cfd;
1392         int fd2, cfd2;
1393         bool notls;
1394 };
1395
1396 FIXTURE_VARIANT(tls_err)
1397 {
1398         uint16_t tls_version;
1399 };
1400
1401 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1402 {
1403         .tls_version = TLS_1_2_VERSION,
1404 };
1405
1406 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1407 {
1408         .tls_version = TLS_1_3_VERSION,
1409 };
1410
1411 FIXTURE_SETUP(tls_err)
1412 {
1413         struct tls_crypto_info_keys tls12;
1414         int ret;
1415
1416         tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1417                              &tls12);
1418
1419         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1420         ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1421         if (self->notls)
1422                 return;
1423
1424         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1425         ASSERT_EQ(ret, 0);
1426
1427         ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1428         ASSERT_EQ(ret, 0);
1429 }
1430
1431 FIXTURE_TEARDOWN(tls_err)
1432 {
1433         close(self->fd);
1434         close(self->cfd);
1435         close(self->fd2);
1436         close(self->cfd2);
1437 }
1438
1439 TEST_F(tls_err, bad_rec)
1440 {
1441         char buf[64];
1442
1443         if (self->notls)
1444                 SKIP(return, "no TLS support");
1445
1446         memset(buf, 0x55, sizeof(buf));
1447         EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1448         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1449         EXPECT_EQ(errno, EMSGSIZE);
1450         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1451         EXPECT_EQ(errno, EAGAIN);
1452 }
1453
1454 TEST_F(tls_err, bad_auth)
1455 {
1456         char buf[128];
1457         int n;
1458
1459         if (self->notls)
1460                 SKIP(return, "no TLS support");
1461
1462         memrnd(buf, sizeof(buf) / 2);
1463         EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1464         n = recv(self->cfd, buf, sizeof(buf), 0);
1465         EXPECT_GT(n, sizeof(buf) / 2);
1466
1467         buf[n - 1]++;
1468
1469         EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1470         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1471         EXPECT_EQ(errno, EBADMSG);
1472         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1473         EXPECT_EQ(errno, EBADMSG);
1474 }
1475
1476 TEST_F(tls_err, bad_in_large_read)
1477 {
1478         char txt[3][64];
1479         char cip[3][128];
1480         char buf[3 * 128];
1481         int i, n;
1482
1483         if (self->notls)
1484                 SKIP(return, "no TLS support");
1485
1486         /* Put 3 records in the sockets */
1487         for (i = 0; i < 3; i++) {
1488                 memrnd(txt[i], sizeof(txt[i]));
1489                 EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1490                           sizeof(txt[i]));
1491                 n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1492                 EXPECT_GT(n, sizeof(txt[i]));
1493                 /* Break the third message */
1494                 if (i == 2)
1495                         cip[2][n - 1]++;
1496                 EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1497         }
1498
1499         /* We should be able to receive the first two messages */
1500         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1501         EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1502         EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1503         /* Third mesasge is bad */
1504         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1505         EXPECT_EQ(errno, EBADMSG);
1506         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1507         EXPECT_EQ(errno, EBADMSG);
1508 }
1509
1510 TEST_F(tls_err, bad_cmsg)
1511 {
1512         char *test_str = "test_read";
1513         int send_len = 10;
1514         char cip[128];
1515         char buf[128];
1516         char txt[64];
1517         int n;
1518
1519         if (self->notls)
1520                 SKIP(return, "no TLS support");
1521
1522         /* Queue up one data record */
1523         memrnd(txt, sizeof(txt));
1524         EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1525         n = recv(self->cfd, cip, sizeof(cip), 0);
1526         EXPECT_GT(n, sizeof(txt));
1527         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1528
1529         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1530         n = recv(self->cfd, cip, sizeof(cip), 0);
1531         cip[n - 1]++; /* Break it */
1532         EXPECT_GT(n, send_len);
1533         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1534
1535         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1536         EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1537         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1538         EXPECT_EQ(errno, EBADMSG);
1539         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1540         EXPECT_EQ(errno, EBADMSG);
1541 }
1542
1543 TEST(non_established) {
1544         struct tls12_crypto_info_aes_gcm_256 tls12;
1545         struct sockaddr_in addr;
1546         int sfd, ret, fd;
1547         socklen_t len;
1548
1549         len = sizeof(addr);
1550
1551         memset(&tls12, 0, sizeof(tls12));
1552         tls12.info.version = TLS_1_2_VERSION;
1553         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1554
1555         addr.sin_family = AF_INET;
1556         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1557         addr.sin_port = 0;
1558
1559         fd = socket(AF_INET, SOCK_STREAM, 0);
1560         sfd = socket(AF_INET, SOCK_STREAM, 0);
1561
1562         ret = bind(sfd, &addr, sizeof(addr));
1563         ASSERT_EQ(ret, 0);
1564         ret = listen(sfd, 10);
1565         ASSERT_EQ(ret, 0);
1566
1567         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1568         EXPECT_EQ(ret, -1);
1569         /* TLS ULP not supported */
1570         if (errno == ENOENT)
1571                 return;
1572         EXPECT_EQ(errno, ENOTCONN);
1573
1574         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1575         EXPECT_EQ(ret, -1);
1576         EXPECT_EQ(errno, ENOTCONN);
1577
1578         ret = getsockname(sfd, &addr, &len);
1579         ASSERT_EQ(ret, 0);
1580
1581         ret = connect(fd, &addr, sizeof(addr));
1582         ASSERT_EQ(ret, 0);
1583
1584         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1585         ASSERT_EQ(ret, 0);
1586
1587         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1588         EXPECT_EQ(ret, -1);
1589         EXPECT_EQ(errno, EEXIST);
1590
1591         close(fd);
1592         close(sfd);
1593 }
1594
1595 TEST(keysizes) {
1596         struct tls12_crypto_info_aes_gcm_256 tls12;
1597         int ret, fd, cfd;
1598         bool notls;
1599
1600         memset(&tls12, 0, sizeof(tls12));
1601         tls12.info.version = TLS_1_2_VERSION;
1602         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1603
1604         ulp_sock_pair(_metadata, &fd, &cfd, &notls);
1605
1606         if (!notls) {
1607                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1608                                  sizeof(tls12));
1609                 EXPECT_EQ(ret, 0);
1610
1611                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1612                                  sizeof(tls12));
1613                 EXPECT_EQ(ret, 0);
1614         }
1615
1616         close(fd);
1617         close(cfd);
1618 }
1619
1620 TEST_HARNESS_MAIN