net/tls: Fixed race condition in async encryption
[linux-2.6-block.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40
41 #include <net/strparser.h>
42 #include <net/tls.h>
43
44 #define MAX_IV_SIZE     TLS_CIPHER_AES_GCM_128_IV_SIZE
45
46 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
47                      unsigned int recursion_level)
48 {
49         int start = skb_headlen(skb);
50         int i, chunk = start - offset;
51         struct sk_buff *frag_iter;
52         int elt = 0;
53
54         if (unlikely(recursion_level >= 24))
55                 return -EMSGSIZE;
56
57         if (chunk > 0) {
58                 if (chunk > len)
59                         chunk = len;
60                 elt++;
61                 len -= chunk;
62                 if (len == 0)
63                         return elt;
64                 offset += chunk;
65         }
66
67         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
68                 int end;
69
70                 WARN_ON(start > offset + len);
71
72                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
73                 chunk = end - offset;
74                 if (chunk > 0) {
75                         if (chunk > len)
76                                 chunk = len;
77                         elt++;
78                         len -= chunk;
79                         if (len == 0)
80                                 return elt;
81                         offset += chunk;
82                 }
83                 start = end;
84         }
85
86         if (unlikely(skb_has_frag_list(skb))) {
87                 skb_walk_frags(skb, frag_iter) {
88                         int end, ret;
89
90                         WARN_ON(start > offset + len);
91
92                         end = start + frag_iter->len;
93                         chunk = end - offset;
94                         if (chunk > 0) {
95                                 if (chunk > len)
96                                         chunk = len;
97                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
98                                                 recursion_level + 1);
99                                 if (unlikely(ret < 0))
100                                         return ret;
101                                 elt += ret;
102                                 len -= chunk;
103                                 if (len == 0)
104                                         return elt;
105                                 offset += chunk;
106                         }
107                         start = end;
108                 }
109         }
110         BUG_ON(len);
111         return elt;
112 }
113
114 /* Return the number of scatterlist elements required to completely map the
115  * skb, or -EMSGSIZE if the recursion depth is exceeded.
116  */
117 static int skb_nsg(struct sk_buff *skb, int offset, int len)
118 {
119         return __skb_nsg(skb, offset, len, 0);
120 }
121
122 static void tls_decrypt_done(struct crypto_async_request *req, int err)
123 {
124         struct aead_request *aead_req = (struct aead_request *)req;
125         struct scatterlist *sgout = aead_req->dst;
126         struct tls_sw_context_rx *ctx;
127         struct tls_context *tls_ctx;
128         struct scatterlist *sg;
129         struct sk_buff *skb;
130         unsigned int pages;
131         int pending;
132
133         skb = (struct sk_buff *)req->data;
134         tls_ctx = tls_get_ctx(skb->sk);
135         ctx = tls_sw_ctx_rx(tls_ctx);
136         pending = atomic_dec_return(&ctx->decrypt_pending);
137
138         /* Propagate if there was an err */
139         if (err) {
140                 ctx->async_wait.err = err;
141                 tls_err_abort(skb->sk, err);
142         }
143
144         /* After using skb->sk to propagate sk through crypto async callback
145          * we need to NULL it again.
146          */
147         skb->sk = NULL;
148
149         /* Release the skb, pages and memory allocated for crypto req */
150         kfree_skb(skb);
151
152         /* Skip the first S/G entry as it points to AAD */
153         for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
154                 if (!sg)
155                         break;
156                 put_page(sg_page(sg));
157         }
158
159         kfree(aead_req);
160
161         if (!pending && READ_ONCE(ctx->async_notify))
162                 complete(&ctx->async_wait.completion);
163 }
164
165 static int tls_do_decryption(struct sock *sk,
166                              struct sk_buff *skb,
167                              struct scatterlist *sgin,
168                              struct scatterlist *sgout,
169                              char *iv_recv,
170                              size_t data_len,
171                              struct aead_request *aead_req,
172                              bool async)
173 {
174         struct tls_context *tls_ctx = tls_get_ctx(sk);
175         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
176         int ret;
177
178         aead_request_set_tfm(aead_req, ctx->aead_recv);
179         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
180         aead_request_set_crypt(aead_req, sgin, sgout,
181                                data_len + tls_ctx->rx.tag_size,
182                                (u8 *)iv_recv);
183
184         if (async) {
185                 /* Using skb->sk to push sk through to crypto async callback
186                  * handler. This allows propagating errors up to the socket
187                  * if needed. It _must_ be cleared in the async handler
188                  * before kfree_skb is called. We _know_ skb->sk is NULL
189                  * because it is a clone from strparser.
190                  */
191                 skb->sk = sk;
192                 aead_request_set_callback(aead_req,
193                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
194                                           tls_decrypt_done, skb);
195                 atomic_inc(&ctx->decrypt_pending);
196         } else {
197                 aead_request_set_callback(aead_req,
198                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
199                                           crypto_req_done, &ctx->async_wait);
200         }
201
202         ret = crypto_aead_decrypt(aead_req);
203         if (ret == -EINPROGRESS) {
204                 if (async)
205                         return ret;
206
207                 ret = crypto_wait_req(ret, &ctx->async_wait);
208         }
209
210         if (async)
211                 atomic_dec(&ctx->decrypt_pending);
212
213         return ret;
214 }
215
216 static void trim_sg(struct sock *sk, struct scatterlist *sg,
217                     int *sg_num_elem, unsigned int *sg_size, int target_size)
218 {
219         int i = *sg_num_elem - 1;
220         int trim = *sg_size - target_size;
221
222         if (trim <= 0) {
223                 WARN_ON(trim < 0);
224                 return;
225         }
226
227         *sg_size = target_size;
228         while (trim >= sg[i].length) {
229                 trim -= sg[i].length;
230                 sk_mem_uncharge(sk, sg[i].length);
231                 put_page(sg_page(&sg[i]));
232                 i--;
233
234                 if (i < 0)
235                         goto out;
236         }
237
238         sg[i].length -= trim;
239         sk_mem_uncharge(sk, trim);
240
241 out:
242         *sg_num_elem = i + 1;
243 }
244
245 static void trim_both_sgl(struct sock *sk, int target_size)
246 {
247         struct tls_context *tls_ctx = tls_get_ctx(sk);
248         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
249         struct tls_rec *rec = ctx->open_rec;
250
251         trim_sg(sk, rec->sg_plaintext_data,
252                 &rec->sg_plaintext_num_elem,
253                 &rec->sg_plaintext_size,
254                 target_size);
255
256         if (target_size > 0)
257                 target_size += tls_ctx->tx.overhead_size;
258
259         trim_sg(sk, rec->sg_encrypted_data,
260                 &rec->sg_encrypted_num_elem,
261                 &rec->sg_encrypted_size,
262                 target_size);
263 }
264
265 static int alloc_encrypted_sg(struct sock *sk, int len)
266 {
267         struct tls_context *tls_ctx = tls_get_ctx(sk);
268         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
269         struct tls_rec *rec = ctx->open_rec;
270         int rc = 0;
271
272         rc = sk_alloc_sg(sk, len,
273                          rec->sg_encrypted_data, 0,
274                          &rec->sg_encrypted_num_elem,
275                          &rec->sg_encrypted_size, 0);
276
277         if (rc == -ENOSPC)
278                 rec->sg_encrypted_num_elem = ARRAY_SIZE(rec->sg_encrypted_data);
279
280         return rc;
281 }
282
283 static int alloc_plaintext_sg(struct sock *sk, int len)
284 {
285         struct tls_context *tls_ctx = tls_get_ctx(sk);
286         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
287         struct tls_rec *rec = ctx->open_rec;
288         int rc = 0;
289
290         rc = sk_alloc_sg(sk, len, rec->sg_plaintext_data, 0,
291                          &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size,
292                          tls_ctx->pending_open_record_frags);
293
294         if (rc == -ENOSPC)
295                 rec->sg_plaintext_num_elem = ARRAY_SIZE(rec->sg_plaintext_data);
296
297         return rc;
298 }
299
300 static void free_sg(struct sock *sk, struct scatterlist *sg,
301                     int *sg_num_elem, unsigned int *sg_size)
302 {
303         int i, n = *sg_num_elem;
304
305         for (i = 0; i < n; ++i) {
306                 sk_mem_uncharge(sk, sg[i].length);
307                 put_page(sg_page(&sg[i]));
308         }
309         *sg_num_elem = 0;
310         *sg_size = 0;
311 }
312
313 static void tls_free_both_sg(struct sock *sk)
314 {
315         struct tls_context *tls_ctx = tls_get_ctx(sk);
316         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
317         struct tls_rec *rec = ctx->open_rec;
318
319         /* Return if there is no open record */
320         if (!rec)
321                 return;
322
323         free_sg(sk, rec->sg_encrypted_data,
324                 &rec->sg_encrypted_num_elem,
325                 &rec->sg_encrypted_size);
326
327         free_sg(sk, rec->sg_plaintext_data,
328                 &rec->sg_plaintext_num_elem,
329                 &rec->sg_plaintext_size);
330 }
331
332 int tls_tx_records(struct sock *sk, int flags)
333 {
334         struct tls_context *tls_ctx = tls_get_ctx(sk);
335         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
336         struct tls_rec *rec, *tmp;
337         int tx_flags, rc = 0;
338
339         if (tls_is_partially_sent_record(tls_ctx)) {
340                 rec = list_first_entry(&ctx->tx_list,
341                                        struct tls_rec, list);
342
343                 if (flags == -1)
344                         tx_flags = rec->tx_flags;
345                 else
346                         tx_flags = flags;
347
348                 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
349                 if (rc)
350                         goto tx_err;
351
352                 /* Full record has been transmitted.
353                  * Remove the head of tx_list
354                  */
355                 list_del(&rec->list);
356                 kfree(rec);
357         }
358
359         /* Tx all ready records */
360         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
361                 if (READ_ONCE(rec->tx_ready)) {
362                         if (flags == -1)
363                                 tx_flags = rec->tx_flags;
364                         else
365                                 tx_flags = flags;
366
367                         rc = tls_push_sg(sk, tls_ctx,
368                                          &rec->sg_encrypted_data[0],
369                                          0, tx_flags);
370                         if (rc)
371                                 goto tx_err;
372
373                         list_del(&rec->list);
374                         kfree(rec);
375                 } else {
376                         break;
377                 }
378         }
379
380 tx_err:
381         if (rc < 0 && rc != -EAGAIN)
382                 tls_err_abort(sk, EBADMSG);
383
384         return rc;
385 }
386
387 static void tls_encrypt_done(struct crypto_async_request *req, int err)
388 {
389         struct aead_request *aead_req = (struct aead_request *)req;
390         struct sock *sk = req->data;
391         struct tls_context *tls_ctx = tls_get_ctx(sk);
392         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
393         struct tls_rec *rec;
394         bool ready = false;
395         int pending;
396
397         rec = container_of(aead_req, struct tls_rec, aead_req);
398
399         rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
400         rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
401
402         free_sg(sk, rec->sg_plaintext_data,
403                 &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size);
404
405         /* Free the record if error is previously set on socket */
406         if (err || sk->sk_err) {
407                 free_sg(sk, rec->sg_encrypted_data,
408                         &rec->sg_encrypted_num_elem, &rec->sg_encrypted_size);
409
410                 kfree(rec);
411                 rec = NULL;
412
413                 /* If err is already set on socket, return the same code */
414                 if (sk->sk_err) {
415                         ctx->async_wait.err = sk->sk_err;
416                 } else {
417                         ctx->async_wait.err = err;
418                         tls_err_abort(sk, err);
419                 }
420         }
421
422         if (rec) {
423                 struct tls_rec *first_rec;
424
425                 /* Mark the record as ready for transmission */
426                 smp_store_mb(rec->tx_ready, true);
427
428                 /* If received record is at head of tx_list, schedule tx */
429                 first_rec = list_first_entry(&ctx->tx_list,
430                                              struct tls_rec, list);
431                 if (rec == first_rec)
432                         ready = true;
433         }
434
435         pending = atomic_dec_return(&ctx->encrypt_pending);
436
437         if (!pending && READ_ONCE(ctx->async_notify))
438                 complete(&ctx->async_wait.completion);
439
440         if (!ready)
441                 return;
442
443         /* Schedule the transmission */
444         if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
445                 schedule_delayed_work(&ctx->tx_work.work, 1);
446 }
447
448 static int tls_do_encryption(struct sock *sk,
449                              struct tls_context *tls_ctx,
450                              struct tls_sw_context_tx *ctx,
451                              struct aead_request *aead_req,
452                              size_t data_len)
453 {
454         struct tls_rec *rec = ctx->open_rec;
455         int rc;
456
457         rec->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
458         rec->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
459
460         aead_request_set_tfm(aead_req, ctx->aead_send);
461         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
462         aead_request_set_crypt(aead_req, rec->sg_aead_in,
463                                rec->sg_aead_out,
464                                data_len, tls_ctx->tx.iv);
465
466         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
467                                   tls_encrypt_done, sk);
468
469         /* Add the record in tx_list */
470         list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
471         atomic_inc(&ctx->encrypt_pending);
472
473         rc = crypto_aead_encrypt(aead_req);
474         if (!rc || rc != -EINPROGRESS) {
475                 atomic_dec(&ctx->encrypt_pending);
476                 rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
477                 rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
478         }
479
480         if (!rc) {
481                 WRITE_ONCE(rec->tx_ready, true);
482         } else if (rc != -EINPROGRESS) {
483                 list_del(&rec->list);
484                 return rc;
485         }
486
487         /* Unhook the record from context if encryption is not failure */
488         ctx->open_rec = NULL;
489         tls_advance_record_sn(sk, &tls_ctx->tx);
490         return rc;
491 }
492
493 static int tls_push_record(struct sock *sk, int flags,
494                            unsigned char record_type)
495 {
496         struct tls_context *tls_ctx = tls_get_ctx(sk);
497         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
498         struct tls_rec *rec = ctx->open_rec;
499         struct aead_request *req;
500         int rc;
501
502         if (!rec)
503                 return 0;
504
505         rec->tx_flags = flags;
506         req = &rec->aead_req;
507
508         sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem - 1);
509         sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem - 1);
510
511         tls_make_aad(rec->aad_space, rec->sg_plaintext_size,
512                      tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
513                      record_type);
514
515         tls_fill_prepend(tls_ctx,
516                          page_address(sg_page(&rec->sg_encrypted_data[0])) +
517                          rec->sg_encrypted_data[0].offset,
518                          rec->sg_plaintext_size, record_type);
519
520         tls_ctx->pending_open_record_frags = 0;
521
522         rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size);
523         if (rc == -EINPROGRESS)
524                 return -EINPROGRESS;
525
526         free_sg(sk, rec->sg_plaintext_data, &rec->sg_plaintext_num_elem,
527                 &rec->sg_plaintext_size);
528
529         if (rc < 0) {
530                 tls_err_abort(sk, EBADMSG);
531                 return rc;
532         }
533
534         return tls_tx_records(sk, flags);
535 }
536
537 static int tls_sw_push_pending_record(struct sock *sk, int flags)
538 {
539         return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
540 }
541
542 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
543                               int length, int *pages_used,
544                               unsigned int *size_used,
545                               struct scatterlist *to, int to_max_pages,
546                               bool charge)
547 {
548         struct page *pages[MAX_SKB_FRAGS];
549
550         size_t offset;
551         ssize_t copied, use;
552         int i = 0;
553         unsigned int size = *size_used;
554         int num_elem = *pages_used;
555         int rc = 0;
556         int maxpages;
557
558         while (length > 0) {
559                 i = 0;
560                 maxpages = to_max_pages - num_elem;
561                 if (maxpages == 0) {
562                         rc = -EFAULT;
563                         goto out;
564                 }
565                 copied = iov_iter_get_pages(from, pages,
566                                             length,
567                                             maxpages, &offset);
568                 if (copied <= 0) {
569                         rc = -EFAULT;
570                         goto out;
571                 }
572
573                 iov_iter_advance(from, copied);
574
575                 length -= copied;
576                 size += copied;
577                 while (copied) {
578                         use = min_t(int, copied, PAGE_SIZE - offset);
579
580                         sg_set_page(&to[num_elem],
581                                     pages[i], use, offset);
582                         sg_unmark_end(&to[num_elem]);
583                         if (charge)
584                                 sk_mem_charge(sk, use);
585
586                         offset = 0;
587                         copied -= use;
588
589                         ++i;
590                         ++num_elem;
591                 }
592         }
593
594         /* Mark the end in the last sg entry if newly added */
595         if (num_elem > *pages_used)
596                 sg_mark_end(&to[num_elem - 1]);
597 out:
598         if (rc)
599                 iov_iter_revert(from, size - *size_used);
600         *size_used = size;
601         *pages_used = num_elem;
602
603         return rc;
604 }
605
606 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
607                              int bytes)
608 {
609         struct tls_context *tls_ctx = tls_get_ctx(sk);
610         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
611         struct tls_rec *rec = ctx->open_rec;
612         struct scatterlist *sg = rec->sg_plaintext_data;
613         int copy, i, rc = 0;
614
615         for (i = tls_ctx->pending_open_record_frags;
616              i < rec->sg_plaintext_num_elem; ++i) {
617                 copy = sg[i].length;
618                 if (copy_from_iter(
619                                 page_address(sg_page(&sg[i])) + sg[i].offset,
620                                 copy, from) != copy) {
621                         rc = -EFAULT;
622                         goto out;
623                 }
624                 bytes -= copy;
625
626                 ++tls_ctx->pending_open_record_frags;
627
628                 if (!bytes)
629                         break;
630         }
631
632 out:
633         return rc;
634 }
635
636 struct tls_rec *get_rec(struct sock *sk)
637 {
638         struct tls_context *tls_ctx = tls_get_ctx(sk);
639         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
640         struct tls_rec *rec;
641         int mem_size;
642
643         /* Return if we already have an open record */
644         if (ctx->open_rec)
645                 return ctx->open_rec;
646
647         mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
648
649         rec = kzalloc(mem_size, sk->sk_allocation);
650         if (!rec)
651                 return NULL;
652
653         sg_init_table(&rec->sg_plaintext_data[0],
654                       ARRAY_SIZE(rec->sg_plaintext_data));
655         sg_init_table(&rec->sg_encrypted_data[0],
656                       ARRAY_SIZE(rec->sg_encrypted_data));
657
658         sg_init_table(rec->sg_aead_in, 2);
659         sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
660                    sizeof(rec->aad_space));
661         sg_unmark_end(&rec->sg_aead_in[1]);
662         sg_chain(rec->sg_aead_in, 2, rec->sg_plaintext_data);
663
664         sg_init_table(rec->sg_aead_out, 2);
665         sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
666                    sizeof(rec->aad_space));
667         sg_unmark_end(&rec->sg_aead_out[1]);
668         sg_chain(rec->sg_aead_out, 2, rec->sg_encrypted_data);
669
670         ctx->open_rec = rec;
671
672         return rec;
673 }
674
675 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
676 {
677         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
678         struct tls_context *tls_ctx = tls_get_ctx(sk);
679         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
680         struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send);
681         bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
682         unsigned char record_type = TLS_RECORD_TYPE_DATA;
683         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
684         bool eor = !(msg->msg_flags & MSG_MORE);
685         size_t try_to_copy, copied = 0;
686         struct tls_rec *rec;
687         int required_size;
688         int num_async = 0;
689         bool full_record;
690         int record_room;
691         int num_zc = 0;
692         int orig_size;
693         int ret;
694
695         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
696                 return -ENOTSUPP;
697
698         lock_sock(sk);
699
700         /* Wait till there is any pending write on socket */
701         if (unlikely(sk->sk_write_pending)) {
702                 ret = wait_on_pending_writer(sk, &timeo);
703                 if (unlikely(ret))
704                         goto send_end;
705         }
706
707         if (unlikely(msg->msg_controllen)) {
708                 ret = tls_proccess_cmsg(sk, msg, &record_type);
709                 if (ret) {
710                         if (ret == -EINPROGRESS)
711                                 num_async++;
712                         else if (ret != -EAGAIN)
713                                 goto send_end;
714                 }
715         }
716
717         while (msg_data_left(msg)) {
718                 if (sk->sk_err) {
719                         ret = -sk->sk_err;
720                         goto send_end;
721                 }
722
723                 rec = get_rec(sk);
724                 if (!rec) {
725                         ret = -ENOMEM;
726                         goto send_end;
727                 }
728
729                 orig_size = rec->sg_plaintext_size;
730                 full_record = false;
731                 try_to_copy = msg_data_left(msg);
732                 record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
733                 if (try_to_copy >= record_room) {
734                         try_to_copy = record_room;
735                         full_record = true;
736                 }
737
738                 required_size = rec->sg_plaintext_size + try_to_copy +
739                                 tls_ctx->tx.overhead_size;
740
741                 if (!sk_stream_memory_free(sk))
742                         goto wait_for_sndbuf;
743
744 alloc_encrypted:
745                 ret = alloc_encrypted_sg(sk, required_size);
746                 if (ret) {
747                         if (ret != -ENOSPC)
748                                 goto wait_for_memory;
749
750                         /* Adjust try_to_copy according to the amount that was
751                          * actually allocated. The difference is due
752                          * to max sg elements limit
753                          */
754                         try_to_copy -= required_size - rec->sg_encrypted_size;
755                         full_record = true;
756                 }
757
758                 if (!is_kvec && (full_record || eor) && !async_capable) {
759                         ret = zerocopy_from_iter(sk, &msg->msg_iter,
760                                 try_to_copy, &rec->sg_plaintext_num_elem,
761                                 &rec->sg_plaintext_size,
762                                 rec->sg_plaintext_data,
763                                 ARRAY_SIZE(rec->sg_plaintext_data),
764                                 true);
765                         if (ret)
766                                 goto fallback_to_reg_send;
767
768                         num_zc++;
769                         copied += try_to_copy;
770                         ret = tls_push_record(sk, msg->msg_flags, record_type);
771                         if (ret) {
772                                 if (ret == -EINPROGRESS)
773                                         num_async++;
774                                 else if (ret != -EAGAIN)
775                                         goto send_end;
776                         }
777                         continue;
778
779 fallback_to_reg_send:
780                         trim_sg(sk, rec->sg_plaintext_data,
781                                 &rec->sg_plaintext_num_elem,
782                                 &rec->sg_plaintext_size,
783                                 orig_size);
784                 }
785
786                 required_size = rec->sg_plaintext_size + try_to_copy;
787 alloc_plaintext:
788                 ret = alloc_plaintext_sg(sk, required_size);
789                 if (ret) {
790                         if (ret != -ENOSPC)
791                                 goto wait_for_memory;
792
793                         /* Adjust try_to_copy according to the amount that was
794                          * actually allocated. The difference is due
795                          * to max sg elements limit
796                          */
797                         try_to_copy -= required_size - rec->sg_plaintext_size;
798                         full_record = true;
799
800                         trim_sg(sk, rec->sg_encrypted_data,
801                                 &rec->sg_encrypted_num_elem,
802                                 &rec->sg_encrypted_size,
803                                 rec->sg_plaintext_size +
804                                 tls_ctx->tx.overhead_size);
805                 }
806
807                 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
808                 if (ret)
809                         goto trim_sgl;
810
811                 copied += try_to_copy;
812                 if (full_record || eor) {
813                         ret = tls_push_record(sk, msg->msg_flags, record_type);
814                         if (ret) {
815                                 if (ret == -EINPROGRESS)
816                                         num_async++;
817                                 else if (ret != -EAGAIN)
818                                         goto send_end;
819                         }
820                 }
821
822                 continue;
823
824 wait_for_sndbuf:
825                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
826 wait_for_memory:
827                 ret = sk_stream_wait_memory(sk, &timeo);
828                 if (ret) {
829 trim_sgl:
830                         trim_both_sgl(sk, orig_size);
831                         goto send_end;
832                 }
833
834                 if (rec->sg_encrypted_size < required_size)
835                         goto alloc_encrypted;
836
837                 goto alloc_plaintext;
838         }
839
840         if (!num_async) {
841                 goto send_end;
842         } else if (num_zc) {
843                 /* Wait for pending encryptions to get completed */
844                 smp_store_mb(ctx->async_notify, true);
845
846                 if (atomic_read(&ctx->encrypt_pending))
847                         crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
848                 else
849                         reinit_completion(&ctx->async_wait.completion);
850
851                 WRITE_ONCE(ctx->async_notify, false);
852
853                 if (ctx->async_wait.err) {
854                         ret = ctx->async_wait.err;
855                         copied = 0;
856                 }
857         }
858
859         /* Transmit if any encryptions have completed */
860         if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
861                 cancel_delayed_work(&ctx->tx_work.work);
862                 tls_tx_records(sk, msg->msg_flags);
863         }
864
865 send_end:
866         ret = sk_stream_error(sk, msg->msg_flags, ret);
867
868         release_sock(sk);
869         return copied ? copied : ret;
870 }
871
872 int tls_sw_sendpage(struct sock *sk, struct page *page,
873                     int offset, size_t size, int flags)
874 {
875         long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
876         struct tls_context *tls_ctx = tls_get_ctx(sk);
877         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
878         unsigned char record_type = TLS_RECORD_TYPE_DATA;
879         size_t orig_size = size;
880         struct scatterlist *sg;
881         struct tls_rec *rec;
882         int num_async = 0;
883         bool full_record;
884         int record_room;
885         bool eor;
886         int ret;
887
888         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
889                       MSG_SENDPAGE_NOTLAST))
890                 return -ENOTSUPP;
891
892         /* No MSG_EOR from splice, only look at MSG_MORE */
893         eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
894
895         lock_sock(sk);
896
897         sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
898
899         /* Wait till there is any pending write on socket */
900         if (unlikely(sk->sk_write_pending)) {
901                 ret = wait_on_pending_writer(sk, &timeo);
902                 if (unlikely(ret))
903                         goto sendpage_end;
904         }
905
906         /* Call the sk_stream functions to manage the sndbuf mem. */
907         while (size > 0) {
908                 size_t copy, required_size;
909
910                 if (sk->sk_err) {
911                         ret = -sk->sk_err;
912                         goto sendpage_end;
913                 }
914
915                 rec = get_rec(sk);
916                 if (!rec) {
917                         ret = -ENOMEM;
918                         goto sendpage_end;
919                 }
920
921                 full_record = false;
922                 record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
923                 copy = size;
924                 if (copy >= record_room) {
925                         copy = record_room;
926                         full_record = true;
927                 }
928                 required_size = rec->sg_plaintext_size + copy +
929                               tls_ctx->tx.overhead_size;
930
931                 if (!sk_stream_memory_free(sk))
932                         goto wait_for_sndbuf;
933 alloc_payload:
934                 ret = alloc_encrypted_sg(sk, required_size);
935                 if (ret) {
936                         if (ret != -ENOSPC)
937                                 goto wait_for_memory;
938
939                         /* Adjust copy according to the amount that was
940                          * actually allocated. The difference is due
941                          * to max sg elements limit
942                          */
943                         copy -= required_size - rec->sg_plaintext_size;
944                         full_record = true;
945                 }
946
947                 get_page(page);
948                 sg = rec->sg_plaintext_data + rec->sg_plaintext_num_elem;
949                 sg_set_page(sg, page, copy, offset);
950                 sg_unmark_end(sg);
951
952                 rec->sg_plaintext_num_elem++;
953
954                 sk_mem_charge(sk, copy);
955                 offset += copy;
956                 size -= copy;
957                 rec->sg_plaintext_size += copy;
958                 tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem;
959
960                 if (full_record || eor ||
961                     rec->sg_plaintext_num_elem ==
962                     ARRAY_SIZE(rec->sg_plaintext_data)) {
963                         ret = tls_push_record(sk, flags, record_type);
964                         if (ret) {
965                                 if (ret == -EINPROGRESS)
966                                         num_async++;
967                                 else if (ret != -EAGAIN)
968                                         goto sendpage_end;
969                         }
970                 }
971                 continue;
972 wait_for_sndbuf:
973                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
974 wait_for_memory:
975                 ret = sk_stream_wait_memory(sk, &timeo);
976                 if (ret) {
977                         trim_both_sgl(sk, rec->sg_plaintext_size);
978                         goto sendpage_end;
979                 }
980
981                 goto alloc_payload;
982         }
983
984         if (num_async) {
985                 /* Transmit if any encryptions have completed */
986                 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
987                         cancel_delayed_work(&ctx->tx_work.work);
988                         tls_tx_records(sk, flags);
989                 }
990         }
991 sendpage_end:
992         if (orig_size > size)
993                 ret = orig_size - size;
994         else
995                 ret = sk_stream_error(sk, flags, ret);
996
997         release_sock(sk);
998         return ret;
999 }
1000
1001 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
1002                                      long timeo, int *err)
1003 {
1004         struct tls_context *tls_ctx = tls_get_ctx(sk);
1005         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1006         struct sk_buff *skb;
1007         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1008
1009         while (!(skb = ctx->recv_pkt)) {
1010                 if (sk->sk_err) {
1011                         *err = sock_error(sk);
1012                         return NULL;
1013                 }
1014
1015                 if (sk->sk_shutdown & RCV_SHUTDOWN)
1016                         return NULL;
1017
1018                 if (sock_flag(sk, SOCK_DONE))
1019                         return NULL;
1020
1021                 if ((flags & MSG_DONTWAIT) || !timeo) {
1022                         *err = -EAGAIN;
1023                         return NULL;
1024                 }
1025
1026                 add_wait_queue(sk_sleep(sk), &wait);
1027                 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1028                 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
1029                 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1030                 remove_wait_queue(sk_sleep(sk), &wait);
1031
1032                 /* Handle signals */
1033                 if (signal_pending(current)) {
1034                         *err = sock_intr_errno(timeo);
1035                         return NULL;
1036                 }
1037         }
1038
1039         return skb;
1040 }
1041
1042 /* This function decrypts the input skb into either out_iov or in out_sg
1043  * or in skb buffers itself. The input parameter 'zc' indicates if
1044  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1045  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1046  * NULL, then the decryption happens inside skb buffers itself, i.e.
1047  * zero-copy gets disabled and 'zc' is updated.
1048  */
1049
1050 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1051                             struct iov_iter *out_iov,
1052                             struct scatterlist *out_sg,
1053                             int *chunk, bool *zc)
1054 {
1055         struct tls_context *tls_ctx = tls_get_ctx(sk);
1056         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1057         struct strp_msg *rxm = strp_msg(skb);
1058         int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
1059         struct aead_request *aead_req;
1060         struct sk_buff *unused;
1061         u8 *aad, *iv, *mem = NULL;
1062         struct scatterlist *sgin = NULL;
1063         struct scatterlist *sgout = NULL;
1064         const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
1065
1066         if (*zc && (out_iov || out_sg)) {
1067                 if (out_iov)
1068                         n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
1069                 else
1070                         n_sgout = sg_nents(out_sg);
1071                 n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
1072                                  rxm->full_len - tls_ctx->rx.prepend_size);
1073         } else {
1074                 n_sgout = 0;
1075                 *zc = false;
1076                 n_sgin = skb_cow_data(skb, 0, &unused);
1077         }
1078
1079         if (n_sgin < 1)
1080                 return -EBADMSG;
1081
1082         /* Increment to accommodate AAD */
1083         n_sgin = n_sgin + 1;
1084
1085         nsg = n_sgin + n_sgout;
1086
1087         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1088         mem_size = aead_size + (nsg * sizeof(struct scatterlist));
1089         mem_size = mem_size + TLS_AAD_SPACE_SIZE;
1090         mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
1091
1092         /* Allocate a single block of memory which contains
1093          * aead_req || sgin[] || sgout[] || aad || iv.
1094          * This order achieves correct alignment for aead_req, sgin, sgout.
1095          */
1096         mem = kmalloc(mem_size, sk->sk_allocation);
1097         if (!mem)
1098                 return -ENOMEM;
1099
1100         /* Segment the allocated memory */
1101         aead_req = (struct aead_request *)mem;
1102         sgin = (struct scatterlist *)(mem + aead_size);
1103         sgout = sgin + n_sgin;
1104         aad = (u8 *)(sgout + n_sgout);
1105         iv = aad + TLS_AAD_SPACE_SIZE;
1106
1107         /* Prepare IV */
1108         err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1109                             iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1110                             tls_ctx->rx.iv_size);
1111         if (err < 0) {
1112                 kfree(mem);
1113                 return err;
1114         }
1115         memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1116
1117         /* Prepare AAD */
1118         tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
1119                      tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
1120                      ctx->control);
1121
1122         /* Prepare sgin */
1123         sg_init_table(sgin, n_sgin);
1124         sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
1125         err = skb_to_sgvec(skb, &sgin[1],
1126                            rxm->offset + tls_ctx->rx.prepend_size,
1127                            rxm->full_len - tls_ctx->rx.prepend_size);
1128         if (err < 0) {
1129                 kfree(mem);
1130                 return err;
1131         }
1132
1133         if (n_sgout) {
1134                 if (out_iov) {
1135                         sg_init_table(sgout, n_sgout);
1136                         sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
1137
1138                         *chunk = 0;
1139                         err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
1140                                                  chunk, &sgout[1],
1141                                                  (n_sgout - 1), false);
1142                         if (err < 0)
1143                                 goto fallback_to_reg_recv;
1144                 } else if (out_sg) {
1145                         memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1146                 } else {
1147                         goto fallback_to_reg_recv;
1148                 }
1149         } else {
1150 fallback_to_reg_recv:
1151                 sgout = sgin;
1152                 pages = 0;
1153                 *chunk = 0;
1154                 *zc = false;
1155         }
1156
1157         /* Prepare and submit AEAD request */
1158         err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1159                                 data_len, aead_req, *zc);
1160         if (err == -EINPROGRESS)
1161                 return err;
1162
1163         /* Release the pages in case iov was mapped to pages */
1164         for (; pages > 0; pages--)
1165                 put_page(sg_page(&sgout[pages]));
1166
1167         kfree(mem);
1168         return err;
1169 }
1170
1171 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1172                               struct iov_iter *dest, int *chunk, bool *zc)
1173 {
1174         struct tls_context *tls_ctx = tls_get_ctx(sk);
1175         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1176         struct strp_msg *rxm = strp_msg(skb);
1177         int err = 0;
1178
1179 #ifdef CONFIG_TLS_DEVICE
1180         err = tls_device_decrypted(sk, skb);
1181         if (err < 0)
1182                 return err;
1183 #endif
1184         if (!ctx->decrypted) {
1185                 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
1186                 if (err < 0) {
1187                         if (err == -EINPROGRESS)
1188                                 tls_advance_record_sn(sk, &tls_ctx->rx);
1189
1190                         return err;
1191                 }
1192         } else {
1193                 *zc = false;
1194         }
1195
1196         rxm->offset += tls_ctx->rx.prepend_size;
1197         rxm->full_len -= tls_ctx->rx.overhead_size;
1198         tls_advance_record_sn(sk, &tls_ctx->rx);
1199         ctx->decrypted = true;
1200         ctx->saved_data_ready(sk);
1201
1202         return err;
1203 }
1204
1205 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
1206                 struct scatterlist *sgout)
1207 {
1208         bool zc = true;
1209         int chunk;
1210
1211         return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
1212 }
1213
1214 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
1215                                unsigned int len)
1216 {
1217         struct tls_context *tls_ctx = tls_get_ctx(sk);
1218         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1219
1220         if (skb) {
1221                 struct strp_msg *rxm = strp_msg(skb);
1222
1223                 if (len < rxm->full_len) {
1224                         rxm->offset += len;
1225                         rxm->full_len -= len;
1226                         return false;
1227                 }
1228                 kfree_skb(skb);
1229         }
1230
1231         /* Finished with message */
1232         ctx->recv_pkt = NULL;
1233         __strp_unpause(&ctx->strp);
1234
1235         return true;
1236 }
1237
1238 int tls_sw_recvmsg(struct sock *sk,
1239                    struct msghdr *msg,
1240                    size_t len,
1241                    int nonblock,
1242                    int flags,
1243                    int *addr_len)
1244 {
1245         struct tls_context *tls_ctx = tls_get_ctx(sk);
1246         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1247         unsigned char control;
1248         struct strp_msg *rxm;
1249         struct sk_buff *skb;
1250         ssize_t copied = 0;
1251         bool cmsg = false;
1252         int target, err = 0;
1253         long timeo;
1254         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1255         int num_async = 0;
1256
1257         flags |= nonblock;
1258
1259         if (unlikely(flags & MSG_ERRQUEUE))
1260                 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1261
1262         lock_sock(sk);
1263
1264         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1265         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1266         do {
1267                 bool zc = false;
1268                 bool async = false;
1269                 int chunk = 0;
1270
1271                 skb = tls_wait_data(sk, flags, timeo, &err);
1272                 if (!skb)
1273                         goto recv_end;
1274
1275                 rxm = strp_msg(skb);
1276
1277                 if (!cmsg) {
1278                         int cerr;
1279
1280                         cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1281                                         sizeof(ctx->control), &ctx->control);
1282                         cmsg = true;
1283                         control = ctx->control;
1284                         if (ctx->control != TLS_RECORD_TYPE_DATA) {
1285                                 if (cerr || msg->msg_flags & MSG_CTRUNC) {
1286                                         err = -EIO;
1287                                         goto recv_end;
1288                                 }
1289                         }
1290                 } else if (control != ctx->control) {
1291                         goto recv_end;
1292                 }
1293
1294                 if (!ctx->decrypted) {
1295                         int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
1296
1297                         if (!is_kvec && to_copy <= len &&
1298                             likely(!(flags & MSG_PEEK)))
1299                                 zc = true;
1300
1301                         err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1302                                                  &chunk, &zc);
1303                         if (err < 0 && err != -EINPROGRESS) {
1304                                 tls_err_abort(sk, EBADMSG);
1305                                 goto recv_end;
1306                         }
1307
1308                         if (err == -EINPROGRESS) {
1309                                 async = true;
1310                                 num_async++;
1311                                 goto pick_next_record;
1312                         }
1313
1314                         ctx->decrypted = true;
1315                 }
1316
1317                 if (!zc) {
1318                         chunk = min_t(unsigned int, rxm->full_len, len);
1319
1320                         err = skb_copy_datagram_msg(skb, rxm->offset, msg,
1321                                                     chunk);
1322                         if (err < 0)
1323                                 goto recv_end;
1324                 }
1325
1326 pick_next_record:
1327                 copied += chunk;
1328                 len -= chunk;
1329                 if (likely(!(flags & MSG_PEEK))) {
1330                         u8 control = ctx->control;
1331
1332                         /* For async, drop current skb reference */
1333                         if (async)
1334                                 skb = NULL;
1335
1336                         if (tls_sw_advance_skb(sk, skb, chunk)) {
1337                                 /* Return full control message to
1338                                  * userspace before trying to parse
1339                                  * another message type
1340                                  */
1341                                 msg->msg_flags |= MSG_EOR;
1342                                 if (control != TLS_RECORD_TYPE_DATA)
1343                                         goto recv_end;
1344                         } else {
1345                                 break;
1346                         }
1347                 } else {
1348                         /* MSG_PEEK right now cannot look beyond current skb
1349                          * from strparser, meaning we cannot advance skb here
1350                          * and thus unpause strparser since we'd loose original
1351                          * one.
1352                          */
1353                         break;
1354                 }
1355
1356                 /* If we have a new message from strparser, continue now. */
1357                 if (copied >= target && !ctx->recv_pkt)
1358                         break;
1359         } while (len);
1360
1361 recv_end:
1362         if (num_async) {
1363                 /* Wait for all previously submitted records to be decrypted */
1364                 smp_store_mb(ctx->async_notify, true);
1365                 if (atomic_read(&ctx->decrypt_pending)) {
1366                         err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1367                         if (err) {
1368                                 /* one of async decrypt failed */
1369                                 tls_err_abort(sk, err);
1370                                 copied = 0;
1371                         }
1372                 } else {
1373                         reinit_completion(&ctx->async_wait.completion);
1374                 }
1375                 WRITE_ONCE(ctx->async_notify, false);
1376         }
1377
1378         release_sock(sk);
1379         return copied ? : err;
1380 }
1381
1382 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
1383                            struct pipe_inode_info *pipe,
1384                            size_t len, unsigned int flags)
1385 {
1386         struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
1387         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1388         struct strp_msg *rxm = NULL;
1389         struct sock *sk = sock->sk;
1390         struct sk_buff *skb;
1391         ssize_t copied = 0;
1392         int err = 0;
1393         long timeo;
1394         int chunk;
1395         bool zc = false;
1396
1397         lock_sock(sk);
1398
1399         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1400
1401         skb = tls_wait_data(sk, flags, timeo, &err);
1402         if (!skb)
1403                 goto splice_read_end;
1404
1405         /* splice does not support reading control messages */
1406         if (ctx->control != TLS_RECORD_TYPE_DATA) {
1407                 err = -ENOTSUPP;
1408                 goto splice_read_end;
1409         }
1410
1411         if (!ctx->decrypted) {
1412                 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
1413
1414                 if (err < 0) {
1415                         tls_err_abort(sk, EBADMSG);
1416                         goto splice_read_end;
1417                 }
1418                 ctx->decrypted = true;
1419         }
1420         rxm = strp_msg(skb);
1421
1422         chunk = min_t(unsigned int, rxm->full_len, len);
1423         copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1424         if (copied < 0)
1425                 goto splice_read_end;
1426
1427         if (likely(!(flags & MSG_PEEK)))
1428                 tls_sw_advance_skb(sk, skb, copied);
1429
1430 splice_read_end:
1431         release_sock(sk);
1432         return copied ? : err;
1433 }
1434
1435 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1436                          struct poll_table_struct *wait)
1437 {
1438         unsigned int ret;
1439         struct sock *sk = sock->sk;
1440         struct tls_context *tls_ctx = tls_get_ctx(sk);
1441         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1442
1443         /* Grab POLLOUT and POLLHUP from the underlying socket */
1444         ret = ctx->sk_poll(file, sock, wait);
1445
1446         /* Clear POLLIN bits, and set based on recv_pkt */
1447         ret &= ~(POLLIN | POLLRDNORM);
1448         if (ctx->recv_pkt)
1449                 ret |= POLLIN | POLLRDNORM;
1450
1451         return ret;
1452 }
1453
1454 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1455 {
1456         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1457         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1458         char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1459         struct strp_msg *rxm = strp_msg(skb);
1460         size_t cipher_overhead;
1461         size_t data_len = 0;
1462         int ret;
1463
1464         /* Verify that we have a full TLS header, or wait for more data */
1465         if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1466                 return 0;
1467
1468         /* Sanity-check size of on-stack buffer. */
1469         if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1470                 ret = -EINVAL;
1471                 goto read_failure;
1472         }
1473
1474         /* Linearize header to local buffer */
1475         ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1476
1477         if (ret < 0)
1478                 goto read_failure;
1479
1480         ctx->control = header[0];
1481
1482         data_len = ((header[4] & 0xFF) | (header[3] << 8));
1483
1484         cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1485
1486         if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1487                 ret = -EMSGSIZE;
1488                 goto read_failure;
1489         }
1490         if (data_len < cipher_overhead) {
1491                 ret = -EBADMSG;
1492                 goto read_failure;
1493         }
1494
1495         if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1496             header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
1497                 ret = -EINVAL;
1498                 goto read_failure;
1499         }
1500
1501 #ifdef CONFIG_TLS_DEVICE
1502         handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1503                              *(u64*)tls_ctx->rx.rec_seq);
1504 #endif
1505         return data_len + TLS_HEADER_SIZE;
1506
1507 read_failure:
1508         tls_err_abort(strp->sk, ret);
1509
1510         return ret;
1511 }
1512
1513 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1514 {
1515         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1516         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1517
1518         ctx->decrypted = false;
1519
1520         ctx->recv_pkt = skb;
1521         strp_pause(strp);
1522
1523         ctx->saved_data_ready(strp->sk);
1524 }
1525
1526 static void tls_data_ready(struct sock *sk)
1527 {
1528         struct tls_context *tls_ctx = tls_get_ctx(sk);
1529         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1530
1531         strp_data_ready(&ctx->strp);
1532 }
1533
1534 void tls_sw_free_resources_tx(struct sock *sk)
1535 {
1536         struct tls_context *tls_ctx = tls_get_ctx(sk);
1537         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1538         struct tls_rec *rec, *tmp;
1539
1540         /* Wait for any pending async encryptions to complete */
1541         smp_store_mb(ctx->async_notify, true);
1542         if (atomic_read(&ctx->encrypt_pending))
1543                 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1544
1545         cancel_delayed_work_sync(&ctx->tx_work.work);
1546
1547         /* Tx whatever records we can transmit and abandon the rest */
1548         tls_tx_records(sk, -1);
1549
1550         /* Free up un-sent records in tx_list. First, free
1551          * the partially sent record if any at head of tx_list.
1552          */
1553         if (tls_ctx->partially_sent_record) {
1554                 struct scatterlist *sg = tls_ctx->partially_sent_record;
1555
1556                 while (1) {
1557                         put_page(sg_page(sg));
1558                         sk_mem_uncharge(sk, sg->length);
1559
1560                         if (sg_is_last(sg))
1561                                 break;
1562                         sg++;
1563                 }
1564
1565                 tls_ctx->partially_sent_record = NULL;
1566
1567                 rec = list_first_entry(&ctx->tx_list,
1568                                        struct tls_rec, list);
1569                 list_del(&rec->list);
1570                 kfree(rec);
1571         }
1572
1573         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
1574                 free_sg(sk, rec->sg_encrypted_data,
1575                         &rec->sg_encrypted_num_elem,
1576                         &rec->sg_encrypted_size);
1577
1578                 list_del(&rec->list);
1579                 kfree(rec);
1580         }
1581
1582         crypto_free_aead(ctx->aead_send);
1583         tls_free_both_sg(sk);
1584
1585         kfree(ctx);
1586 }
1587
1588 void tls_sw_release_resources_rx(struct sock *sk)
1589 {
1590         struct tls_context *tls_ctx = tls_get_ctx(sk);
1591         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1592
1593         if (ctx->aead_recv) {
1594                 kfree_skb(ctx->recv_pkt);
1595                 ctx->recv_pkt = NULL;
1596                 crypto_free_aead(ctx->aead_recv);
1597                 strp_stop(&ctx->strp);
1598                 write_lock_bh(&sk->sk_callback_lock);
1599                 sk->sk_data_ready = ctx->saved_data_ready;
1600                 write_unlock_bh(&sk->sk_callback_lock);
1601                 release_sock(sk);
1602                 strp_done(&ctx->strp);
1603                 lock_sock(sk);
1604         }
1605 }
1606
1607 void tls_sw_free_resources_rx(struct sock *sk)
1608 {
1609         struct tls_context *tls_ctx = tls_get_ctx(sk);
1610         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1611
1612         tls_sw_release_resources_rx(sk);
1613
1614         kfree(ctx);
1615 }
1616
1617 /* The work handler to transmitt the encrypted records in tx_list */
1618 static void tx_work_handler(struct work_struct *work)
1619 {
1620         struct delayed_work *delayed_work = to_delayed_work(work);
1621         struct tx_work *tx_work = container_of(delayed_work,
1622                                                struct tx_work, work);
1623         struct sock *sk = tx_work->sk;
1624         struct tls_context *tls_ctx = tls_get_ctx(sk);
1625         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1626
1627         if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
1628                 return;
1629
1630         lock_sock(sk);
1631         tls_tx_records(sk, -1);
1632         release_sock(sk);
1633 }
1634
1635 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1636 {
1637         struct tls_crypto_info *crypto_info;
1638         struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1639         struct tls_sw_context_tx *sw_ctx_tx = NULL;
1640         struct tls_sw_context_rx *sw_ctx_rx = NULL;
1641         struct cipher_context *cctx;
1642         struct crypto_aead **aead;
1643         struct strp_callbacks cb;
1644         u16 nonce_size, tag_size, iv_size, rec_seq_size;
1645         char *iv, *rec_seq;
1646         int rc = 0;
1647
1648         if (!ctx) {
1649                 rc = -EINVAL;
1650                 goto out;
1651         }
1652
1653         if (tx) {
1654                 if (!ctx->priv_ctx_tx) {
1655                         sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1656                         if (!sw_ctx_tx) {
1657                                 rc = -ENOMEM;
1658                                 goto out;
1659                         }
1660                         ctx->priv_ctx_tx = sw_ctx_tx;
1661                 } else {
1662                         sw_ctx_tx =
1663                                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1664                 }
1665         } else {
1666                 if (!ctx->priv_ctx_rx) {
1667                         sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1668                         if (!sw_ctx_rx) {
1669                                 rc = -ENOMEM;
1670                                 goto out;
1671                         }
1672                         ctx->priv_ctx_rx = sw_ctx_rx;
1673                 } else {
1674                         sw_ctx_rx =
1675                                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1676                 }
1677         }
1678
1679         if (tx) {
1680                 crypto_init_wait(&sw_ctx_tx->async_wait);
1681                 crypto_info = &ctx->crypto_send.info;
1682                 cctx = &ctx->tx;
1683                 aead = &sw_ctx_tx->aead_send;
1684                 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
1685                 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
1686                 sw_ctx_tx->tx_work.sk = sk;
1687         } else {
1688                 crypto_init_wait(&sw_ctx_rx->async_wait);
1689                 crypto_info = &ctx->crypto_recv.info;
1690                 cctx = &ctx->rx;
1691                 aead = &sw_ctx_rx->aead_recv;
1692         }
1693
1694         switch (crypto_info->cipher_type) {
1695         case TLS_CIPHER_AES_GCM_128: {
1696                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1697                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1698                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1699                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1700                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1701                 rec_seq =
1702                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1703                 gcm_128_info =
1704                         (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1705                 break;
1706         }
1707         default:
1708                 rc = -EINVAL;
1709                 goto free_priv;
1710         }
1711
1712         /* Sanity-check the IV size for stack allocations. */
1713         if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1714                 rc = -EINVAL;
1715                 goto free_priv;
1716         }
1717
1718         cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1719         cctx->tag_size = tag_size;
1720         cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1721         cctx->iv_size = iv_size;
1722         cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1723                            GFP_KERNEL);
1724         if (!cctx->iv) {
1725                 rc = -ENOMEM;
1726                 goto free_priv;
1727         }
1728         memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1729         memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1730         cctx->rec_seq_size = rec_seq_size;
1731         cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1732         if (!cctx->rec_seq) {
1733                 rc = -ENOMEM;
1734                 goto free_iv;
1735         }
1736
1737         if (!*aead) {
1738                 *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1739                 if (IS_ERR(*aead)) {
1740                         rc = PTR_ERR(*aead);
1741                         *aead = NULL;
1742                         goto free_rec_seq;
1743                 }
1744         }
1745
1746         ctx->push_pending_record = tls_sw_push_pending_record;
1747
1748         rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1749                                 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1750         if (rc)
1751                 goto free_aead;
1752
1753         rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1754         if (rc)
1755                 goto free_aead;
1756
1757         if (sw_ctx_rx) {
1758                 /* Set up strparser */
1759                 memset(&cb, 0, sizeof(cb));
1760                 cb.rcv_msg = tls_queue;
1761                 cb.parse_msg = tls_read_size;
1762
1763                 strp_init(&sw_ctx_rx->strp, sk, &cb);
1764
1765                 write_lock_bh(&sk->sk_callback_lock);
1766                 sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1767                 sk->sk_data_ready = tls_data_ready;
1768                 write_unlock_bh(&sk->sk_callback_lock);
1769
1770                 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1771
1772                 strp_check_rcv(&sw_ctx_rx->strp);
1773         }
1774
1775         goto out;
1776
1777 free_aead:
1778         crypto_free_aead(*aead);
1779         *aead = NULL;
1780 free_rec_seq:
1781         kfree(cctx->rec_seq);
1782         cctx->rec_seq = NULL;
1783 free_iv:
1784         kfree(cctx->iv);
1785         cctx->iv = NULL;
1786 free_priv:
1787         if (tx) {
1788                 kfree(ctx->priv_ctx_tx);
1789                 ctx->priv_ctx_tx = NULL;
1790         } else {
1791                 kfree(ctx->priv_ctx_rx);
1792                 ctx->priv_ctx_rx = NULL;
1793         }
1794 out:
1795         return rc;
1796 }