SUNRPC: Recognize control messages in server-side TCP socket code
authorChuck Lever <chuck.lever@oracle.com>
Mon, 17 Apr 2023 13:42:14 +0000 (09:42 -0400)
committerChuck Lever <chuck.lever@oracle.com>
Thu, 27 Apr 2023 22:49:24 +0000 (18:49 -0400)
To support kTLS, the server-side TCP socket receive path needs to
watch for CMSGs.

Acked-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Chuck Lever <chuck.lever@oracle.com>
include/net/tls.h
net/sunrpc/svcsock.c

index 154949c7b0c88e9bb23231d3de52a389fdfc0ea8..6056ce5a2aa5fb666b395f8f21f7b501f2ef2dc3 100644 (file)
@@ -69,6 +69,8 @@ extern const struct tls_cipher_size_desc tls_cipher_size_desc[];
 
 #define TLS_CRYPTO_INFO_READY(info)    ((info)->cipher_type)
 
+#define TLS_RECORD_TYPE_ALERT          0x15
+#define TLS_RECORD_TYPE_HANDSHAKE      0x16
 #define TLS_RECORD_TYPE_DATA           0x17
 
 #define TLS_AAD_SPACE_SIZE             13
index 302a14dd7882fc395a3e3c443cf1932dfa7884f5..c5b74f523fc44a17cf2ddaf728c4e74816388d0e 100644 (file)
@@ -43,6 +43,7 @@
 #include <net/udp.h>
 #include <net/tcp.h>
 #include <net/tcp_states.h>
+#include <net/tls.h>
 #include <linux/uaccess.h>
 #include <linux/highmem.h>
 #include <asm/ioctls.h>
@@ -216,6 +217,49 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining)
        return len;
 }
 
+static int
+svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg,
+                         struct cmsghdr *cmsg, int ret)
+{
+       if (cmsg->cmsg_level == SOL_TLS &&
+           cmsg->cmsg_type == TLS_GET_RECORD_TYPE) {
+               u8 content_type = *((u8 *)CMSG_DATA(cmsg));
+
+               switch (content_type) {
+               case TLS_RECORD_TYPE_DATA:
+                       /* TLS sets EOR at the end of each application data
+                        * record, even though there might be more frames
+                        * waiting to be decrypted.
+                        */
+                       msg->msg_flags &= ~MSG_EOR;
+                       break;
+               case TLS_RECORD_TYPE_ALERT:
+                       ret = -ENOTCONN;
+                       break;
+               default:
+                       ret = -EAGAIN;
+               }
+       }
+       return ret;
+}
+
+static int
+svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg)
+{
+       union {
+               struct cmsghdr  cmsg;
+               u8              buf[CMSG_SPACE(sizeof(u8))];
+       } u;
+       int ret;
+
+       msg->msg_control = &u;
+       msg->msg_controllen = sizeof(u);
+       ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT);
+       if (unlikely(msg->msg_controllen != sizeof(u)))
+               ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret);
+       return ret;
+}
+
 #if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE
 static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek)
 {
@@ -263,7 +307,7 @@ static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen,
                iov_iter_advance(&msg.msg_iter, seek);
                buflen -= seek;
        }
-       len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
+       len = svc_tcp_sock_recv_cmsg(svsk, &msg);
        if (len > 0)
                svc_flush_bvec(bvec, len, seek);
 
@@ -877,7 +921,7 @@ static ssize_t svc_tcp_read_marker(struct svc_sock *svsk,
                iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen;
                iov.iov_len  = want;
                iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want);
-               len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT);
+               len = svc_tcp_sock_recv_cmsg(svsk, &msg);
                if (len < 0)
                        return len;
                svsk->sk_tcplen += len;