Merge tag 'hyperv-next-signed' of git://git.kernel.org/pub/scm/linux/kernel/git/hyper...
[linux-block.git] / net / socket.c
index 3c6d60eadf7abe38815e6654cb0dbb5d50dedcba..c78c3d37c884fff5325e5e342d1ad8902fe5854b 100644 (file)
@@ -1691,24 +1691,13 @@ SYSCALL_DEFINE2(listen, int, fd, int, backlog)
        return __sys_listen(fd, backlog);
 }
 
-/*
- *     For accept, we attempt to create a new socket, set up the link
- *     with the client, wake up the client, then return the new
- *     connected fd. We collect the address of the connector in kernel
- *     space and move it to user at the very end. This is unclean because
- *     we open the socket then return an error.
- *
- *     1003.1g adds the ability to recvmsg() to query connection pending
- *     status to recvmsg. We need to add that support in a way thats
- *     clean when we restructure accept also.
- */
-
-int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
-                 int __user *upeer_addrlen, int flags)
+int __sys_accept4_file(struct file *file, unsigned file_flags,
+                      struct sockaddr __user *upeer_sockaddr,
+                      int __user *upeer_addrlen, int flags)
 {
        struct socket *sock, *newsock;
        struct file *newfile;
-       int err, len, newfd, fput_needed;
+       int err, len, newfd;
        struct sockaddr_storage address;
 
        if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
@@ -1717,14 +1706,14 @@ int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
        if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
                flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
-       sock = sockfd_lookup_light(fd, &err, &fput_needed);
+       sock = sock_from_file(file, &err);
        if (!sock)
                goto out;
 
        err = -ENFILE;
        newsock = sock_alloc();
        if (!newsock)
-               goto out_put;
+               goto out;
 
        newsock->type = sock->type;
        newsock->ops = sock->ops;
@@ -1739,20 +1728,21 @@ int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
        if (unlikely(newfd < 0)) {
                err = newfd;
                sock_release(newsock);
-               goto out_put;
+               goto out;
        }
        newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name);
        if (IS_ERR(newfile)) {
                err = PTR_ERR(newfile);
                put_unused_fd(newfd);
-               goto out_put;
+               goto out;
        }
 
        err = security_socket_accept(sock, newsock);
        if (err)
                goto out_fd;
 
-       err = sock->ops->accept(sock, newsock, sock->file->f_flags, false);
+       err = sock->ops->accept(sock, newsock, sock->file->f_flags | file_flags,
+                                       false);
        if (err < 0)
                goto out_fd;
 
@@ -1773,15 +1763,42 @@ int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
 
        fd_install(newfd, newfile);
        err = newfd;
-
-out_put:
-       fput_light(sock->file, fput_needed);
 out:
        return err;
 out_fd:
        fput(newfile);
        put_unused_fd(newfd);
-       goto out_put;
+       goto out;
+
+}
+
+/*
+ *     For accept, we attempt to create a new socket, set up the link
+ *     with the client, wake up the client, then return the new
+ *     connected fd. We collect the address of the connector in kernel
+ *     space and move it to user at the very end. This is unclean because
+ *     we open the socket then return an error.
+ *
+ *     1003.1g adds the ability to recvmsg() to query connection pending
+ *     status to recvmsg. We need to add that support in a way thats
+ *     clean when we restructure accept also.
+ */
+
+int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
+                 int __user *upeer_addrlen, int flags)
+{
+       int ret = -EBADF;
+       struct fd f;
+
+       f = fdget(fd);
+       if (f.file) {
+               ret = __sys_accept4_file(f.file, 0, upeer_sockaddr,
+                                               upeer_addrlen, flags);
+               if (f.flags)
+                       fput(f.file);
+       }
+
+       return ret;
 }
 
 SYSCALL_DEFINE4(accept4, int, fd, struct sockaddr __user *, upeer_sockaddr,
@@ -1808,32 +1825,46 @@ SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr,
  *     include the -EINPROGRESS status for such sockets.
  */
 
-int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
+int __sys_connect_file(struct file *file, struct sockaddr __user *uservaddr,
+                      int addrlen, int file_flags)
 {
        struct socket *sock;
        struct sockaddr_storage address;
-       int err, fput_needed;
+       int err;
 
-       sock = sockfd_lookup_light(fd, &err, &fput_needed);
+       sock = sock_from_file(file, &err);
        if (!sock)
                goto out;
        err = move_addr_to_kernel(uservaddr, addrlen, &address);
        if (err < 0)
-               goto out_put;
+               goto out;
 
        err =
            security_socket_connect(sock, (struct sockaddr *)&address, addrlen);
        if (err)
-               goto out_put;
+               goto out;
 
        err = sock->ops->connect(sock, (struct sockaddr *)&address, addrlen,
-                                sock->file->f_flags);
-out_put:
-       fput_light(sock->file, fput_needed);
+                                sock->file->f_flags | file_flags);
 out:
        return err;
 }
 
+int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen)
+{
+       int ret = -EBADF;
+       struct fd f;
+
+       f = fdget(fd);
+       if (f.file) {
+               ret = __sys_connect_file(f.file, uservaddr, addrlen, 0);
+               if (f.flags)
+                       fput(f.file);
+       }
+
+       return ret;
+}
+
 SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr,
                int, addrlen)
 {
@@ -2233,15 +2264,10 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
        return err < 0 ? err : 0;
 }
 
-static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
-                        struct msghdr *msg_sys, unsigned int flags,
-                        struct used_address *used_address,
-                        unsigned int allowed_msghdr_flags)
+static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
+                          unsigned int flags, struct used_address *used_address,
+                          unsigned int allowed_msghdr_flags)
 {
-       struct compat_msghdr __user *msg_compat =
-           (struct compat_msghdr __user *)msg;
-       struct sockaddr_storage address;
-       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
        unsigned char ctl[sizeof(struct cmsghdr) + 20]
                                __aligned(sizeof(__kernel_size_t));
        /* 20 is size of ipv6_pktinfo */
@@ -2249,19 +2275,10 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
        int ctl_len;
        ssize_t err;
 
-       msg_sys->msg_name = &address;
-
-       if (MSG_CMSG_COMPAT & flags)
-               err = get_compat_msghdr(msg_sys, msg_compat, NULL, &iov);
-       else
-               err = copy_msghdr_from_user(msg_sys, msg, NULL, &iov);
-       if (err < 0)
-               return err;
-
        err = -ENOBUFS;
 
        if (msg_sys->msg_controllen > INT_MAX)
-               goto out_freeiov;
+               goto out;
        flags |= (msg_sys->msg_flags & allowed_msghdr_flags);
        ctl_len = msg_sys->msg_controllen;
        if ((MSG_CMSG_COMPAT & flags) && ctl_len) {
@@ -2269,7 +2286,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
                    cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl,
                                                     sizeof(ctl));
                if (err)
-                       goto out_freeiov;
+                       goto out;
                ctl_buf = msg_sys->msg_control;
                ctl_len = msg_sys->msg_controllen;
        } else if (ctl_len) {
@@ -2278,7 +2295,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
                if (ctl_len > sizeof(ctl)) {
                        ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL);
                        if (ctl_buf == NULL)
-                               goto out_freeiov;
+                               goto out;
                }
                err = -EFAULT;
                /*
@@ -2324,7 +2341,47 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
 out_freectl:
        if (ctl_buf != ctl)
                sock_kfree_s(sock->sk, ctl_buf, ctl_len);
-out_freeiov:
+out:
+       return err;
+}
+
+static int sendmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct iovec **iov)
+{
+       int err;
+
+       if (flags & MSG_CMSG_COMPAT) {
+               struct compat_msghdr __user *msg_compat;
+
+               msg_compat = (struct compat_msghdr __user *) umsg;
+               err = get_compat_msghdr(msg, msg_compat, NULL, iov);
+       } else {
+               err = copy_msghdr_from_user(msg, umsg, NULL, iov);
+       }
+       if (err < 0)
+               return err;
+
+       return 0;
+}
+
+static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
+                        struct msghdr *msg_sys, unsigned int flags,
+                        struct used_address *used_address,
+                        unsigned int allowed_msghdr_flags)
+{
+       struct sockaddr_storage address;
+       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
+       ssize_t err;
+
+       msg_sys->msg_name = &address;
+
+       err = sendmsg_copy_msghdr(msg_sys, msg, flags, &iov);
+       if (err < 0)
+               return err;
+
+       err = ____sys_sendmsg(sock, msg_sys, flags, used_address,
+                               allowed_msghdr_flags);
        kfree(iov);
        return err;
 }
@@ -2332,12 +2389,27 @@ out_freeiov:
 /*
  *     BSD sendmsg interface
  */
-long __sys_sendmsg_sock(struct socket *sock, struct user_msghdr __user *msg,
+long __sys_sendmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
                        unsigned int flags)
 {
-       struct msghdr msg_sys;
+       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
+       struct sockaddr_storage address;
+       struct msghdr msg = { .msg_name = &address };
+       ssize_t err;
+
+       err = sendmsg_copy_msghdr(&msg, umsg, flags, &iov);
+       if (err)
+               return err;
+       /* disallow ancillary data requests from this path */
+       if (msg.msg_control || msg.msg_controllen) {
+               err = -EINVAL;
+               goto out;
+       }
 
-       return ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0);
+       err = ____sys_sendmsg(sock, &msg, flags, NULL, 0);
+out:
+       kfree(iov);
+       return err;
 }
 
 long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,
@@ -2443,33 +2515,41 @@ SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg,
        return __sys_sendmmsg(fd, mmsg, vlen, flags, true);
 }
 
-static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
-                        struct msghdr *msg_sys, unsigned int flags, int nosec)
+static int recvmsg_copy_msghdr(struct msghdr *msg,
+                              struct user_msghdr __user *umsg, unsigned flags,
+                              struct sockaddr __user **uaddr,
+                              struct iovec **iov)
 {
-       struct compat_msghdr __user *msg_compat =
-           (struct compat_msghdr __user *)msg;
-       struct iovec iovstack[UIO_FASTIOV];
-       struct iovec *iov = iovstack;
-       unsigned long cmsg_ptr;
-       int len;
        ssize_t err;
 
-       /* kernel mode address */
-       struct sockaddr_storage addr;
+       if (MSG_CMSG_COMPAT & flags) {
+               struct compat_msghdr __user *msg_compat;
 
-       /* user mode address pointers */
-       struct sockaddr __user *uaddr;
-       int __user *uaddr_len = COMPAT_NAMELEN(msg);
-
-       msg_sys->msg_name = &addr;
-
-       if (MSG_CMSG_COMPAT & flags)
-               err = get_compat_msghdr(msg_sys, msg_compat, &uaddr, &iov);
-       else
-               err = copy_msghdr_from_user(msg_sys, msg, &uaddr, &iov);
+               msg_compat = (struct compat_msghdr __user *) umsg;
+               err = get_compat_msghdr(msg, msg_compat, uaddr, iov);
+       } else {
+               err = copy_msghdr_from_user(msg, umsg, uaddr, iov);
+       }
        if (err < 0)
                return err;
 
+       return 0;
+}
+
+static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys,
+                          struct user_msghdr __user *msg,
+                          struct sockaddr __user *uaddr,
+                          unsigned int flags, int nosec)
+{
+       struct compat_msghdr __user *msg_compat =
+                                       (struct compat_msghdr __user *) msg;
+       int __user *uaddr_len = COMPAT_NAMELEN(msg);
+       struct sockaddr_storage addr;
+       unsigned long cmsg_ptr;
+       int len;
+       ssize_t err;
+
+       msg_sys->msg_name = &addr;
        cmsg_ptr = (unsigned long)msg_sys->msg_control;
        msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
 
@@ -2480,7 +2560,7 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
                flags |= MSG_DONTWAIT;
        err = (nosec ? sock_recvmsg_nosec : sock_recvmsg)(sock, msg_sys, flags);
        if (err < 0)
-               goto out_freeiov;
+               goto out;
        len = err;
 
        if (uaddr != NULL) {
@@ -2488,12 +2568,12 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
                                        msg_sys->msg_namelen, uaddr,
                                        uaddr_len);
                if (err < 0)
-                       goto out_freeiov;
+                       goto out;
        }
        err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
                         COMPAT_FLAGS(msg));
        if (err)
-               goto out_freeiov;
+               goto out;
        if (MSG_CMSG_COMPAT & flags)
                err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
                                 &msg_compat->msg_controllen);
@@ -2501,10 +2581,25 @@ static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
                err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr,
                                 &msg->msg_controllen);
        if (err)
-               goto out_freeiov;
+               goto out;
        err = len;
+out:
+       return err;
+}
+
+static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg,
+                        struct msghdr *msg_sys, unsigned int flags, int nosec)
+{
+       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
+       /* user mode address pointers */
+       struct sockaddr __user *uaddr;
+       ssize_t err;
 
-out_freeiov:
+       err = recvmsg_copy_msghdr(msg_sys, msg, flags, &uaddr, &iov);
+       if (err < 0)
+               return err;
+
+       err = ____sys_recvmsg(sock, msg_sys, msg, uaddr, flags, nosec);
        kfree(iov);
        return err;
 }
@@ -2513,12 +2608,28 @@ out_freeiov:
  *     BSD recvmsg interface
  */
 
-long __sys_recvmsg_sock(struct socket *sock, struct user_msghdr __user *msg,
+long __sys_recvmsg_sock(struct socket *sock, struct user_msghdr __user *umsg,
                        unsigned int flags)
 {
-       struct msghdr msg_sys;
+       struct iovec iovstack[UIO_FASTIOV], *iov = iovstack;
+       struct sockaddr_storage address;
+       struct msghdr msg = { .msg_name = &address };
+       struct sockaddr __user *uaddr;
+       ssize_t err;
 
-       return ___sys_recvmsg(sock, msg, &msg_sys, flags, 0);
+       err = recvmsg_copy_msghdr(&msg, umsg, flags, &uaddr, &iov);
+       if (err)
+               return err;
+       /* disallow ancillary data requests from this path */
+       if (msg.msg_control || msg.msg_controllen) {
+               err = -EINVAL;
+               goto out;
+       }
+
+       err = ____sys_recvmsg(sock, &msg, umsg, uaddr, flags, 0);
+out:
+       kfree(iov);
+       return err;
 }
 
 long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,