net engine: don't pass in flags
[fio.git] / engines / net.c
index 2a9caaa24e8985172fd89764c190fba67c415674..afb3265c6ada813179fa12853ec121d6632ab378 100644 (file)
@@ -28,6 +28,39 @@ struct netio_data {
        struct sockaddr_in addr;
 };
 
+/*
+ * Return -1 for error and 'nr events' for a positive number
+ * of events
+ */
+static int poll_wait(struct thread_data *td, int fd, short events)
+{
+       struct pollfd pfd;
+       int ret;
+
+       while (!td->terminate) {
+               pfd.fd = fd;
+               pfd.events = events;
+               ret = poll(&pfd, 1, -1);
+               if (ret < 0) {
+                       if (errno == EINTR)
+                               continue;
+
+                       td_verror(td, errno, "poll");
+                       return -1;
+               } else if (!ret)
+                       continue;
+
+               break;
+       }
+
+       if (pfd.revents & events)
+               return 1;
+       else if (td->terminate)
+               return 1;
+
+       return -1;
+}
+
 static int fio_netio_prep(struct thread_data *td, struct io_u *io_u)
 {
        struct netio_data *nd = td->io_ops->data;
@@ -182,7 +215,11 @@ static int fio_netio_splice_out(struct thread_data *td, struct io_u *io_u)
 static int fio_netio_send(struct thread_data *td, struct io_u *io_u)
 {
        struct netio_data *nd = td->io_ops->data;
-       int flags = 0;
+       int ret, flags = 0;
+
+       ret = poll_wait(td, io_u->file->fd, POLLOUT);
+       if (ret <= 0)
+               return ret;
 
        /*
         * if we are going to write more, set MSG_MORE
@@ -204,7 +241,11 @@ static int fio_netio_send(struct thread_data *td, struct io_u *io_u)
 static int fio_netio_recv(struct thread_data *td, struct io_u *io_u)
 {
        struct netio_data *nd = td->io_ops->data;
-       int flags = MSG_WAITALL;
+       int ret, flags = MSG_WAITALL;
+
+       ret = poll_wait(td, io_u->file->fd, POLLIN);
+       if (ret <= 0)
+               return ret;
 
        if (nd->net_protocol == IPPROTO_UDP) {
                socklen_t len = sizeof(nd->addr);
@@ -289,8 +330,6 @@ static int fio_netio_accept(struct thread_data *td, struct fio_file *f)
 {
        struct netio_data *nd = td->io_ops->data;
        socklen_t socklen = sizeof(nd->addr);
-       struct pollfd pfd;
-       int ret;
 
        if (nd->net_protocol == IPPROTO_UDP) {
                f->fd = nd->listenfd;
@@ -299,36 +338,13 @@ static int fio_netio_accept(struct thread_data *td, struct fio_file *f)
 
        log_info("fio: waiting for connection\n");
 
-       /*
-        * Accept loop. poll for incoming events, accept them. Repeat until we
-        * have all connections.
-        */
-       while (!td->terminate) {
-               pfd.fd = nd->listenfd;
-               pfd.events = POLLIN;
-
-               ret = poll(&pfd, 1, -1);
-               if (ret < 0) {
-                       if (errno == EINTR)
-                               continue;
-
-                       td_verror(td, errno, "poll");
-                       break;
-               } else if (!ret)
-                       continue;
-
-               /*
-                * should be impossible
-                */
-               if (!(pfd.revents & POLLIN))
-                       continue;
+       if (poll_wait(td, nd->listenfd, POLLIN) < 0)
+               return 1;
 
-               f->fd = accept(nd->listenfd, (struct sockaddr *) &nd->addr, &socklen);
-               if (f->fd < 0) {
-                       td_verror(td, errno, "accept");
-                       return 1;
-               }
-               break;
+       f->fd = accept(nd->listenfd, (struct sockaddr *) &nd->addr, &socklen);
+       if (f->fd < 0) {
+               td_verror(td, errno, "accept");
+               return 1;
        }
 
        return 0;
@@ -452,9 +468,11 @@ static int fio_netio_init(struct thread_data *td)
                goto bad_host;
 
        if (modep) {
-               if (!strncmp("tcp", modep, strlen(modep)))
+               if (!strncmp("tcp", modep, strlen(modep)) ||
+                   !strncmp("TCP", modep, strlen(modep)))
                        nd->net_protocol = IPPROTO_TCP;
-               else if (!strncmp("udp", modep, strlen(modep)))
+               else if (!strncmp("udp", modep, strlen(modep)) ||
+                        !strncmp("UDP", modep, strlen(modep)))
                        nd->net_protocol = IPPROTO_UDP;
                else
                        goto bad_host;