ip6_gre: simplify gre header parsing in ip6gre_err
[linux-block.git] / net / tls / tls_sw.c
CommitLineData
3c4d7559
DW
1/*
2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7 *
8 * This software is available to you under a choice of one of two
9 * licenses. You may choose to be licensed under the terms of the GNU
10 * General Public License (GPL) Version 2, available from the file
11 * COPYING in the main directory of this source tree, or the
12 * OpenIB.org BSD license below:
13 *
14 * Redistribution and use in source and binary forms, with or
15 * without modification, are permitted provided that the following
16 * conditions are met:
17 *
18 * - Redistributions of source code must retain the above
19 * copyright notice, this list of conditions and the following
20 * disclaimer.
21 *
22 * - Redistributions in binary form must reproduce the above
23 * copyright notice, this list of conditions and the following
24 * disclaimer in the documentation and/or other materials
25 * provided with the distribution.
26 *
27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34 * SOFTWARE.
35 */
36
c46234eb 37#include <linux/sched/signal.h>
3c4d7559
DW
38#include <linux/module.h>
39#include <crypto/aead.h>
40
c46234eb 41#include <net/strparser.h>
3c4d7559
DW
42#include <net/tls.h>
43
b16520f7
KC
44#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
45
0927f71d
DRK
46static int __skb_nsg(struct sk_buff *skb, int offset, int len,
47 unsigned int recursion_level)
48{
49 int start = skb_headlen(skb);
50 int i, chunk = start - offset;
51 struct sk_buff *frag_iter;
52 int elt = 0;
53
54 if (unlikely(recursion_level >= 24))
55 return -EMSGSIZE;
56
57 if (chunk > 0) {
58 if (chunk > len)
59 chunk = len;
60 elt++;
61 len -= chunk;
62 if (len == 0)
63 return elt;
64 offset += chunk;
65 }
66
67 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
68 int end;
69
70 WARN_ON(start > offset + len);
71
72 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
73 chunk = end - offset;
74 if (chunk > 0) {
75 if (chunk > len)
76 chunk = len;
77 elt++;
78 len -= chunk;
79 if (len == 0)
80 return elt;
81 offset += chunk;
82 }
83 start = end;
84 }
85
86 if (unlikely(skb_has_frag_list(skb))) {
87 skb_walk_frags(skb, frag_iter) {
88 int end, ret;
89
90 WARN_ON(start > offset + len);
91
92 end = start + frag_iter->len;
93 chunk = end - offset;
94 if (chunk > 0) {
95 if (chunk > len)
96 chunk = len;
97 ret = __skb_nsg(frag_iter, offset - start, chunk,
98 recursion_level + 1);
99 if (unlikely(ret < 0))
100 return ret;
101 elt += ret;
102 len -= chunk;
103 if (len == 0)
104 return elt;
105 offset += chunk;
106 }
107 start = end;
108 }
109 }
110 BUG_ON(len);
111 return elt;
112}
113
114/* Return the number of scatterlist elements required to completely map the
115 * skb, or -EMSGSIZE if the recursion depth is exceeded.
116 */
117static int skb_nsg(struct sk_buff *skb, int offset, int len)
118{
119 return __skb_nsg(skb, offset, len, 0);
120}
121
94524d8f
VG
122static void tls_decrypt_done(struct crypto_async_request *req, int err)
123{
124 struct aead_request *aead_req = (struct aead_request *)req;
125 struct decrypt_req_ctx *req_ctx =
126 (struct decrypt_req_ctx *)(aead_req + 1);
127
128 struct scatterlist *sgout = aead_req->dst;
129
130 struct tls_context *tls_ctx = tls_get_ctx(req_ctx->sk);
131 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
132 int pending = atomic_dec_return(&ctx->decrypt_pending);
133 struct scatterlist *sg;
134 unsigned int pages;
135
136 /* Propagate if there was an err */
137 if (err) {
138 ctx->async_wait.err = err;
139 tls_err_abort(req_ctx->sk, err);
140 }
141
142 /* Release the skb, pages and memory allocated for crypto req */
143 kfree_skb(req->data);
144
145 /* Skip the first S/G entry as it points to AAD */
146 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
147 if (!sg)
148 break;
149 put_page(sg_page(sg));
150 }
151
152 kfree(aead_req);
153
154 if (!pending && READ_ONCE(ctx->async_notify))
155 complete(&ctx->async_wait.completion);
156}
157
c46234eb 158static int tls_do_decryption(struct sock *sk,
94524d8f 159 struct sk_buff *skb,
c46234eb
DW
160 struct scatterlist *sgin,
161 struct scatterlist *sgout,
162 char *iv_recv,
163 size_t data_len,
94524d8f
VG
164 struct aead_request *aead_req,
165 bool async)
c46234eb
DW
166{
167 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 168 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb 169 int ret;
c46234eb 170
0b243d00 171 aead_request_set_tfm(aead_req, ctx->aead_recv);
c46234eb
DW
172 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
173 aead_request_set_crypt(aead_req, sgin, sgout,
174 data_len + tls_ctx->rx.tag_size,
175 (u8 *)iv_recv);
c46234eb 176
94524d8f
VG
177 if (async) {
178 struct decrypt_req_ctx *req_ctx;
179
180 req_ctx = (struct decrypt_req_ctx *)(aead_req + 1);
181 req_ctx->sk = sk;
182
183 aead_request_set_callback(aead_req,
184 CRYPTO_TFM_REQ_MAY_BACKLOG,
185 tls_decrypt_done, skb);
186 atomic_inc(&ctx->decrypt_pending);
187 } else {
188 aead_request_set_callback(aead_req,
189 CRYPTO_TFM_REQ_MAY_BACKLOG,
190 crypto_req_done, &ctx->async_wait);
191 }
192
193 ret = crypto_aead_decrypt(aead_req);
194 if (ret == -EINPROGRESS) {
195 if (async)
196 return ret;
197
198 ret = crypto_wait_req(ret, &ctx->async_wait);
199 }
200
201 if (async)
202 atomic_dec(&ctx->decrypt_pending);
203
c46234eb
DW
204 return ret;
205}
206
3c4d7559
DW
207static void trim_sg(struct sock *sk, struct scatterlist *sg,
208 int *sg_num_elem, unsigned int *sg_size, int target_size)
209{
210 int i = *sg_num_elem - 1;
211 int trim = *sg_size - target_size;
212
213 if (trim <= 0) {
214 WARN_ON(trim < 0);
215 return;
216 }
217
218 *sg_size = target_size;
219 while (trim >= sg[i].length) {
220 trim -= sg[i].length;
221 sk_mem_uncharge(sk, sg[i].length);
222 put_page(sg_page(&sg[i]));
223 i--;
224
225 if (i < 0)
226 goto out;
227 }
228
229 sg[i].length -= trim;
230 sk_mem_uncharge(sk, trim);
231
232out:
233 *sg_num_elem = i + 1;
234}
235
236static void trim_both_sgl(struct sock *sk, int target_size)
237{
238 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 239 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559
DW
240
241 trim_sg(sk, ctx->sg_plaintext_data,
242 &ctx->sg_plaintext_num_elem,
243 &ctx->sg_plaintext_size,
244 target_size);
245
246 if (target_size > 0)
dbe42559 247 target_size += tls_ctx->tx.overhead_size;
3c4d7559
DW
248
249 trim_sg(sk, ctx->sg_encrypted_data,
250 &ctx->sg_encrypted_num_elem,
251 &ctx->sg_encrypted_size,
252 target_size);
253}
254
3c4d7559
DW
255static int alloc_encrypted_sg(struct sock *sk, int len)
256{
257 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 258 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559
DW
259 int rc = 0;
260
2c3682f0 261 rc = sk_alloc_sg(sk, len,
8c05dbf0 262 ctx->sg_encrypted_data, 0,
2c3682f0
JF
263 &ctx->sg_encrypted_num_elem,
264 &ctx->sg_encrypted_size, 0);
3c4d7559 265
52ea992c
VG
266 if (rc == -ENOSPC)
267 ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
268
3c4d7559
DW
269 return rc;
270}
271
272static int alloc_plaintext_sg(struct sock *sk, int len)
273{
274 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 275 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559
DW
276 int rc = 0;
277
8c05dbf0 278 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
2c3682f0
JF
279 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
280 tls_ctx->pending_open_record_frags);
3c4d7559 281
52ea992c
VG
282 if (rc == -ENOSPC)
283 ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
284
3c4d7559
DW
285 return rc;
286}
287
288static void free_sg(struct sock *sk, struct scatterlist *sg,
289 int *sg_num_elem, unsigned int *sg_size)
290{
291 int i, n = *sg_num_elem;
292
293 for (i = 0; i < n; ++i) {
294 sk_mem_uncharge(sk, sg[i].length);
295 put_page(sg_page(&sg[i]));
296 }
297 *sg_num_elem = 0;
298 *sg_size = 0;
299}
300
301static void tls_free_both_sg(struct sock *sk)
302{
303 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 304 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559
DW
305
306 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
307 &ctx->sg_encrypted_size);
308
309 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
310 &ctx->sg_plaintext_size);
311}
312
313static int tls_do_encryption(struct tls_context *tls_ctx,
a447da7d
DB
314 struct tls_sw_context_tx *ctx,
315 struct aead_request *aead_req,
316 size_t data_len)
3c4d7559 317{
3c4d7559
DW
318 int rc;
319
dbe42559
DW
320 ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
321 ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
3c4d7559
DW
322
323 aead_request_set_tfm(aead_req, ctx->aead_send);
324 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
325 aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
dbe42559 326 data_len, tls_ctx->tx.iv);
a54667f6
VG
327
328 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
329 crypto_req_done, &ctx->async_wait);
330
331 rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
3c4d7559 332
dbe42559
DW
333 ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
334 ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
3c4d7559 335
3c4d7559
DW
336 return rc;
337}
338
339static int tls_push_record(struct sock *sk, int flags,
340 unsigned char record_type)
341{
342 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 343 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
a447da7d 344 struct aead_request *req;
3c4d7559
DW
345 int rc;
346
d2bdd268 347 req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
a447da7d
DB
348 if (!req)
349 return -ENOMEM;
350
3c4d7559
DW
351 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
352 sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
353
213ef6e7 354 tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
dbe42559 355 tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
3c4d7559
DW
356 record_type);
357
358 tls_fill_prepend(tls_ctx,
359 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
360 ctx->sg_encrypted_data[0].offset,
361 ctx->sg_plaintext_size, record_type);
362
363 tls_ctx->pending_open_record_frags = 0;
364 set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
365
a447da7d 366 rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
3c4d7559
DW
367 if (rc < 0) {
368 /* If we are called from write_space and
369 * we fail, we need to set this SOCK_NOSPACE
370 * to trigger another write_space in the future.
371 */
372 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
a447da7d 373 goto out_req;
3c4d7559
DW
374 }
375
376 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
377 &ctx->sg_plaintext_size);
378
379 ctx->sg_encrypted_num_elem = 0;
380 ctx->sg_encrypted_size = 0;
381
382 /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
383 rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
384 if (rc < 0 && rc != -EAGAIN)
f4a8e43f 385 tls_err_abort(sk, EBADMSG);
3c4d7559 386
dbe42559 387 tls_advance_record_sn(sk, &tls_ctx->tx);
a447da7d 388out_req:
d2bdd268 389 aead_request_free(req);
3c4d7559
DW
390 return rc;
391}
392
393static int tls_sw_push_pending_record(struct sock *sk, int flags)
394{
395 return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
396}
397
398static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
69ca9293
DW
399 int length, int *pages_used,
400 unsigned int *size_used,
401 struct scatterlist *to, int to_max_pages,
2da19ed3 402 bool charge)
3c4d7559 403{
3c4d7559
DW
404 struct page *pages[MAX_SKB_FRAGS];
405
406 size_t offset;
407 ssize_t copied, use;
408 int i = 0;
69ca9293
DW
409 unsigned int size = *size_used;
410 int num_elem = *pages_used;
3c4d7559
DW
411 int rc = 0;
412 int maxpages;
413
414 while (length > 0) {
415 i = 0;
69ca9293 416 maxpages = to_max_pages - num_elem;
3c4d7559
DW
417 if (maxpages == 0) {
418 rc = -EFAULT;
419 goto out;
420 }
421 copied = iov_iter_get_pages(from, pages,
422 length,
423 maxpages, &offset);
424 if (copied <= 0) {
425 rc = -EFAULT;
426 goto out;
427 }
428
429 iov_iter_advance(from, copied);
430
431 length -= copied;
432 size += copied;
433 while (copied) {
434 use = min_t(int, copied, PAGE_SIZE - offset);
435
69ca9293 436 sg_set_page(&to[num_elem],
3c4d7559 437 pages[i], use, offset);
69ca9293
DW
438 sg_unmark_end(&to[num_elem]);
439 if (charge)
440 sk_mem_charge(sk, use);
3c4d7559
DW
441
442 offset = 0;
443 copied -= use;
444
445 ++i;
446 ++num_elem;
447 }
448 }
449
cfb4099f
VG
450 /* Mark the end in the last sg entry if newly added */
451 if (num_elem > *pages_used)
452 sg_mark_end(&to[num_elem - 1]);
3c4d7559 453out:
2da19ed3
DRK
454 if (rc)
455 iov_iter_revert(from, size - *size_used);
69ca9293
DW
456 *size_used = size;
457 *pages_used = num_elem;
458
3c4d7559
DW
459 return rc;
460}
461
462static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
463 int bytes)
464{
465 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 466 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559
DW
467 struct scatterlist *sg = ctx->sg_plaintext_data;
468 int copy, i, rc = 0;
469
470 for (i = tls_ctx->pending_open_record_frags;
471 i < ctx->sg_plaintext_num_elem; ++i) {
472 copy = sg[i].length;
473 if (copy_from_iter(
474 page_address(sg_page(&sg[i])) + sg[i].offset,
475 copy, from) != copy) {
476 rc = -EFAULT;
477 goto out;
478 }
479 bytes -= copy;
480
481 ++tls_ctx->pending_open_record_frags;
482
483 if (!bytes)
484 break;
485 }
486
487out:
488 return rc;
489}
490
491int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
492{
493 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 494 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
15008579 495 int ret;
3c4d7559
DW
496 int required_size;
497 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
498 bool eor = !(msg->msg_flags & MSG_MORE);
499 size_t try_to_copy, copied = 0;
500 unsigned char record_type = TLS_RECORD_TYPE_DATA;
501 int record_room;
502 bool full_record;
503 int orig_size;
0a26cf3f 504 bool is_kvec = msg->msg_iter.type & ITER_KVEC;
3c4d7559
DW
505
506 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
507 return -ENOTSUPP;
508
509 lock_sock(sk);
510
15008579
VG
511 ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
512 if (ret)
3c4d7559
DW
513 goto send_end;
514
515 if (unlikely(msg->msg_controllen)) {
516 ret = tls_proccess_cmsg(sk, msg, &record_type);
517 if (ret)
518 goto send_end;
519 }
520
521 while (msg_data_left(msg)) {
522 if (sk->sk_err) {
30be8f8d 523 ret = -sk->sk_err;
3c4d7559
DW
524 goto send_end;
525 }
526
527 orig_size = ctx->sg_plaintext_size;
528 full_record = false;
529 try_to_copy = msg_data_left(msg);
530 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
531 if (try_to_copy >= record_room) {
532 try_to_copy = record_room;
533 full_record = true;
534 }
535
536 required_size = ctx->sg_plaintext_size + try_to_copy +
dbe42559 537 tls_ctx->tx.overhead_size;
3c4d7559
DW
538
539 if (!sk_stream_memory_free(sk))
540 goto wait_for_sndbuf;
541alloc_encrypted:
542 ret = alloc_encrypted_sg(sk, required_size);
543 if (ret) {
544 if (ret != -ENOSPC)
545 goto wait_for_memory;
546
547 /* Adjust try_to_copy according to the amount that was
548 * actually allocated. The difference is due
549 * to max sg elements limit
550 */
551 try_to_copy -= required_size - ctx->sg_encrypted_size;
552 full_record = true;
553 }
0a26cf3f 554 if (!is_kvec && (full_record || eor)) {
3c4d7559 555 ret = zerocopy_from_iter(sk, &msg->msg_iter,
69ca9293
DW
556 try_to_copy, &ctx->sg_plaintext_num_elem,
557 &ctx->sg_plaintext_size,
558 ctx->sg_plaintext_data,
559 ARRAY_SIZE(ctx->sg_plaintext_data),
2da19ed3 560 true);
3c4d7559
DW
561 if (ret)
562 goto fallback_to_reg_send;
563
564 copied += try_to_copy;
565 ret = tls_push_record(sk, msg->msg_flags, record_type);
5a3611ef 566 if (ret)
3c4d7559 567 goto send_end;
5a3611ef 568 continue;
3c4d7559 569
3c4d7559 570fallback_to_reg_send:
3c4d7559
DW
571 trim_sg(sk, ctx->sg_plaintext_data,
572 &ctx->sg_plaintext_num_elem,
573 &ctx->sg_plaintext_size,
574 orig_size);
575 }
576
577 required_size = ctx->sg_plaintext_size + try_to_copy;
578alloc_plaintext:
579 ret = alloc_plaintext_sg(sk, required_size);
580 if (ret) {
581 if (ret != -ENOSPC)
582 goto wait_for_memory;
583
584 /* Adjust try_to_copy according to the amount that was
585 * actually allocated. The difference is due
586 * to max sg elements limit
587 */
588 try_to_copy -= required_size - ctx->sg_plaintext_size;
589 full_record = true;
590
591 trim_sg(sk, ctx->sg_encrypted_data,
592 &ctx->sg_encrypted_num_elem,
593 &ctx->sg_encrypted_size,
594 ctx->sg_plaintext_size +
dbe42559 595 tls_ctx->tx.overhead_size);
3c4d7559
DW
596 }
597
598 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
599 if (ret)
600 goto trim_sgl;
601
602 copied += try_to_copy;
603 if (full_record || eor) {
604push_record:
605 ret = tls_push_record(sk, msg->msg_flags, record_type);
606 if (ret) {
607 if (ret == -ENOMEM)
608 goto wait_for_memory;
609
610 goto send_end;
611 }
612 }
613
614 continue;
615
616wait_for_sndbuf:
617 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
618wait_for_memory:
619 ret = sk_stream_wait_memory(sk, &timeo);
620 if (ret) {
621trim_sgl:
622 trim_both_sgl(sk, orig_size);
623 goto send_end;
624 }
625
626 if (tls_is_pending_closed_record(tls_ctx))
627 goto push_record;
628
629 if (ctx->sg_encrypted_size < required_size)
630 goto alloc_encrypted;
631
632 goto alloc_plaintext;
633 }
634
635send_end:
636 ret = sk_stream_error(sk, msg->msg_flags, ret);
637
638 release_sock(sk);
639 return copied ? copied : ret;
640}
641
642int tls_sw_sendpage(struct sock *sk, struct page *page,
643 int offset, size_t size, int flags)
644{
645 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 646 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
15008579 647 int ret;
3c4d7559
DW
648 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
649 bool eor;
650 size_t orig_size = size;
651 unsigned char record_type = TLS_RECORD_TYPE_DATA;
652 struct scatterlist *sg;
653 bool full_record;
654 int record_room;
655
656 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
657 MSG_SENDPAGE_NOTLAST))
658 return -ENOTSUPP;
659
660 /* No MSG_EOR from splice, only look at MSG_MORE */
661 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
662
663 lock_sock(sk);
664
665 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
666
15008579
VG
667 ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
668 if (ret)
3c4d7559
DW
669 goto sendpage_end;
670
671 /* Call the sk_stream functions to manage the sndbuf mem. */
672 while (size > 0) {
673 size_t copy, required_size;
674
675 if (sk->sk_err) {
30be8f8d 676 ret = -sk->sk_err;
3c4d7559
DW
677 goto sendpage_end;
678 }
679
680 full_record = false;
681 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
682 copy = size;
683 if (copy >= record_room) {
684 copy = record_room;
685 full_record = true;
686 }
687 required_size = ctx->sg_plaintext_size + copy +
dbe42559 688 tls_ctx->tx.overhead_size;
3c4d7559
DW
689
690 if (!sk_stream_memory_free(sk))
691 goto wait_for_sndbuf;
692alloc_payload:
693 ret = alloc_encrypted_sg(sk, required_size);
694 if (ret) {
695 if (ret != -ENOSPC)
696 goto wait_for_memory;
697
698 /* Adjust copy according to the amount that was
699 * actually allocated. The difference is due
700 * to max sg elements limit
701 */
702 copy -= required_size - ctx->sg_plaintext_size;
703 full_record = true;
704 }
705
706 get_page(page);
707 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
708 sg_set_page(sg, page, copy, offset);
7a8c4dd9
DW
709 sg_unmark_end(sg);
710
3c4d7559
DW
711 ctx->sg_plaintext_num_elem++;
712
713 sk_mem_charge(sk, copy);
714 offset += copy;
715 size -= copy;
716 ctx->sg_plaintext_size += copy;
717 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
718
719 if (full_record || eor ||
720 ctx->sg_plaintext_num_elem ==
721 ARRAY_SIZE(ctx->sg_plaintext_data)) {
722push_record:
723 ret = tls_push_record(sk, flags, record_type);
724 if (ret) {
725 if (ret == -ENOMEM)
726 goto wait_for_memory;
727
728 goto sendpage_end;
729 }
730 }
731 continue;
732wait_for_sndbuf:
733 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
734wait_for_memory:
735 ret = sk_stream_wait_memory(sk, &timeo);
736 if (ret) {
737 trim_both_sgl(sk, ctx->sg_plaintext_size);
738 goto sendpage_end;
739 }
740
741 if (tls_is_pending_closed_record(tls_ctx))
742 goto push_record;
743
744 goto alloc_payload;
745 }
746
747sendpage_end:
748 if (orig_size > size)
749 ret = orig_size - size;
750 else
751 ret = sk_stream_error(sk, flags, ret);
752
753 release_sock(sk);
754 return ret;
755}
756
c46234eb
DW
757static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
758 long timeo, int *err)
759{
760 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 761 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb
DW
762 struct sk_buff *skb;
763 DEFINE_WAIT_FUNC(wait, woken_wake_function);
764
765 while (!(skb = ctx->recv_pkt)) {
766 if (sk->sk_err) {
767 *err = sock_error(sk);
768 return NULL;
769 }
770
fcf4793e
DRK
771 if (sk->sk_shutdown & RCV_SHUTDOWN)
772 return NULL;
773
c46234eb
DW
774 if (sock_flag(sk, SOCK_DONE))
775 return NULL;
776
777 if ((flags & MSG_DONTWAIT) || !timeo) {
778 *err = -EAGAIN;
779 return NULL;
780 }
781
782 add_wait_queue(sk_sleep(sk), &wait);
783 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
784 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
785 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
786 remove_wait_queue(sk_sleep(sk), &wait);
787
788 /* Handle signals */
789 if (signal_pending(current)) {
790 *err = sock_intr_errno(timeo);
791 return NULL;
792 }
793 }
794
795 return skb;
796}
797
0b243d00
VG
798/* This function decrypts the input skb into either out_iov or in out_sg
799 * or in skb buffers itself. The input parameter 'zc' indicates if
800 * zero-copy mode needs to be tried or not. With zero-copy mode, either
801 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
802 * NULL, then the decryption happens inside skb buffers itself, i.e.
803 * zero-copy gets disabled and 'zc' is updated.
804 */
805
806static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
807 struct iov_iter *out_iov,
808 struct scatterlist *out_sg,
809 int *chunk, bool *zc)
810{
811 struct tls_context *tls_ctx = tls_get_ctx(sk);
812 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
813 struct strp_msg *rxm = strp_msg(skb);
814 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
815 struct aead_request *aead_req;
816 struct sk_buff *unused;
817 u8 *aad, *iv, *mem = NULL;
818 struct scatterlist *sgin = NULL;
819 struct scatterlist *sgout = NULL;
820 const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
821
822 if (*zc && (out_iov || out_sg)) {
823 if (out_iov)
824 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
825 else
826 n_sgout = sg_nents(out_sg);
0927f71d
DRK
827 n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
828 rxm->full_len - tls_ctx->rx.prepend_size);
0b243d00
VG
829 } else {
830 n_sgout = 0;
831 *zc = false;
0927f71d 832 n_sgin = skb_cow_data(skb, 0, &unused);
0b243d00
VG
833 }
834
0b243d00
VG
835 if (n_sgin < 1)
836 return -EBADMSG;
837
838 /* Increment to accommodate AAD */
839 n_sgin = n_sgin + 1;
840
841 nsg = n_sgin + n_sgout;
842
843 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
844 mem_size = aead_size + (nsg * sizeof(struct scatterlist));
845 mem_size = mem_size + TLS_AAD_SPACE_SIZE;
846 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
847
848 /* Allocate a single block of memory which contains
849 * aead_req || sgin[] || sgout[] || aad || iv.
850 * This order achieves correct alignment for aead_req, sgin, sgout.
851 */
852 mem = kmalloc(mem_size, sk->sk_allocation);
853 if (!mem)
854 return -ENOMEM;
855
856 /* Segment the allocated memory */
857 aead_req = (struct aead_request *)mem;
858 sgin = (struct scatterlist *)(mem + aead_size);
859 sgout = sgin + n_sgin;
860 aad = (u8 *)(sgout + n_sgout);
861 iv = aad + TLS_AAD_SPACE_SIZE;
862
863 /* Prepare IV */
864 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
865 iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
866 tls_ctx->rx.iv_size);
867 if (err < 0) {
868 kfree(mem);
869 return err;
870 }
871 memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
872
873 /* Prepare AAD */
874 tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
875 tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
876 ctx->control);
877
878 /* Prepare sgin */
879 sg_init_table(sgin, n_sgin);
880 sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
881 err = skb_to_sgvec(skb, &sgin[1],
882 rxm->offset + tls_ctx->rx.prepend_size,
883 rxm->full_len - tls_ctx->rx.prepend_size);
884 if (err < 0) {
885 kfree(mem);
886 return err;
887 }
888
889 if (n_sgout) {
890 if (out_iov) {
891 sg_init_table(sgout, n_sgout);
892 sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
893
894 *chunk = 0;
895 err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
896 chunk, &sgout[1],
897 (n_sgout - 1), false);
898 if (err < 0)
899 goto fallback_to_reg_recv;
900 } else if (out_sg) {
901 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
902 } else {
903 goto fallback_to_reg_recv;
904 }
905 } else {
906fallback_to_reg_recv:
907 sgout = sgin;
908 pages = 0;
909 *chunk = 0;
910 *zc = false;
911 }
912
913 /* Prepare and submit AEAD request */
94524d8f
VG
914 err = tls_do_decryption(sk, skb, sgin, sgout, iv,
915 data_len, aead_req, *zc);
916 if (err == -EINPROGRESS)
917 return err;
0b243d00
VG
918
919 /* Release the pages in case iov was mapped to pages */
920 for (; pages > 0; pages--)
921 put_page(sg_page(&sgout[pages]));
922
923 kfree(mem);
924 return err;
925}
926
dafb67f3 927static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
0b243d00 928 struct iov_iter *dest, int *chunk, bool *zc)
dafb67f3
BP
929{
930 struct tls_context *tls_ctx = tls_get_ctx(sk);
931 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
932 struct strp_msg *rxm = strp_msg(skb);
933 int err = 0;
934
4799ac81
BP
935#ifdef CONFIG_TLS_DEVICE
936 err = tls_device_decrypted(sk, skb);
dafb67f3
BP
937 if (err < 0)
938 return err;
4799ac81
BP
939#endif
940 if (!ctx->decrypted) {
0b243d00 941 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
94524d8f
VG
942 if (err < 0) {
943 if (err == -EINPROGRESS)
944 tls_advance_record_sn(sk, &tls_ctx->rx);
945
4799ac81 946 return err;
94524d8f 947 }
4799ac81
BP
948 } else {
949 *zc = false;
950 }
dafb67f3
BP
951
952 rxm->offset += tls_ctx->rx.prepend_size;
953 rxm->full_len -= tls_ctx->rx.overhead_size;
954 tls_advance_record_sn(sk, &tls_ctx->rx);
955 ctx->decrypted = true;
956 ctx->saved_data_ready(sk);
957
958 return err;
959}
960
961int decrypt_skb(struct sock *sk, struct sk_buff *skb,
962 struct scatterlist *sgout)
c46234eb 963{
0b243d00
VG
964 bool zc = true;
965 int chunk;
c46234eb 966
0b243d00 967 return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
c46234eb
DW
968}
969
970static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
971 unsigned int len)
972{
973 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 974 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb 975
94524d8f
VG
976 if (skb) {
977 struct strp_msg *rxm = strp_msg(skb);
c46234eb 978
94524d8f
VG
979 if (len < rxm->full_len) {
980 rxm->offset += len;
981 rxm->full_len -= len;
982 return false;
983 }
984 kfree_skb(skb);
c46234eb
DW
985 }
986
987 /* Finished with message */
988 ctx->recv_pkt = NULL;
7170e604 989 __strp_unpause(&ctx->strp);
c46234eb
DW
990
991 return true;
992}
993
994int tls_sw_recvmsg(struct sock *sk,
995 struct msghdr *msg,
996 size_t len,
997 int nonblock,
998 int flags,
999 int *addr_len)
1000{
1001 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 1002 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb
DW
1003 unsigned char control;
1004 struct strp_msg *rxm;
1005 struct sk_buff *skb;
1006 ssize_t copied = 0;
1007 bool cmsg = false;
06030dba 1008 int target, err = 0;
c46234eb 1009 long timeo;
0a26cf3f 1010 bool is_kvec = msg->msg_iter.type & ITER_KVEC;
94524d8f 1011 int num_async = 0;
c46234eb
DW
1012
1013 flags |= nonblock;
1014
1015 if (unlikely(flags & MSG_ERRQUEUE))
1016 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1017
1018 lock_sock(sk);
1019
06030dba 1020 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
c46234eb
DW
1021 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1022 do {
1023 bool zc = false;
94524d8f 1024 bool async = false;
c46234eb
DW
1025 int chunk = 0;
1026
1027 skb = tls_wait_data(sk, flags, timeo, &err);
1028 if (!skb)
1029 goto recv_end;
1030
1031 rxm = strp_msg(skb);
94524d8f 1032
c46234eb
DW
1033 if (!cmsg) {
1034 int cerr;
1035
1036 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1037 sizeof(ctx->control), &ctx->control);
1038 cmsg = true;
1039 control = ctx->control;
1040 if (ctx->control != TLS_RECORD_TYPE_DATA) {
1041 if (cerr || msg->msg_flags & MSG_CTRUNC) {
1042 err = -EIO;
1043 goto recv_end;
1044 }
1045 }
1046 } else if (control != ctx->control) {
1047 goto recv_end;
1048 }
1049
1050 if (!ctx->decrypted) {
0b243d00 1051 int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
c46234eb 1052
0b243d00
VG
1053 if (!is_kvec && to_copy <= len &&
1054 likely(!(flags & MSG_PEEK)))
c46234eb 1055 zc = true;
0b243d00
VG
1056
1057 err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1058 &chunk, &zc);
94524d8f 1059 if (err < 0 && err != -EINPROGRESS) {
0b243d00
VG
1060 tls_err_abort(sk, EBADMSG);
1061 goto recv_end;
c46234eb 1062 }
94524d8f
VG
1063
1064 if (err == -EINPROGRESS) {
1065 async = true;
1066 num_async++;
1067 goto pick_next_record;
1068 }
1069
c46234eb
DW
1070 ctx->decrypted = true;
1071 }
1072
1073 if (!zc) {
1074 chunk = min_t(unsigned int, rxm->full_len, len);
94524d8f 1075
c46234eb
DW
1076 err = skb_copy_datagram_msg(skb, rxm->offset, msg,
1077 chunk);
1078 if (err < 0)
1079 goto recv_end;
1080 }
1081
94524d8f 1082pick_next_record:
c46234eb
DW
1083 copied += chunk;
1084 len -= chunk;
1085 if (likely(!(flags & MSG_PEEK))) {
1086 u8 control = ctx->control;
1087
94524d8f
VG
1088 /* For async, drop current skb reference */
1089 if (async)
1090 skb = NULL;
1091
c46234eb
DW
1092 if (tls_sw_advance_skb(sk, skb, chunk)) {
1093 /* Return full control message to
1094 * userspace before trying to parse
1095 * another message type
1096 */
1097 msg->msg_flags |= MSG_EOR;
1098 if (control != TLS_RECORD_TYPE_DATA)
1099 goto recv_end;
94524d8f
VG
1100 } else {
1101 break;
c46234eb
DW
1102 }
1103 }
94524d8f 1104
06030dba
DB
1105 /* If we have a new message from strparser, continue now. */
1106 if (copied >= target && !ctx->recv_pkt)
1107 break;
c46234eb
DW
1108 } while (len);
1109
1110recv_end:
94524d8f
VG
1111 if (num_async) {
1112 /* Wait for all previously submitted records to be decrypted */
1113 smp_store_mb(ctx->async_notify, true);
1114 if (atomic_read(&ctx->decrypt_pending)) {
1115 err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1116 if (err) {
1117 /* one of async decrypt failed */
1118 tls_err_abort(sk, err);
1119 copied = 0;
1120 }
1121 } else {
1122 reinit_completion(&ctx->async_wait.completion);
1123 }
1124 WRITE_ONCE(ctx->async_notify, false);
1125 }
1126
c46234eb
DW
1127 release_sock(sk);
1128 return copied ? : err;
1129}
1130
1131ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
1132 struct pipe_inode_info *pipe,
1133 size_t len, unsigned int flags)
1134{
1135 struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
f66de3ee 1136 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb
DW
1137 struct strp_msg *rxm = NULL;
1138 struct sock *sk = sock->sk;
1139 struct sk_buff *skb;
1140 ssize_t copied = 0;
1141 int err = 0;
1142 long timeo;
1143 int chunk;
0b243d00 1144 bool zc = false;
c46234eb
DW
1145
1146 lock_sock(sk);
1147
1148 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1149
1150 skb = tls_wait_data(sk, flags, timeo, &err);
1151 if (!skb)
1152 goto splice_read_end;
1153
1154 /* splice does not support reading control messages */
1155 if (ctx->control != TLS_RECORD_TYPE_DATA) {
1156 err = -ENOTSUPP;
1157 goto splice_read_end;
1158 }
1159
1160 if (!ctx->decrypted) {
0b243d00 1161 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
c46234eb
DW
1162
1163 if (err < 0) {
1164 tls_err_abort(sk, EBADMSG);
1165 goto splice_read_end;
1166 }
1167 ctx->decrypted = true;
1168 }
1169 rxm = strp_msg(skb);
1170
1171 chunk = min_t(unsigned int, rxm->full_len, len);
1172 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1173 if (copied < 0)
1174 goto splice_read_end;
1175
1176 if (likely(!(flags & MSG_PEEK)))
1177 tls_sw_advance_skb(sk, skb, copied);
1178
1179splice_read_end:
1180 release_sock(sk);
1181 return copied ? : err;
1182}
1183
a11e1d43
LT
1184unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1185 struct poll_table_struct *wait)
c46234eb 1186{
a11e1d43 1187 unsigned int ret;
c46234eb
DW
1188 struct sock *sk = sock->sk;
1189 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 1190 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb 1191
a11e1d43
LT
1192 /* Grab POLLOUT and POLLHUP from the underlying socket */
1193 ret = ctx->sk_poll(file, sock, wait);
c46234eb 1194
a11e1d43
LT
1195 /* Clear POLLIN bits, and set based on recv_pkt */
1196 ret &= ~(POLLIN | POLLRDNORM);
c46234eb 1197 if (ctx->recv_pkt)
a11e1d43 1198 ret |= POLLIN | POLLRDNORM;
c46234eb 1199
a11e1d43 1200 return ret;
c46234eb
DW
1201}
1202
1203static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1204{
1205 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
f66de3ee 1206 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
3463e51d 1207 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
c46234eb
DW
1208 struct strp_msg *rxm = strp_msg(skb);
1209 size_t cipher_overhead;
1210 size_t data_len = 0;
1211 int ret;
1212
1213 /* Verify that we have a full TLS header, or wait for more data */
1214 if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1215 return 0;
1216
3463e51d
KC
1217 /* Sanity-check size of on-stack buffer. */
1218 if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1219 ret = -EINVAL;
1220 goto read_failure;
1221 }
1222
c46234eb
DW
1223 /* Linearize header to local buffer */
1224 ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1225
1226 if (ret < 0)
1227 goto read_failure;
1228
1229 ctx->control = header[0];
1230
1231 data_len = ((header[4] & 0xFF) | (header[3] << 8));
1232
1233 cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1234
1235 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1236 ret = -EMSGSIZE;
1237 goto read_failure;
1238 }
1239 if (data_len < cipher_overhead) {
1240 ret = -EBADMSG;
1241 goto read_failure;
1242 }
1243
1244 if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
1245 header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
1246 ret = -EINVAL;
1247 goto read_failure;
1248 }
1249
4799ac81
BP
1250#ifdef CONFIG_TLS_DEVICE
1251 handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1252 *(u64*)tls_ctx->rx.rec_seq);
1253#endif
c46234eb
DW
1254 return data_len + TLS_HEADER_SIZE;
1255
1256read_failure:
1257 tls_err_abort(strp->sk, ret);
1258
1259 return ret;
1260}
1261
1262static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1263{
1264 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
f66de3ee 1265 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb
DW
1266
1267 ctx->decrypted = false;
1268
1269 ctx->recv_pkt = skb;
1270 strp_pause(strp);
1271
ad13acce 1272 ctx->saved_data_ready(strp->sk);
c46234eb
DW
1273}
1274
1275static void tls_data_ready(struct sock *sk)
1276{
1277 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 1278 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
c46234eb
DW
1279
1280 strp_data_ready(&ctx->strp);
1281}
1282
f66de3ee 1283void tls_sw_free_resources_tx(struct sock *sk)
3c4d7559
DW
1284{
1285 struct tls_context *tls_ctx = tls_get_ctx(sk);
f66de3ee 1286 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
3c4d7559 1287
201876b3 1288 crypto_free_aead(ctx->aead_send);
f66de3ee
BP
1289 tls_free_both_sg(sk);
1290
1291 kfree(ctx);
1292}
1293
39f56e1a 1294void tls_sw_release_resources_rx(struct sock *sk)
f66de3ee
BP
1295{
1296 struct tls_context *tls_ctx = tls_get_ctx(sk);
1297 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1298
c46234eb 1299 if (ctx->aead_recv) {
201876b3
VG
1300 kfree_skb(ctx->recv_pkt);
1301 ctx->recv_pkt = NULL;
c46234eb
DW
1302 crypto_free_aead(ctx->aead_recv);
1303 strp_stop(&ctx->strp);
1304 write_lock_bh(&sk->sk_callback_lock);
1305 sk->sk_data_ready = ctx->saved_data_ready;
1306 write_unlock_bh(&sk->sk_callback_lock);
1307 release_sock(sk);
1308 strp_done(&ctx->strp);
1309 lock_sock(sk);
1310 }
39f56e1a
BP
1311}
1312
1313void tls_sw_free_resources_rx(struct sock *sk)
1314{
1315 struct tls_context *tls_ctx = tls_get_ctx(sk);
1316 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1317
1318 tls_sw_release_resources_rx(sk);
3c4d7559 1319
3c4d7559
DW
1320 kfree(ctx);
1321}
1322
c46234eb 1323int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
3c4d7559
DW
1324{
1325 char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
1326 struct tls_crypto_info *crypto_info;
1327 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
f66de3ee
BP
1328 struct tls_sw_context_tx *sw_ctx_tx = NULL;
1329 struct tls_sw_context_rx *sw_ctx_rx = NULL;
c46234eb
DW
1330 struct cipher_context *cctx;
1331 struct crypto_aead **aead;
1332 struct strp_callbacks cb;
3c4d7559
DW
1333 u16 nonce_size, tag_size, iv_size, rec_seq_size;
1334 char *iv, *rec_seq;
1335 int rc = 0;
1336
1337 if (!ctx) {
1338 rc = -EINVAL;
1339 goto out;
1340 }
1341
f66de3ee 1342 if (tx) {
b190a587
BP
1343 if (!ctx->priv_ctx_tx) {
1344 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1345 if (!sw_ctx_tx) {
1346 rc = -ENOMEM;
1347 goto out;
1348 }
1349 ctx->priv_ctx_tx = sw_ctx_tx;
1350 } else {
1351 sw_ctx_tx =
1352 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
c46234eb 1353 }
c46234eb 1354 } else {
b190a587
BP
1355 if (!ctx->priv_ctx_rx) {
1356 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1357 if (!sw_ctx_rx) {
1358 rc = -ENOMEM;
1359 goto out;
1360 }
1361 ctx->priv_ctx_rx = sw_ctx_rx;
1362 } else {
1363 sw_ctx_rx =
1364 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
f66de3ee 1365 }
3c4d7559
DW
1366 }
1367
c46234eb 1368 if (tx) {
b190a587 1369 crypto_init_wait(&sw_ctx_tx->async_wait);
c46234eb
DW
1370 crypto_info = &ctx->crypto_send;
1371 cctx = &ctx->tx;
f66de3ee 1372 aead = &sw_ctx_tx->aead_send;
c46234eb 1373 } else {
b190a587 1374 crypto_init_wait(&sw_ctx_rx->async_wait);
c46234eb
DW
1375 crypto_info = &ctx->crypto_recv;
1376 cctx = &ctx->rx;
f66de3ee 1377 aead = &sw_ctx_rx->aead_recv;
c46234eb
DW
1378 }
1379
3c4d7559
DW
1380 switch (crypto_info->cipher_type) {
1381 case TLS_CIPHER_AES_GCM_128: {
1382 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1383 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1384 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1385 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1386 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1387 rec_seq =
1388 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1389 gcm_128_info =
1390 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1391 break;
1392 }
1393 default:
1394 rc = -EINVAL;
cf6d43ef 1395 goto free_priv;
3c4d7559
DW
1396 }
1397
b16520f7 1398 /* Sanity-check the IV size for stack allocations. */
3463e51d 1399 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
b16520f7
KC
1400 rc = -EINVAL;
1401 goto free_priv;
1402 }
1403
c46234eb
DW
1404 cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1405 cctx->tag_size = tag_size;
1406 cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1407 cctx->iv_size = iv_size;
1408 cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1409 GFP_KERNEL);
1410 if (!cctx->iv) {
3c4d7559 1411 rc = -ENOMEM;
cf6d43ef 1412 goto free_priv;
3c4d7559 1413 }
c46234eb
DW
1414 memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1415 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1416 cctx->rec_seq_size = rec_seq_size;
969d5090 1417 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
c46234eb 1418 if (!cctx->rec_seq) {
3c4d7559
DW
1419 rc = -ENOMEM;
1420 goto free_iv;
1421 }
c46234eb 1422
f66de3ee
BP
1423 if (sw_ctx_tx) {
1424 sg_init_table(sw_ctx_tx->sg_encrypted_data,
1425 ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1426 sg_init_table(sw_ctx_tx->sg_plaintext_data,
1427 ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1428
1429 sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1430 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1431 sizeof(sw_ctx_tx->aad_space));
1432 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1433 sg_chain(sw_ctx_tx->sg_aead_in, 2,
1434 sw_ctx_tx->sg_plaintext_data);
1435 sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1436 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1437 sizeof(sw_ctx_tx->aad_space));
1438 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1439 sg_chain(sw_ctx_tx->sg_aead_out, 2,
1440 sw_ctx_tx->sg_encrypted_data);
c46234eb
DW
1441 }
1442
1443 if (!*aead) {
1444 *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1445 if (IS_ERR(*aead)) {
1446 rc = PTR_ERR(*aead);
1447 *aead = NULL;
3c4d7559
DW
1448 goto free_rec_seq;
1449 }
1450 }
1451
1452 ctx->push_pending_record = tls_sw_push_pending_record;
1453
1454 memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1455
c46234eb 1456 rc = crypto_aead_setkey(*aead, keyval,
3c4d7559
DW
1457 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1458 if (rc)
1459 goto free_aead;
1460
c46234eb
DW
1461 rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1462 if (rc)
1463 goto free_aead;
1464
f66de3ee 1465 if (sw_ctx_rx) {
94524d8f
VG
1466 (*aead)->reqsize = sizeof(struct decrypt_req_ctx);
1467
c46234eb
DW
1468 /* Set up strparser */
1469 memset(&cb, 0, sizeof(cb));
1470 cb.rcv_msg = tls_queue;
1471 cb.parse_msg = tls_read_size;
1472
f66de3ee 1473 strp_init(&sw_ctx_rx->strp, sk, &cb);
c46234eb
DW
1474
1475 write_lock_bh(&sk->sk_callback_lock);
f66de3ee 1476 sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
c46234eb
DW
1477 sk->sk_data_ready = tls_data_ready;
1478 write_unlock_bh(&sk->sk_callback_lock);
1479
a11e1d43 1480 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
c46234eb 1481
f66de3ee 1482 strp_check_rcv(&sw_ctx_rx->strp);
c46234eb
DW
1483 }
1484
1485 goto out;
3c4d7559
DW
1486
1487free_aead:
c46234eb
DW
1488 crypto_free_aead(*aead);
1489 *aead = NULL;
3c4d7559 1490free_rec_seq:
c46234eb
DW
1491 kfree(cctx->rec_seq);
1492 cctx->rec_seq = NULL;
3c4d7559 1493free_iv:
f66de3ee
BP
1494 kfree(cctx->iv);
1495 cctx->iv = NULL;
cf6d43ef 1496free_priv:
f66de3ee
BP
1497 if (tx) {
1498 kfree(ctx->priv_ctx_tx);
1499 ctx->priv_ctx_tx = NULL;
1500 } else {
1501 kfree(ctx->priv_ctx_rx);
1502 ctx->priv_ctx_rx = NULL;
1503 }
3c4d7559
DW
1504out:
1505 return rc;
1506}