tls: rx: strp: factor out copying skb data
[linux-2.6-block.git] / net / tls / tls_strp.c
CommitLineData
c618db2a 1// SPDX-License-Identifier: GPL-2.0-only
84c61fe1 2/* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
c618db2a
JK
3
4#include <linux/skbuff.h>
84c61fe1
JK
5#include <linux/workqueue.h>
6#include <net/strparser.h>
7#include <net/tcp.h>
8#include <net/sock.h>
9#include <net/tls.h>
c618db2a
JK
10
11#include "tls.h"
12
84c61fe1
JK
13static struct workqueue_struct *tls_strp_wq;
14
15static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
16{
17 if (strp->stopped)
18 return;
19
20 strp->stopped = 1;
21
22 /* Report an error on the lower socket */
23 strp->sk->sk_err = -err;
24 sk_error_report(strp->sk);
25}
26
27static void tls_strp_anchor_free(struct tls_strparser *strp)
d4e5db64 28{
84c61fe1
JK
29 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
30
31 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
32 shinfo->frag_list = NULL;
33 consume_skb(strp->anchor);
34 strp->anchor = NULL;
35}
36
c1c607b1
JK
37static struct sk_buff *
38tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
39 int offset, int len)
84c61fe1 40{
d4e5db64 41 struct sk_buff *skb;
c1c607b1 42 int i, err;
84c61fe1 43
c1c607b1 44 skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
84c61fe1
JK
45 &err, strp->sk->sk_allocation);
46 if (!skb)
47 return NULL;
48
84c61fe1
JK
49 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
50 skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
d4e5db64 51
c1c607b1 52 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
84c61fe1
JK
53 skb_frag_address(frag),
54 skb_frag_size(frag)));
55 offset += skb_frag_size(frag);
56 }
57
c1c607b1
JK
58 skb->len = len;
59 skb->data_len = len;
60 skb_copy_header(skb, in_skb);
61 return skb;
62}
63
64/* Create a new skb with the contents of input copied to its page frags */
65static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
66{
67 struct strp_msg *rxm;
68 struct sk_buff *skb;
69
70 skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
71 strp->stm.full_len);
72 if (!skb)
73 return NULL;
74
84c61fe1
JK
75 rxm = strp_msg(skb);
76 rxm->offset = 0;
d4e5db64
JK
77 return skb;
78}
79
84c61fe1
JK
80/* Steal the input skb, input msg is invalid after calling this function */
81struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
82{
83 struct tls_strparser *strp = &ctx->strp;
84
85#ifdef CONFIG_TLS_DEVICE
86 DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
87#else
88 /* This function turns an input into an output,
89 * that can only happen if we have offload.
90 */
91 WARN_ON(1);
92#endif
93
94 if (strp->copy_mode) {
95 struct sk_buff *skb;
96
97 /* Replace anchor with an empty skb, this is a little
98 * dangerous but __tls_cur_msg() warns on empty skbs
99 * so hopefully we'll catch abuses.
100 */
101 skb = alloc_skb(0, strp->sk->sk_allocation);
102 if (!skb)
103 return NULL;
104
105 swap(strp->anchor, skb);
106 return skb;
107 }
108
109 return tls_strp_msg_make_copy(strp);
110}
111
112/* Force the input skb to be in copy mode. The data ownership remains
113 * with the input skb itself (meaning unpause will wipe it) but it can
114 * be modified.
115 */
8b3c59a7
JK
116int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
117{
84c61fe1
JK
118 struct tls_strparser *strp = &ctx->strp;
119 struct sk_buff *skb;
120
121 if (strp->copy_mode)
122 return 0;
123
124 skb = tls_strp_msg_make_copy(strp);
125 if (!skb)
126 return -ENOMEM;
127
128 tls_strp_anchor_free(strp);
129 strp->anchor = skb;
130
131 tcp_read_done(strp->sk, strp->stm.full_len);
132 strp->copy_mode = 1;
133
134 return 0;
135}
136
137/* Make a clone (in the skb sense) of the input msg to keep a reference
138 * to the underlying data. The reference-holding skbs get placed on
139 * @dst.
140 */
141int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
142{
143 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
144
145 if (strp->copy_mode) {
146 struct sk_buff *skb;
147
148 WARN_ON_ONCE(!shinfo->nr_frags);
149
150 /* We can't skb_clone() the anchor, it gets wiped by unpause */
151 skb = alloc_skb(0, strp->sk->sk_allocation);
152 if (!skb)
153 return -ENOMEM;
154
155 __skb_queue_tail(dst, strp->anchor);
156 strp->anchor = skb;
157 } else {
158 struct sk_buff *iter, *clone;
159 int chunk, len, offset;
160
161 offset = strp->stm.offset;
162 len = strp->stm.full_len;
163 iter = shinfo->frag_list;
164
165 while (len > 0) {
166 if (iter->len <= offset) {
167 offset -= iter->len;
168 goto next;
169 }
170
171 chunk = iter->len - offset;
172 offset = 0;
173
174 clone = skb_clone(iter, strp->sk->sk_allocation);
175 if (!clone)
176 return -ENOMEM;
177 __skb_queue_tail(dst, clone);
178
179 len -= chunk;
180next:
181 iter = iter->next;
182 }
183 }
184
185 return 0;
186}
187
188static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
189{
190 struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
191 int i;
192
193 DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
194
195 for (i = 0; i < shinfo->nr_frags; i++)
196 __skb_frag_unref(&shinfo->frags[i], false);
197 shinfo->nr_frags = 0;
198 strp->copy_mode = 0;
199}
200
201static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
202 unsigned int offset, size_t in_len)
203{
204 struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
84c61fe1
JK
205 struct sk_buff *skb;
206 skb_frag_t *frag;
8fd1e151
YL
207 size_t len, chunk;
208 int sz;
84c61fe1
JK
209
210 if (strp->msg_ready)
211 return 0;
212
213 skb = strp->anchor;
214 frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
215
216 len = in_len;
217 /* First make sure we got the header */
218 if (!strp->stm.full_len) {
219 /* Assume one page is more than enough for headers */
220 chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
221 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
222 skb_frag_address(frag) +
223 skb_frag_size(frag),
224 chunk));
225
8b0c0dc9
JK
226 skb->len += chunk;
227 skb->data_len += chunk;
228 skb_frag_size_add(frag, chunk);
229
230 sz = tls_rx_msg_size(strp, skb);
84c61fe1
JK
231 if (sz < 0) {
232 desc->error = sz;
233 return 0;
234 }
235
236 /* We may have over-read, sz == 0 is guaranteed under-read */
8b0c0dc9
JK
237 if (unlikely(sz && sz < skb->len)) {
238 int over = skb->len - sz;
239
240 WARN_ON_ONCE(over > chunk);
241 skb->len -= over;
242 skb->data_len -= over;
243 skb_frag_size_add(frag, -over);
244
245 chunk -= over;
246 }
84c61fe1 247
84c61fe1
JK
248 frag++;
249 len -= chunk;
250 offset += chunk;
251
252 strp->stm.full_len = sz;
253 if (!strp->stm.full_len)
254 goto read_done;
255 }
256
257 /* Load up more data */
258 while (len && strp->stm.full_len > skb->len) {
259 chunk = min_t(size_t, len, strp->stm.full_len - skb->len);
260 chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
261 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
262 skb_frag_address(frag) +
263 skb_frag_size(frag),
264 chunk));
265
266 skb->len += chunk;
267 skb->data_len += chunk;
268 skb_frag_size_add(frag, chunk);
269 frag++;
270 len -= chunk;
271 offset += chunk;
272 }
273
274 if (strp->stm.full_len == skb->len) {
275 desc->count = 0;
276
277 strp->msg_ready = 1;
278 tls_rx_msg_ready(strp);
279 }
280
281read_done:
282 return in_len - len;
283}
284
285static int tls_strp_read_copyin(struct tls_strparser *strp)
286{
287 struct socket *sock = strp->sk->sk_socket;
288 read_descriptor_t desc;
289
290 desc.arg.data = strp;
291 desc.error = 0;
292 desc.count = 1; /* give more than one skb per call */
293
294 /* sk should be locked here, so okay to do read_sock */
295 sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
296
297 return desc.error;
298}
299
0d87bbd3 300static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
84c61fe1
JK
301{
302 struct skb_shared_info *shinfo;
303 struct page *page;
304 int need_spc, len;
305
306 /* If the rbuf is small or rcv window has collapsed to 0 we need
307 * to read the data out. Otherwise the connection will stall.
308 * Without pressure threshold of INT_MAX will never be ready.
309 */
0d87bbd3 310 if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX)))
84c61fe1
JK
311 return 0;
312
313 shinfo = skb_shinfo(strp->anchor);
314 shinfo->frag_list = NULL;
315
316 /* If we don't know the length go max plus page for cipher overhead */
317 need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
318
319 for (len = need_spc; len > 0; len -= PAGE_SIZE) {
320 page = alloc_page(strp->sk->sk_allocation);
321 if (!page) {
322 tls_strp_flush_anchor_copy(strp);
323 return -ENOMEM;
324 }
325
326 skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
327 page, 0, 0);
328 }
329
330 strp->copy_mode = 1;
331 strp->stm.offset = 0;
332
333 strp->anchor->len = 0;
334 strp->anchor->data_len = 0;
335 strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
336
337 tls_strp_read_copyin(strp);
338
339 return 0;
340}
341
14c4be92 342static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
0d87bbd3
JK
343{
344 unsigned int len = strp->stm.offset + strp->stm.full_len;
14c4be92 345 struct sk_buff *first, *skb;
0d87bbd3
JK
346 u32 seq;
347
14c4be92
JK
348 first = skb_shinfo(strp->anchor)->frag_list;
349 skb = first;
350 seq = TCP_SKB_CB(first)->seq;
0d87bbd3 351
14c4be92
JK
352 /* Make sure there's no duplicate data in the queue,
353 * and the decrypted status matches.
354 */
0d87bbd3
JK
355 while (skb->len < len) {
356 seq += skb->len;
357 len -= skb->len;
358 skb = skb->next;
359
360 if (TCP_SKB_CB(skb)->seq != seq)
361 return false;
14c4be92
JK
362 if (skb_cmp_decrypted(first, skb))
363 return false;
0d87bbd3
JK
364 }
365
366 return true;
367}
368
84c61fe1
JK
369static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
370{
371 struct tcp_sock *tp = tcp_sk(strp->sk);
372 struct sk_buff *first;
373 u32 offset;
374
375 first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
376 if (WARN_ON_ONCE(!first))
377 return;
378
379 /* Bestow the state onto the anchor */
380 strp->anchor->len = offset + len;
381 strp->anchor->data_len = offset + len;
382 strp->anchor->truesize = offset + len;
383
384 skb_shinfo(strp->anchor)->frag_list = first;
385
386 skb_copy_header(strp->anchor, first);
387 strp->anchor->destructor = NULL;
388
389 strp->stm.offset = offset;
390}
391
392void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
393{
394 struct strp_msg *rxm;
395 struct tls_msg *tlm;
396
397 DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
398 DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
399
400 if (!strp->copy_mode && force_refresh) {
401 if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
402 return;
403
404 tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
405 }
406
407 rxm = strp_msg(strp->anchor);
408 rxm->full_len = strp->stm.full_len;
409 rxm->offset = strp->stm.offset;
410 tlm = tls_msg(strp->anchor);
411 tlm->control = strp->mark;
412}
413
414/* Called with lock held on lower socket */
415static int tls_strp_read_sock(struct tls_strparser *strp)
416{
417 int sz, inq;
418
419 inq = tcp_inq(strp->sk);
420 if (inq < 1)
421 return 0;
422
423 if (unlikely(strp->copy_mode))
424 return tls_strp_read_copyin(strp);
425
426 if (inq < strp->stm.full_len)
0d87bbd3 427 return tls_strp_read_copy(strp, true);
84c61fe1
JK
428
429 if (!strp->stm.full_len) {
430 tls_strp_load_anchor_with_queue(strp, inq);
431
432 sz = tls_rx_msg_size(strp, strp->anchor);
433 if (sz < 0) {
434 tls_strp_abort_strp(strp, sz);
435 return sz;
436 }
437
438 strp->stm.full_len = sz;
439
440 if (!strp->stm.full_len || inq < strp->stm.full_len)
0d87bbd3 441 return tls_strp_read_copy(strp, true);
84c61fe1
JK
442 }
443
14c4be92 444 if (!tls_strp_check_queue_ok(strp))
0d87bbd3
JK
445 return tls_strp_read_copy(strp, false);
446
84c61fe1
JK
447 strp->msg_ready = 1;
448 tls_rx_msg_ready(strp);
449
450 return 0;
451}
452
453void tls_strp_check_rcv(struct tls_strparser *strp)
454{
455 if (unlikely(strp->stopped) || strp->msg_ready)
456 return;
457
458 if (tls_strp_read_sock(strp) == -ENOMEM)
459 queue_work(tls_strp_wq, &strp->work);
460}
461
462/* Lower sock lock held */
463void tls_strp_data_ready(struct tls_strparser *strp)
464{
465 /* This check is needed to synchronize with do_tls_strp_work.
466 * do_tls_strp_work acquires a process lock (lock_sock) whereas
467 * the lock held here is bh_lock_sock. The two locks can be
468 * held by different threads at the same time, but bh_lock_sock
469 * allows a thread in BH context to safely check if the process
470 * lock is held. In this case, if the lock is held, queue work.
471 */
472 if (sock_owned_by_user_nocheck(strp->sk)) {
473 queue_work(tls_strp_wq, &strp->work);
474 return;
475 }
476
477 tls_strp_check_rcv(strp);
478}
479
480static void tls_strp_work(struct work_struct *w)
481{
482 struct tls_strparser *strp =
483 container_of(w, struct tls_strparser, work);
484
485 lock_sock(strp->sk);
486 tls_strp_check_rcv(strp);
487 release_sock(strp->sk);
488}
489
490void tls_strp_msg_done(struct tls_strparser *strp)
491{
492 WARN_ON(!strp->stm.full_len);
493
494 if (likely(!strp->copy_mode))
495 tcp_read_done(strp->sk, strp->stm.full_len);
496 else
497 tls_strp_flush_anchor_copy(strp);
498
499 strp->msg_ready = 0;
500 memset(&strp->stm, 0, sizeof(strp->stm));
501
502 tls_strp_check_rcv(strp);
503}
504
505void tls_strp_stop(struct tls_strparser *strp)
506{
507 strp->stopped = 1;
508}
509
510int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
511{
512 memset(strp, 0, sizeof(*strp));
513
514 strp->sk = sk;
515
516 strp->anchor = alloc_skb(0, GFP_KERNEL);
517 if (!strp->anchor)
518 return -ENOMEM;
519
520 INIT_WORK(&strp->work, tls_strp_work);
8b3c59a7 521
8b3c59a7
JK
522 return 0;
523}
524
84c61fe1
JK
525/* strp must already be stopped so that tls_strp_recv will no longer be called.
526 * Note that tls_strp_done is not called with the lower socket held.
527 */
528void tls_strp_done(struct tls_strparser *strp)
c618db2a 529{
84c61fe1 530 WARN_ON(!strp->stopped);
c618db2a 531
84c61fe1
JK
532 cancel_work_sync(&strp->work);
533 tls_strp_anchor_free(strp);
534}
535
536int __init tls_strp_dev_init(void)
537{
d11ef9cc 538 tls_strp_wq = create_workqueue("tls-strp");
84c61fe1 539 if (unlikely(!tls_strp_wq))
c618db2a 540 return -ENOMEM;
84c61fe1 541
c618db2a
JK
542 return 0;
543}
84c61fe1
JK
544
545void tls_strp_dev_exit(void)
546{
547 destroy_workqueue(tls_strp_wq);
548}