Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[linux-2.6-block.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 /* device_offload_lock is used to synchronize tls_dev_add
42  * against NETDEV_DOWN notifications.
43  */
44 static DECLARE_RWSEM(device_offload_lock);
45
46 static void tls_device_gc_task(struct work_struct *work);
47
48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
49 static LIST_HEAD(tls_device_gc_list);
50 static LIST_HEAD(tls_device_list);
51 static DEFINE_SPINLOCK(tls_device_lock);
52
53 static void tls_device_free_ctx(struct tls_context *ctx)
54 {
55         if (ctx->tx_conf == TLS_HW) {
56                 kfree(tls_offload_ctx_tx(ctx));
57                 kfree(ctx->tx.rec_seq);
58                 kfree(ctx->tx.iv);
59         }
60
61         if (ctx->rx_conf == TLS_HW)
62                 kfree(tls_offload_ctx_rx(ctx));
63
64         tls_ctx_free(ctx);
65 }
66
67 static void tls_device_gc_task(struct work_struct *work)
68 {
69         struct tls_context *ctx, *tmp;
70         unsigned long flags;
71         LIST_HEAD(gc_list);
72
73         spin_lock_irqsave(&tls_device_lock, flags);
74         list_splice_init(&tls_device_gc_list, &gc_list);
75         spin_unlock_irqrestore(&tls_device_lock, flags);
76
77         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
78                 struct net_device *netdev = ctx->netdev;
79
80                 if (netdev && ctx->tx_conf == TLS_HW) {
81                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
82                                                         TLS_OFFLOAD_CTX_DIR_TX);
83                         dev_put(netdev);
84                         ctx->netdev = NULL;
85                 }
86
87                 list_del(&ctx->list);
88                 tls_device_free_ctx(ctx);
89         }
90 }
91
92 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
93 {
94         unsigned long flags;
95
96         spin_lock_irqsave(&tls_device_lock, flags);
97         list_move_tail(&ctx->list, &tls_device_gc_list);
98
99         /* schedule_work inside the spinlock
100          * to make sure tls_device_down waits for that work.
101          */
102         schedule_work(&tls_device_gc_work);
103
104         spin_unlock_irqrestore(&tls_device_lock, flags);
105 }
106
107 /* We assume that the socket is already connected */
108 static struct net_device *get_netdev_for_sock(struct sock *sk)
109 {
110         struct dst_entry *dst = sk_dst_get(sk);
111         struct net_device *netdev = NULL;
112
113         if (likely(dst)) {
114                 netdev = dst->dev;
115                 dev_hold(netdev);
116         }
117
118         dst_release(dst);
119
120         return netdev;
121 }
122
123 static void destroy_record(struct tls_record_info *record)
124 {
125         int nr_frags = record->num_frags;
126         skb_frag_t *frag;
127
128         while (nr_frags-- > 0) {
129                 frag = &record->frags[nr_frags];
130                 __skb_frag_unref(frag);
131         }
132         kfree(record);
133 }
134
135 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
136 {
137         struct tls_record_info *info, *temp;
138
139         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
140                 list_del(&info->list);
141                 destroy_record(info);
142         }
143
144         offload_ctx->retransmit_hint = NULL;
145 }
146
147 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
148 {
149         struct tls_context *tls_ctx = tls_get_ctx(sk);
150         struct tls_record_info *info, *temp;
151         struct tls_offload_context_tx *ctx;
152         u64 deleted_records = 0;
153         unsigned long flags;
154
155         if (!tls_ctx)
156                 return;
157
158         ctx = tls_offload_ctx_tx(tls_ctx);
159
160         spin_lock_irqsave(&ctx->lock, flags);
161         info = ctx->retransmit_hint;
162         if (info && !before(acked_seq, info->end_seq)) {
163                 ctx->retransmit_hint = NULL;
164                 list_del(&info->list);
165                 destroy_record(info);
166                 deleted_records++;
167         }
168
169         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
170                 if (before(acked_seq, info->end_seq))
171                         break;
172                 list_del(&info->list);
173
174                 destroy_record(info);
175                 deleted_records++;
176         }
177
178         ctx->unacked_record_sn += deleted_records;
179         spin_unlock_irqrestore(&ctx->lock, flags);
180 }
181
182 /* At this point, there should be no references on this
183  * socket and no in-flight SKBs associated with this
184  * socket, so it is safe to free all the resources.
185  */
186 static void tls_device_sk_destruct(struct sock *sk)
187 {
188         struct tls_context *tls_ctx = tls_get_ctx(sk);
189         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
190
191         tls_ctx->sk_destruct(sk);
192
193         if (tls_ctx->tx_conf == TLS_HW) {
194                 if (ctx->open_record)
195                         destroy_record(ctx->open_record);
196                 delete_all_records(ctx);
197                 crypto_free_aead(ctx->aead_send);
198                 clean_acked_data_disable(inet_csk(sk));
199         }
200
201         if (refcount_dec_and_test(&tls_ctx->refcount))
202                 tls_device_queue_ctx_destruction(tls_ctx);
203 }
204
205 void tls_device_free_resources_tx(struct sock *sk)
206 {
207         struct tls_context *tls_ctx = tls_get_ctx(sk);
208
209         tls_free_partial_record(sk, tls_ctx);
210 }
211
212 static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
213                                  u32 seq)
214 {
215         struct net_device *netdev;
216         struct sk_buff *skb;
217         u8 *rcd_sn;
218
219         skb = tcp_write_queue_tail(sk);
220         if (skb)
221                 TCP_SKB_CB(skb)->eor = 1;
222
223         rcd_sn = tls_ctx->tx.rec_seq;
224
225         down_read(&device_offload_lock);
226         netdev = tls_ctx->netdev;
227         if (netdev)
228                 netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
229                                                    TLS_OFFLOAD_CTX_DIR_TX);
230         up_read(&device_offload_lock);
231
232         clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
233 }
234
235 static void tls_append_frag(struct tls_record_info *record,
236                             struct page_frag *pfrag,
237                             int size)
238 {
239         skb_frag_t *frag;
240
241         frag = &record->frags[record->num_frags - 1];
242         if (frag->page.p == pfrag->page &&
243             frag->page_offset + frag->size == pfrag->offset) {
244                 frag->size += size;
245         } else {
246                 ++frag;
247                 frag->page.p = pfrag->page;
248                 frag->page_offset = pfrag->offset;
249                 frag->size = size;
250                 ++record->num_frags;
251                 get_page(pfrag->page);
252         }
253
254         pfrag->offset += size;
255         record->len += size;
256 }
257
258 static int tls_push_record(struct sock *sk,
259                            struct tls_context *ctx,
260                            struct tls_offload_context_tx *offload_ctx,
261                            struct tls_record_info *record,
262                            struct page_frag *pfrag,
263                            int flags,
264                            unsigned char record_type)
265 {
266         struct tls_prot_info *prot = &ctx->prot_info;
267         struct tcp_sock *tp = tcp_sk(sk);
268         struct page_frag dummy_tag_frag;
269         skb_frag_t *frag;
270         int i;
271
272         /* fill prepend */
273         frag = &record->frags[0];
274         tls_fill_prepend(ctx,
275                          skb_frag_address(frag),
276                          record->len - prot->prepend_size,
277                          record_type,
278                          prot->version);
279
280         /* HW doesn't care about the data in the tag, because it fills it. */
281         dummy_tag_frag.page = skb_frag_page(frag);
282         dummy_tag_frag.offset = 0;
283
284         tls_append_frag(record, &dummy_tag_frag, prot->tag_size);
285         record->end_seq = tp->write_seq + record->len;
286         spin_lock_irq(&offload_ctx->lock);
287         list_add_tail(&record->list, &offload_ctx->records_list);
288         spin_unlock_irq(&offload_ctx->lock);
289         offload_ctx->open_record = NULL;
290
291         if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
292                 tls_device_resync_tx(sk, ctx, tp->write_seq);
293
294         tls_advance_record_sn(sk, prot, &ctx->tx);
295
296         for (i = 0; i < record->num_frags; i++) {
297                 frag = &record->frags[i];
298                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
299                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
300                             frag->size, frag->page_offset);
301                 sk_mem_charge(sk, frag->size);
302                 get_page(skb_frag_page(frag));
303         }
304         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
305
306         /* all ready, send */
307         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
308 }
309
310 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
311                                  struct page_frag *pfrag,
312                                  size_t prepend_size)
313 {
314         struct tls_record_info *record;
315         skb_frag_t *frag;
316
317         record = kmalloc(sizeof(*record), GFP_KERNEL);
318         if (!record)
319                 return -ENOMEM;
320
321         frag = &record->frags[0];
322         __skb_frag_set_page(frag, pfrag->page);
323         frag->page_offset = pfrag->offset;
324         skb_frag_size_set(frag, prepend_size);
325
326         get_page(pfrag->page);
327         pfrag->offset += prepend_size;
328
329         record->num_frags = 1;
330         record->len = prepend_size;
331         offload_ctx->open_record = record;
332         return 0;
333 }
334
335 static int tls_do_allocation(struct sock *sk,
336                              struct tls_offload_context_tx *offload_ctx,
337                              struct page_frag *pfrag,
338                              size_t prepend_size)
339 {
340         int ret;
341
342         if (!offload_ctx->open_record) {
343                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
344                                                    sk->sk_allocation))) {
345                         sk->sk_prot->enter_memory_pressure(sk);
346                         sk_stream_moderate_sndbuf(sk);
347                         return -ENOMEM;
348                 }
349
350                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
351                 if (ret)
352                         return ret;
353
354                 if (pfrag->size > pfrag->offset)
355                         return 0;
356         }
357
358         if (!sk_page_frag_refill(sk, pfrag))
359                 return -ENOMEM;
360
361         return 0;
362 }
363
364 static int tls_push_data(struct sock *sk,
365                          struct iov_iter *msg_iter,
366                          size_t size, int flags,
367                          unsigned char record_type)
368 {
369         struct tls_context *tls_ctx = tls_get_ctx(sk);
370         struct tls_prot_info *prot = &tls_ctx->prot_info;
371         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
372         int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
373         int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
374         struct tls_record_info *record = ctx->open_record;
375         struct page_frag *pfrag;
376         size_t orig_size = size;
377         u32 max_open_record_len;
378         int copy, rc = 0;
379         bool done = false;
380         long timeo;
381
382         if (flags &
383             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
384                 return -ENOTSUPP;
385
386         if (sk->sk_err)
387                 return -sk->sk_err;
388
389         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
390         if (tls_is_partially_sent_record(tls_ctx)) {
391                 rc = tls_push_partial_record(sk, tls_ctx, flags);
392                 if (rc < 0)
393                         return rc;
394         }
395
396         pfrag = sk_page_frag(sk);
397
398         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
399          * we need to leave room for an authentication tag.
400          */
401         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
402                               prot->prepend_size;
403         do {
404                 rc = tls_do_allocation(sk, ctx, pfrag,
405                                        prot->prepend_size);
406                 if (rc) {
407                         rc = sk_stream_wait_memory(sk, &timeo);
408                         if (!rc)
409                                 continue;
410
411                         record = ctx->open_record;
412                         if (!record)
413                                 break;
414 handle_error:
415                         if (record_type != TLS_RECORD_TYPE_DATA) {
416                                 /* avoid sending partial
417                                  * record with type !=
418                                  * application_data
419                                  */
420                                 size = orig_size;
421                                 destroy_record(record);
422                                 ctx->open_record = NULL;
423                         } else if (record->len > prot->prepend_size) {
424                                 goto last_record;
425                         }
426
427                         break;
428                 }
429
430                 record = ctx->open_record;
431                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
432                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
433
434                 if (copy_from_iter_nocache(page_address(pfrag->page) +
435                                                pfrag->offset,
436                                            copy, msg_iter) != copy) {
437                         rc = -EFAULT;
438                         goto handle_error;
439                 }
440                 tls_append_frag(record, pfrag, copy);
441
442                 size -= copy;
443                 if (!size) {
444 last_record:
445                         tls_push_record_flags = flags;
446                         if (more) {
447                                 tls_ctx->pending_open_record_frags =
448                                                 !!record->num_frags;
449                                 break;
450                         }
451
452                         done = true;
453                 }
454
455                 if (done || record->len >= max_open_record_len ||
456                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
457                         rc = tls_push_record(sk,
458                                              tls_ctx,
459                                              ctx,
460                                              record,
461                                              pfrag,
462                                              tls_push_record_flags,
463                                              record_type);
464                         if (rc < 0)
465                                 break;
466                 }
467         } while (!done);
468
469         if (orig_size - size > 0)
470                 rc = orig_size - size;
471
472         return rc;
473 }
474
475 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
476 {
477         unsigned char record_type = TLS_RECORD_TYPE_DATA;
478         int rc;
479
480         lock_sock(sk);
481
482         if (unlikely(msg->msg_controllen)) {
483                 rc = tls_proccess_cmsg(sk, msg, &record_type);
484                 if (rc)
485                         goto out;
486         }
487
488         rc = tls_push_data(sk, &msg->msg_iter, size,
489                            msg->msg_flags, record_type);
490
491 out:
492         release_sock(sk);
493         return rc;
494 }
495
496 int tls_device_sendpage(struct sock *sk, struct page *page,
497                         int offset, size_t size, int flags)
498 {
499         struct iov_iter msg_iter;
500         char *kaddr = kmap(page);
501         struct kvec iov;
502         int rc;
503
504         if (flags & MSG_SENDPAGE_NOTLAST)
505                 flags |= MSG_MORE;
506
507         lock_sock(sk);
508
509         if (flags & MSG_OOB) {
510                 rc = -ENOTSUPP;
511                 goto out;
512         }
513
514         iov.iov_base = kaddr + offset;
515         iov.iov_len = size;
516         iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
517         rc = tls_push_data(sk, &msg_iter, size,
518                            flags, TLS_RECORD_TYPE_DATA);
519         kunmap(page);
520
521 out:
522         release_sock(sk);
523         return rc;
524 }
525
526 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
527                                        u32 seq, u64 *p_record_sn)
528 {
529         u64 record_sn = context->hint_record_sn;
530         struct tls_record_info *info;
531
532         info = context->retransmit_hint;
533         if (!info ||
534             before(seq, info->end_seq - info->len)) {
535                 /* if retransmit_hint is irrelevant start
536                  * from the beggining of the list
537                  */
538                 info = list_first_entry(&context->records_list,
539                                         struct tls_record_info, list);
540                 record_sn = context->unacked_record_sn;
541         }
542
543         list_for_each_entry_from(info, &context->records_list, list) {
544                 if (before(seq, info->end_seq)) {
545                         if (!context->retransmit_hint ||
546                             after(info->end_seq,
547                                   context->retransmit_hint->end_seq)) {
548                                 context->hint_record_sn = record_sn;
549                                 context->retransmit_hint = info;
550                         }
551                         *p_record_sn = record_sn;
552                         return info;
553                 }
554                 record_sn++;
555         }
556
557         return NULL;
558 }
559 EXPORT_SYMBOL(tls_get_record);
560
561 static int tls_device_push_pending_record(struct sock *sk, int flags)
562 {
563         struct iov_iter msg_iter;
564
565         iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
566         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
567 }
568
569 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
570 {
571         if (!sk->sk_write_pending && tls_is_partially_sent_record(ctx)) {
572                 gfp_t sk_allocation = sk->sk_allocation;
573
574                 sk->sk_allocation = GFP_ATOMIC;
575                 tls_push_partial_record(sk, ctx, MSG_DONTWAIT | MSG_NOSIGNAL);
576                 sk->sk_allocation = sk_allocation;
577         }
578 }
579
580 static void tls_device_resync_rx(struct tls_context *tls_ctx,
581                                  struct sock *sk, u32 seq, u8 *rcd_sn)
582 {
583         struct net_device *netdev;
584
585         if (WARN_ON(test_and_set_bit(TLS_RX_SYNC_RUNNING, &tls_ctx->flags)))
586                 return;
587         netdev = READ_ONCE(tls_ctx->netdev);
588         if (netdev)
589                 netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
590                                                    TLS_OFFLOAD_CTX_DIR_RX);
591         clear_bit_unlock(TLS_RX_SYNC_RUNNING, &tls_ctx->flags);
592 }
593
594 void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
595 {
596         struct tls_context *tls_ctx = tls_get_ctx(sk);
597         struct tls_offload_context_rx *rx_ctx;
598         u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
599         struct tls_prot_info *prot;
600         u32 is_req_pending;
601         s64 resync_req;
602         u32 req_seq;
603
604         if (tls_ctx->rx_conf != TLS_HW)
605                 return;
606
607         prot = &tls_ctx->prot_info;
608         rx_ctx = tls_offload_ctx_rx(tls_ctx);
609         memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
610
611         switch (rx_ctx->resync_type) {
612         case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
613                 resync_req = atomic64_read(&rx_ctx->resync_req);
614                 req_seq = resync_req >> 32;
615                 seq += TLS_HEADER_SIZE - 1;
616                 is_req_pending = resync_req;
617
618                 if (likely(!is_req_pending) || req_seq != seq ||
619                     !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
620                         return;
621                 break;
622         case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
623                 if (likely(!rx_ctx->resync_nh_do_now))
624                         return;
625
626                 /* head of next rec is already in, note that the sock_inq will
627                  * include the currently parsed message when called from parser
628                  */
629                 if (tcp_inq(sk) > rcd_len)
630                         return;
631
632                 rx_ctx->resync_nh_do_now = 0;
633                 seq += rcd_len;
634                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
635                 break;
636         }
637
638         tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
639 }
640
641 static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
642                                            struct tls_offload_context_rx *ctx,
643                                            struct sock *sk, struct sk_buff *skb)
644 {
645         struct strp_msg *rxm;
646
647         /* device will request resyncs by itself based on stream scan */
648         if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
649                 return;
650         /* already scheduled */
651         if (ctx->resync_nh_do_now)
652                 return;
653         /* seen decrypted fragments since last fully-failed record */
654         if (ctx->resync_nh_reset) {
655                 ctx->resync_nh_reset = 0;
656                 ctx->resync_nh.decrypted_failed = 1;
657                 ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
658                 return;
659         }
660
661         if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
662                 return;
663
664         /* doing resync, bump the next target in case it fails */
665         if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
666                 ctx->resync_nh.decrypted_tgt *= 2;
667         else
668                 ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
669
670         rxm = strp_msg(skb);
671
672         /* head of next rec is already in, parser will sync for us */
673         if (tcp_inq(sk) > rxm->full_len) {
674                 ctx->resync_nh_do_now = 1;
675         } else {
676                 struct tls_prot_info *prot = &tls_ctx->prot_info;
677                 u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
678
679                 memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
680                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
681
682                 tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
683                                      rcd_sn);
684         }
685 }
686
687 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
688 {
689         struct strp_msg *rxm = strp_msg(skb);
690         int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
691         struct sk_buff *skb_iter, *unused;
692         struct scatterlist sg[1];
693         char *orig_buf, *buf;
694
695         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
696                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
697         if (!orig_buf)
698                 return -ENOMEM;
699         buf = orig_buf;
700
701         nsg = skb_cow_data(skb, 0, &unused);
702         if (unlikely(nsg < 0)) {
703                 err = nsg;
704                 goto free_buf;
705         }
706
707         sg_init_table(sg, 1);
708         sg_set_buf(&sg[0], buf,
709                    rxm->full_len + TLS_HEADER_SIZE +
710                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
711         err = skb_copy_bits(skb, offset, buf,
712                             TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
713         if (err)
714                 goto free_buf;
715
716         /* We are interested only in the decrypted data not the auth */
717         err = decrypt_skb(sk, skb, sg);
718         if (err != -EBADMSG)
719                 goto free_buf;
720         else
721                 err = 0;
722
723         data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
724
725         if (skb_pagelen(skb) > offset) {
726                 copy = min_t(int, skb_pagelen(skb) - offset, data_len);
727
728                 if (skb->decrypted) {
729                         err = skb_store_bits(skb, offset, buf, copy);
730                         if (err)
731                                 goto free_buf;
732                 }
733
734                 offset += copy;
735                 buf += copy;
736         }
737
738         pos = skb_pagelen(skb);
739         skb_walk_frags(skb, skb_iter) {
740                 int frag_pos;
741
742                 /* Practically all frags must belong to msg if reencrypt
743                  * is needed with current strparser and coalescing logic,
744                  * but strparser may "get optimized", so let's be safe.
745                  */
746                 if (pos + skb_iter->len <= offset)
747                         goto done_with_frag;
748                 if (pos >= data_len + rxm->offset)
749                         break;
750
751                 frag_pos = offset - pos;
752                 copy = min_t(int, skb_iter->len - frag_pos,
753                              data_len + rxm->offset - offset);
754
755                 if (skb_iter->decrypted) {
756                         err = skb_store_bits(skb_iter, frag_pos, buf, copy);
757                         if (err)
758                                 goto free_buf;
759                 }
760
761                 offset += copy;
762                 buf += copy;
763 done_with_frag:
764                 pos += skb_iter->len;
765         }
766
767 free_buf:
768         kfree(orig_buf);
769         return err;
770 }
771
772 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
773 {
774         struct tls_context *tls_ctx = tls_get_ctx(sk);
775         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
776         int is_decrypted = skb->decrypted;
777         int is_encrypted = !is_decrypted;
778         struct sk_buff *skb_iter;
779
780         /* Check if all the data is decrypted already */
781         skb_walk_frags(skb, skb_iter) {
782                 is_decrypted &= skb_iter->decrypted;
783                 is_encrypted &= !skb_iter->decrypted;
784         }
785
786         ctx->sw.decrypted |= is_decrypted;
787
788         /* Return immediately if the record is either entirely plaintext or
789          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
790          * record.
791          */
792         if (is_decrypted) {
793                 ctx->resync_nh_reset = 1;
794                 return 0;
795         }
796         if (is_encrypted) {
797                 tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
798                 return 0;
799         }
800
801         ctx->resync_nh_reset = 1;
802         return tls_device_reencrypt(sk, skb);
803 }
804
805 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
806                               struct net_device *netdev)
807 {
808         if (sk->sk_destruct != tls_device_sk_destruct) {
809                 refcount_set(&ctx->refcount, 1);
810                 dev_hold(netdev);
811                 ctx->netdev = netdev;
812                 spin_lock_irq(&tls_device_lock);
813                 list_add_tail(&ctx->list, &tls_device_list);
814                 spin_unlock_irq(&tls_device_lock);
815
816                 ctx->sk_destruct = sk->sk_destruct;
817                 sk->sk_destruct = tls_device_sk_destruct;
818         }
819 }
820
821 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
822 {
823         u16 nonce_size, tag_size, iv_size, rec_seq_size;
824         struct tls_context *tls_ctx = tls_get_ctx(sk);
825         struct tls_prot_info *prot = &tls_ctx->prot_info;
826         struct tls_record_info *start_marker_record;
827         struct tls_offload_context_tx *offload_ctx;
828         struct tls_crypto_info *crypto_info;
829         struct net_device *netdev;
830         char *iv, *rec_seq;
831         struct sk_buff *skb;
832         int rc = -EINVAL;
833         __be64 rcd_sn;
834
835         if (!ctx)
836                 goto out;
837
838         if (ctx->priv_ctx_tx) {
839                 rc = -EEXIST;
840                 goto out;
841         }
842
843         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
844         if (!start_marker_record) {
845                 rc = -ENOMEM;
846                 goto out;
847         }
848
849         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
850         if (!offload_ctx) {
851                 rc = -ENOMEM;
852                 goto free_marker_record;
853         }
854
855         crypto_info = &ctx->crypto_send.info;
856         if (crypto_info->version != TLS_1_2_VERSION) {
857                 rc = -EOPNOTSUPP;
858                 goto free_offload_ctx;
859         }
860
861         switch (crypto_info->cipher_type) {
862         case TLS_CIPHER_AES_GCM_128:
863                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
864                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
865                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
866                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
867                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
868                 rec_seq =
869                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
870                 break;
871         default:
872                 rc = -EINVAL;
873                 goto free_offload_ctx;
874         }
875
876         /* Sanity-check the rec_seq_size for stack allocations */
877         if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
878                 rc = -EINVAL;
879                 goto free_offload_ctx;
880         }
881
882         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
883         prot->tag_size = tag_size;
884         prot->overhead_size = prot->prepend_size + prot->tag_size;
885         prot->iv_size = iv_size;
886         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
887                              GFP_KERNEL);
888         if (!ctx->tx.iv) {
889                 rc = -ENOMEM;
890                 goto free_offload_ctx;
891         }
892
893         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
894
895         prot->rec_seq_size = rec_seq_size;
896         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
897         if (!ctx->tx.rec_seq) {
898                 rc = -ENOMEM;
899                 goto free_iv;
900         }
901
902         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
903         if (rc)
904                 goto free_rec_seq;
905
906         /* start at rec_seq - 1 to account for the start marker record */
907         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
908         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
909
910         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
911         start_marker_record->len = 0;
912         start_marker_record->num_frags = 0;
913
914         INIT_LIST_HEAD(&offload_ctx->records_list);
915         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
916         spin_lock_init(&offload_ctx->lock);
917         sg_init_table(offload_ctx->sg_tx_data,
918                       ARRAY_SIZE(offload_ctx->sg_tx_data));
919
920         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
921         ctx->push_pending_record = tls_device_push_pending_record;
922
923         /* TLS offload is greatly simplified if we don't send
924          * SKBs where only part of the payload needs to be encrypted.
925          * So mark the last skb in the write queue as end of record.
926          */
927         skb = tcp_write_queue_tail(sk);
928         if (skb)
929                 TCP_SKB_CB(skb)->eor = 1;
930
931         /* We support starting offload on multiple sockets
932          * concurrently, so we only need a read lock here.
933          * This lock must precede get_netdev_for_sock to prevent races between
934          * NETDEV_DOWN and setsockopt.
935          */
936         down_read(&device_offload_lock);
937         netdev = get_netdev_for_sock(sk);
938         if (!netdev) {
939                 pr_err_ratelimited("%s: netdev not found\n", __func__);
940                 rc = -EINVAL;
941                 goto release_lock;
942         }
943
944         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
945                 rc = -ENOTSUPP;
946                 goto release_netdev;
947         }
948
949         /* Avoid offloading if the device is down
950          * We don't want to offload new flows after
951          * the NETDEV_DOWN event
952          */
953         if (!(netdev->flags & IFF_UP)) {
954                 rc = -EINVAL;
955                 goto release_netdev;
956         }
957
958         ctx->priv_ctx_tx = offload_ctx;
959         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
960                                              &ctx->crypto_send.info,
961                                              tcp_sk(sk)->write_seq);
962         if (rc)
963                 goto release_netdev;
964
965         tls_device_attach(ctx, sk, netdev);
966
967         /* following this assignment tls_is_sk_tx_device_offloaded
968          * will return true and the context might be accessed
969          * by the netdev's xmit function.
970          */
971         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
972         dev_put(netdev);
973         up_read(&device_offload_lock);
974         goto out;
975
976 release_netdev:
977         dev_put(netdev);
978 release_lock:
979         up_read(&device_offload_lock);
980         clean_acked_data_disable(inet_csk(sk));
981         crypto_free_aead(offload_ctx->aead_send);
982 free_rec_seq:
983         kfree(ctx->tx.rec_seq);
984 free_iv:
985         kfree(ctx->tx.iv);
986 free_offload_ctx:
987         kfree(offload_ctx);
988         ctx->priv_ctx_tx = NULL;
989 free_marker_record:
990         kfree(start_marker_record);
991 out:
992         return rc;
993 }
994
995 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
996 {
997         struct tls_offload_context_rx *context;
998         struct net_device *netdev;
999         int rc = 0;
1000
1001         if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1002                 return -EOPNOTSUPP;
1003
1004         /* We support starting offload on multiple sockets
1005          * concurrently, so we only need a read lock here.
1006          * This lock must precede get_netdev_for_sock to prevent races between
1007          * NETDEV_DOWN and setsockopt.
1008          */
1009         down_read(&device_offload_lock);
1010         netdev = get_netdev_for_sock(sk);
1011         if (!netdev) {
1012                 pr_err_ratelimited("%s: netdev not found\n", __func__);
1013                 rc = -EINVAL;
1014                 goto release_lock;
1015         }
1016
1017         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1018                 rc = -ENOTSUPP;
1019                 goto release_netdev;
1020         }
1021
1022         /* Avoid offloading if the device is down
1023          * We don't want to offload new flows after
1024          * the NETDEV_DOWN event
1025          */
1026         if (!(netdev->flags & IFF_UP)) {
1027                 rc = -EINVAL;
1028                 goto release_netdev;
1029         }
1030
1031         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
1032         if (!context) {
1033                 rc = -ENOMEM;
1034                 goto release_netdev;
1035         }
1036         context->resync_nh_reset = 1;
1037
1038         ctx->priv_ctx_rx = context;
1039         rc = tls_set_sw_offload(sk, ctx, 0);
1040         if (rc)
1041                 goto release_ctx;
1042
1043         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1044                                              &ctx->crypto_recv.info,
1045                                              tcp_sk(sk)->copied_seq);
1046         if (rc)
1047                 goto free_sw_resources;
1048
1049         tls_device_attach(ctx, sk, netdev);
1050         goto release_netdev;
1051
1052 free_sw_resources:
1053         up_read(&device_offload_lock);
1054         tls_sw_free_resources_rx(sk);
1055         down_read(&device_offload_lock);
1056 release_ctx:
1057         ctx->priv_ctx_rx = NULL;
1058 release_netdev:
1059         dev_put(netdev);
1060 release_lock:
1061         up_read(&device_offload_lock);
1062         return rc;
1063 }
1064
1065 void tls_device_offload_cleanup_rx(struct sock *sk)
1066 {
1067         struct tls_context *tls_ctx = tls_get_ctx(sk);
1068         struct net_device *netdev;
1069
1070         down_read(&device_offload_lock);
1071         netdev = tls_ctx->netdev;
1072         if (!netdev)
1073                 goto out;
1074
1075         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1076                                         TLS_OFFLOAD_CTX_DIR_RX);
1077
1078         if (tls_ctx->tx_conf != TLS_HW) {
1079                 dev_put(netdev);
1080                 tls_ctx->netdev = NULL;
1081         }
1082 out:
1083         up_read(&device_offload_lock);
1084         tls_sw_release_resources_rx(sk);
1085 }
1086
1087 static int tls_device_down(struct net_device *netdev)
1088 {
1089         struct tls_context *ctx, *tmp;
1090         unsigned long flags;
1091         LIST_HEAD(list);
1092
1093         /* Request a write lock to block new offload attempts */
1094         down_write(&device_offload_lock);
1095
1096         spin_lock_irqsave(&tls_device_lock, flags);
1097         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1098                 if (ctx->netdev != netdev ||
1099                     !refcount_inc_not_zero(&ctx->refcount))
1100                         continue;
1101
1102                 list_move(&ctx->list, &list);
1103         }
1104         spin_unlock_irqrestore(&tls_device_lock, flags);
1105
1106         list_for_each_entry_safe(ctx, tmp, &list, list) {
1107                 if (ctx->tx_conf == TLS_HW)
1108                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1109                                                         TLS_OFFLOAD_CTX_DIR_TX);
1110                 if (ctx->rx_conf == TLS_HW)
1111                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1112                                                         TLS_OFFLOAD_CTX_DIR_RX);
1113                 WRITE_ONCE(ctx->netdev, NULL);
1114                 smp_mb__before_atomic(); /* pairs with test_and_set_bit() */
1115                 while (test_bit(TLS_RX_SYNC_RUNNING, &ctx->flags))
1116                         usleep_range(10, 200);
1117                 dev_put(netdev);
1118                 list_del_init(&ctx->list);
1119
1120                 if (refcount_dec_and_test(&ctx->refcount))
1121                         tls_device_free_ctx(ctx);
1122         }
1123
1124         up_write(&device_offload_lock);
1125
1126         flush_work(&tls_device_gc_work);
1127
1128         return NOTIFY_DONE;
1129 }
1130
1131 static int tls_dev_event(struct notifier_block *this, unsigned long event,
1132                          void *ptr)
1133 {
1134         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1135
1136         if (!dev->tlsdev_ops &&
1137             !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1138                 return NOTIFY_DONE;
1139
1140         switch (event) {
1141         case NETDEV_REGISTER:
1142         case NETDEV_FEAT_CHANGE:
1143                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
1144                     !dev->tlsdev_ops->tls_dev_resync)
1145                         return NOTIFY_BAD;
1146
1147                 if  (dev->tlsdev_ops &&
1148                      dev->tlsdev_ops->tls_dev_add &&
1149                      dev->tlsdev_ops->tls_dev_del)
1150                         return NOTIFY_DONE;
1151                 else
1152                         return NOTIFY_BAD;
1153         case NETDEV_DOWN:
1154                 return tls_device_down(dev);
1155         }
1156         return NOTIFY_DONE;
1157 }
1158
1159 static struct notifier_block tls_dev_notifier = {
1160         .notifier_call  = tls_dev_event,
1161 };
1162
1163 void __init tls_device_init(void)
1164 {
1165         register_netdevice_notifier(&tls_dev_notifier);
1166 }
1167
1168 void __exit tls_device_cleanup(void)
1169 {
1170         unregister_netdevice_notifier(&tls_dev_notifier);
1171         flush_work(&tls_device_gc_work);
1172         clean_acked_data_flush();
1173 }