Add nettest/ tools
[splice.git] / nettest / recv.c
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <unistd.h>
4 #include <netdb.h>
5 #include <unistd.h>
6 #include <fcntl.h>
7 #include <signal.h>
8 #include <netinet/in.h>
9 #include <arpa/inet.h>
10 #include <string.h>
11 #include <sys/time.h>
12 #include <errno.h>
13 #include <sys/poll.h>
14 #include <sys/time.h>
15 #include <sys/resource.h>
16
17 #include "../splice.h"
18 #include "crc32.h"
19 #include "msg.h"
20
21 static unsigned int msg_size = 4096;
22 static int use_splice = 1;
23 static int splice_move;
24
25 static int usage(const char *name)
26 {
27         fprintf(stderr, "%s: [-s(ize)] [-m(ove)] [-r(ecv)] port\n", name);
28         return 1;
29 }
30
31 unsigned long mtime_since(struct timeval *s, struct timeval *e)
32 {
33         long sec, usec, ret;
34
35         sec = e->tv_sec - s->tv_sec;
36         usec = e->tv_usec - s->tv_usec;
37         if (sec > 0 && usec < 0) {
38                 sec--;
39                 usec += 1000000;
40         }
41
42         sec *= 1000UL;
43         usec /= 1000UL;
44         ret = sec + usec;
45
46         /*
47          * time warp bug on some kernels?
48          */
49         if (ret < 0)
50                 ret = 0;
51
52         return ret;
53 }
54
55 unsigned long mtime_since_now(struct timeval *s)
56 {
57         struct timeval t;
58
59         gettimeofday(&t, NULL);
60         return mtime_since(s, &t);
61 }
62
63 static int get_connect(int fd, struct sockaddr_in *addr)
64 {
65         socklen_t socklen = sizeof(*addr);
66         int ret, connfd;
67
68         fprintf(stderr, "Waiting for connect...\n");
69
70         do {
71                 struct pollfd pfd = {
72                         .fd = fd,
73                         .events = POLLIN,
74                 };
75
76                 ret = poll(&pfd, 1, -1);
77                 if (ret < 0)
78                         return error("poll");
79                 else if (!ret)
80                         continue;
81
82                 connfd = accept(fd, (struct sockaddr *) addr, &socklen);
83                 if (connfd < 0)
84                         return error("accept");
85                 break;
86         } while (1);
87
88         fprintf(stderr, "Got connect!\n");
89                         
90         return connfd;
91 }
92
93 static int parse_options(int argc, char *argv[])
94 {
95         int c, index = 1;
96
97         while ((c = getopt(argc, argv, "s:mr")) != -1) {
98                 switch (c) {
99                 case 's':
100                         msg_size = atoi(optarg);
101                         index++;
102                         break;
103                 case 'm':
104                         splice_move = 1;
105                         index++;
106                         break;
107                 case 'r':
108                         use_splice = 0;
109                         index++;
110                         break;
111                 default:
112                         return -1;
113                 }
114         }
115
116         return index;
117 }
118
119 static int verify_crc(struct msg *m)
120 {
121         unsigned long crc;
122         void *data = m;
123
124         data += sizeof(*m);
125         crc = crc32(data, m->msg_size - sizeof(*m));
126
127         if (crc == m->crc32)
128                 return 0;
129
130         fprintf(stderr, "crc error: got %lx, wanted %lx\n", crc, m->crc32);
131         return 1;
132 }
133
134 static int do_recv(int fd, void *buf, unsigned int len)
135 {
136         while (len) {
137                 int ret = recv(fd, buf, len, MSG_WAITALL);
138
139                 if (ret < 0)
140                         return error("recv");
141                 else if (!ret)
142                         break;
143
144                 len -= ret;
145                 buf += ret;
146         }
147
148         return len;
149 }
150
151 static int normal_recv_loop(int fd)
152 {
153         struct msg *m;
154
155         m = malloc(msg_size);
156
157         while (1) {
158                 if (do_recv(fd, m, msg_size))
159                         break;
160
161                 if (m->msg_size != msg_size) {
162                         fprintf(stderr, "Bad packet length: wanted %u, got %lu\n", msg_size, m->msg_size);
163                         break;
164                 }
165
166                 /*
167                  * now verify data
168                  */
169                 if (verify_crc(m))
170                         break;
171         }
172
173         free(m);
174         return 0;
175 }
176
177 static int splice_in(int sockfd, int pipefd, unsigned int size)
178 {
179         while (size) {
180                 int ret = ssplice(sockfd, NULL, pipefd, NULL, size, 0);
181
182                 if (ret < 0)
183                         return error("splice from net");
184                 else if (!ret)
185                         break;
186
187                 size -= ret;
188         }
189
190         if (size)
191                 fprintf(stderr, "splice: %u resid\n", size);
192
193         return size;
194 }
195
196 static int vmsplice_unmap(int pipefd, void *buf, unsigned int len)
197 {
198         struct iovec iov = {
199                 .iov_base = buf,
200                 .iov_len = len,
201         };
202
203         if (svmsplice(pipefd, &iov, 1, SPLICE_F_UNMAP) < 0)
204                 return error("vmsplice unmap");
205
206         return 0;
207 }
208
209 static int vmsplice_out(void **buf, int pipefd, unsigned int len)
210 {
211         struct iovec iov = {
212                 .iov_base = *buf,
213                 .iov_len = len,
214         };
215         int ret, flags = 0;
216
217         if (splice_move)
218                 flags |= SPLICE_F_MOVE;
219
220         while (len) {
221                 ret = svmsplice(pipefd, &iov, 1, flags);
222                 if (ret < 0)
223                         return error("vmsplice");
224                 else if (!ret)
225                         break;
226
227                 *buf = iov.iov_base;
228
229                 len -= ret;
230                 if (len) {
231                         if (splice_move)
232                                 break;
233                         iov.iov_len -= ret;
234                         iov.iov_base += ret;
235                 }
236         }
237
238         if (len)
239                 fprintf(stderr, "vmsplice: %u resid\n", len);
240
241         return len;
242 }
243
244 static int splice_recv_loop(int fd)
245 {
246         struct msg *m;
247         void *buf;
248         int pipes[2];
249
250         if (pipe(pipes) < 0)
251                 return error("pipe");
252
253         if (!splice_move)
254                 m = malloc(msg_size);
255         else
256                 m = NULL;
257
258         while (1) {
259                 /*
260                  * fill pipe with network data
261                  */
262                 if (splice_in(fd, pipes[1], msg_size))
263                         break;
264
265                 /*
266                  * move data to our address space
267                  */
268                 if (!splice_move)
269                         buf = m;
270                 else
271                         buf = NULL;
272
273                 if (vmsplice_out(&buf, pipes[0], msg_size))
274                         break;
275
276                 m = buf;
277
278                 if (m->msg_size != msg_size) {
279                         fprintf(stderr, "Bad packet length: wanted %u, got %lu\n", msg_size, m->msg_size);
280                         break;
281                 }
282
283                 /*
284                  * now verify data
285                  */
286                 if (verify_crc(m))
287                         break;
288
289                 if (splice_move && vmsplice_unmap(pipes[0], buf, msg_size))
290                         break;
291         }
292
293         if (!splice_move)
294                 free(m);
295
296         close(pipes[0]);
297         close(pipes[1]);
298         return 0;
299 }
300
301 static int recv_loop(int fd)
302 {
303         struct rusage ru_s, ru_e;
304         struct timeval start;
305         unsigned long ut, st, rt;
306         int ret;
307
308         gettimeofday(&start, NULL);
309         getrusage(RUSAGE_SELF, &ru_s);
310
311         if (use_splice)
312                 ret = splice_recv_loop(fd);
313         else
314                 ret = normal_recv_loop(fd);
315
316         getrusage(RUSAGE_SELF, &ru_e);
317
318         ut = mtime_since(&ru_s.ru_utime, &ru_e.ru_utime);
319         st = mtime_since(&ru_s.ru_stime, &ru_e.ru_stime);
320         rt = mtime_since_now(&start);
321
322         printf("usr=%lu, sys=%lu, real=%lu\n", ut, st, rt);
323
324         return ret;
325 }
326
327 int main(int argc, char *argv[])
328 {
329         struct sockaddr_in addr;
330         unsigned short port;
331         int connfd, fd, opt, index;
332
333         if (argc < 2)
334                 return usage(argv[0]);
335
336         index = parse_options(argc, argv);
337         if (index == -1 || index + 1 > argc)
338                 return usage(argv[0]);
339
340         printf("recv: msg=%ukb, ", msg_size >> 10);
341         if (use_splice) {
342                 printf("splice() ");
343                 if (splice_move)
344                         printf("zero map ");
345                 else
346                         printf("addr map ");
347         } else
348                 printf("recv()");
349         printf("\n");
350
351         port = atoi(argv[index]);
352
353         fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
354         if (fd < 0)
355                 return error("socket");
356
357         opt = 1;
358         if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0)
359                 return error("setsockopt");
360
361         memset(&addr, 0, sizeof(addr));
362         addr.sin_family = AF_INET;
363         addr.sin_addr.s_addr = htonl(INADDR_ANY);
364         addr.sin_port = htons(port);
365
366         if (bind(fd, (struct sockaddr *) &addr, sizeof(addr)) < 0)
367                 return error("bind");
368         if (listen(fd, 1) < 0)
369                 return error("listen");
370
371         connfd = get_connect(fd, &addr);
372         if (connfd < 0)
373                 return connfd;
374
375         return recv_loop(connfd);
376 }