tls: cap the output scatter list to something reasonable
authorJakub Kicinski <kuba@kernel.org>
Wed, 2 Feb 2022 22:20:31 +0000 (14:20 -0800)
committerDavid S. Miller <davem@davemloft.net>
Fri, 4 Feb 2022 10:14:07 +0000 (10:14 +0000)
TLS recvmsg() passes user pages as destination for decrypt.
The decrypt operation is repeated record by record, each
record being 16kB, max. TLS allocates an sg_table and uses
iov_iter_get_pages() to populate it with enough pages to
fit the decrypted record.

Even though we decrypt a single message at a time we size
the sg_table based on the entire length of the iovec.
This leads to unnecessarily large allocations, risking
triggering OOM conditions.

Use iov_iter_truncate() / iov_iter_reexpand() to construct
a "capped" version of iov_iter_npages(). Alternatively we
could parametrize iov_iter_npages() to take the size as
arg instead of using i->count, or do something else..

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/uio.h
net/tls/tls_sw.c

index 1198a2bfc9bfcebcf9ab098f52685bc76fedbcba..739285fe5a2f211ea0e668d31f58c5700b5475ab 100644 (file)
@@ -273,6 +273,23 @@ static inline void iov_iter_reexpand(struct iov_iter *i, size_t count)
        i->count = count;
 }
 
+static inline int
+iov_iter_npages_cap(struct iov_iter *i, int maxpages, size_t max_bytes)
+{
+       size_t shorted = 0;
+       int npages;
+
+       if (iov_iter_count(i) > max_bytes) {
+               shorted = iov_iter_count(i) - max_bytes;
+               iov_iter_truncate(i, max_bytes);
+       }
+       npages = iov_iter_npages(i, INT_MAX);
+       if (shorted)
+               iov_iter_reexpand(i, iov_iter_count(i) + shorted);
+
+       return npages;
+}
+
 struct csum_state {
        __wsum csum;
        size_t off;
index efc84845bb6b07d5fbee719dd24c7ebba38b686e..0024a692f0f8e25f51c24e1899267f49e5a734ff 100644 (file)
@@ -1433,7 +1433,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 
        if (*zc && (out_iov || out_sg)) {
                if (out_iov)
-                       n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
+                       n_sgout = 1 +
+                               iov_iter_npages_cap(out_iov, INT_MAX, data_len);
                else
                        n_sgout = sg_nents(out_sg);
                n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,