Merge tag 'block-6.1-2022-10-20' of git://git.kernel.dk/linux
[linux-block.git] / drivers / crypto / qat / qat_common / qat_asym_algs.c
1 // SPDX-License-Identifier: (BSD-3-Clause OR GPL-2.0-only)
2 /* Copyright(c) 2014 - 2020 Intel Corporation */
3 #include <linux/module.h>
4 #include <crypto/internal/rsa.h>
5 #include <crypto/internal/akcipher.h>
6 #include <crypto/akcipher.h>
7 #include <crypto/kpp.h>
8 #include <crypto/internal/kpp.h>
9 #include <crypto/dh.h>
10 #include <linux/dma-mapping.h>
11 #include <linux/fips.h>
12 #include <crypto/scatterwalk.h>
13 #include "icp_qat_fw_pke.h"
14 #include "adf_accel_devices.h"
15 #include "qat_algs_send.h"
16 #include "adf_transport.h"
17 #include "adf_common_drv.h"
18 #include "qat_crypto.h"
19
20 static DEFINE_MUTEX(algs_lock);
21 static unsigned int active_devs;
22
23 struct qat_rsa_input_params {
24         union {
25                 struct {
26                         dma_addr_t m;
27                         dma_addr_t e;
28                         dma_addr_t n;
29                 } enc;
30                 struct {
31                         dma_addr_t c;
32                         dma_addr_t d;
33                         dma_addr_t n;
34                 } dec;
35                 struct {
36                         dma_addr_t c;
37                         dma_addr_t p;
38                         dma_addr_t q;
39                         dma_addr_t dp;
40                         dma_addr_t dq;
41                         dma_addr_t qinv;
42                 } dec_crt;
43                 u64 in_tab[8];
44         };
45 } __packed __aligned(64);
46
47 struct qat_rsa_output_params {
48         union {
49                 struct {
50                         dma_addr_t c;
51                 } enc;
52                 struct {
53                         dma_addr_t m;
54                 } dec;
55                 u64 out_tab[8];
56         };
57 } __packed __aligned(64);
58
59 struct qat_rsa_ctx {
60         char *n;
61         char *e;
62         char *d;
63         char *p;
64         char *q;
65         char *dp;
66         char *dq;
67         char *qinv;
68         dma_addr_t dma_n;
69         dma_addr_t dma_e;
70         dma_addr_t dma_d;
71         dma_addr_t dma_p;
72         dma_addr_t dma_q;
73         dma_addr_t dma_dp;
74         dma_addr_t dma_dq;
75         dma_addr_t dma_qinv;
76         unsigned int key_sz;
77         bool crt_mode;
78         struct qat_crypto_instance *inst;
79 } __packed __aligned(64);
80
81 struct qat_dh_input_params {
82         union {
83                 struct {
84                         dma_addr_t b;
85                         dma_addr_t xa;
86                         dma_addr_t p;
87                 } in;
88                 struct {
89                         dma_addr_t xa;
90                         dma_addr_t p;
91                 } in_g2;
92                 u64 in_tab[8];
93         };
94 } __packed __aligned(64);
95
96 struct qat_dh_output_params {
97         union {
98                 dma_addr_t r;
99                 u64 out_tab[8];
100         };
101 } __packed __aligned(64);
102
103 struct qat_dh_ctx {
104         char *g;
105         char *xa;
106         char *p;
107         dma_addr_t dma_g;
108         dma_addr_t dma_xa;
109         dma_addr_t dma_p;
110         unsigned int p_size;
111         bool g2;
112         struct qat_crypto_instance *inst;
113 } __packed __aligned(64);
114
115 struct qat_asym_request {
116         union {
117                 struct qat_rsa_input_params rsa;
118                 struct qat_dh_input_params dh;
119         } in;
120         union {
121                 struct qat_rsa_output_params rsa;
122                 struct qat_dh_output_params dh;
123         } out;
124         dma_addr_t phy_in;
125         dma_addr_t phy_out;
126         char *src_align;
127         char *dst_align;
128         struct icp_qat_fw_pke_request req;
129         union {
130                 struct qat_rsa_ctx *rsa;
131                 struct qat_dh_ctx *dh;
132         } ctx;
133         union {
134                 struct akcipher_request *rsa;
135                 struct kpp_request *dh;
136         } areq;
137         int err;
138         void (*cb)(struct icp_qat_fw_pke_resp *resp);
139         struct qat_alg_req alg_req;
140 } __aligned(64);
141
142 static int qat_alg_send_asym_message(struct qat_asym_request *qat_req,
143                                      struct qat_crypto_instance *inst,
144                                      struct crypto_async_request *base)
145 {
146         struct qat_alg_req *alg_req = &qat_req->alg_req;
147
148         alg_req->fw_req = (u32 *)&qat_req->req;
149         alg_req->tx_ring = inst->pke_tx;
150         alg_req->base = base;
151         alg_req->backlog = &inst->backlog;
152
153         return qat_alg_send_message(alg_req);
154 }
155
156 static void qat_dh_cb(struct icp_qat_fw_pke_resp *resp)
157 {
158         struct qat_asym_request *req = (void *)(__force long)resp->opaque;
159         struct kpp_request *areq = req->areq.dh;
160         struct device *dev = &GET_DEV(req->ctx.dh->inst->accel_dev);
161         int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
162                                 resp->pke_resp_hdr.comn_resp_flags);
163
164         err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
165
166         if (areq->src) {
167                 dma_unmap_single(dev, req->in.dh.in.b, req->ctx.dh->p_size,
168                                  DMA_TO_DEVICE);
169                 kfree_sensitive(req->src_align);
170         }
171
172         areq->dst_len = req->ctx.dh->p_size;
173         if (req->dst_align) {
174                 scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
175                                          areq->dst_len, 1);
176                 kfree_sensitive(req->dst_align);
177         }
178
179         dma_unmap_single(dev, req->out.dh.r, req->ctx.dh->p_size,
180                          DMA_FROM_DEVICE);
181
182         dma_unmap_single(dev, req->phy_in, sizeof(struct qat_dh_input_params),
183                          DMA_TO_DEVICE);
184         dma_unmap_single(dev, req->phy_out,
185                          sizeof(struct qat_dh_output_params),
186                          DMA_TO_DEVICE);
187
188         kpp_request_complete(areq, err);
189 }
190
191 #define PKE_DH_1536 0x390c1a49
192 #define PKE_DH_G2_1536 0x2e0b1a3e
193 #define PKE_DH_2048 0x4d0c1a60
194 #define PKE_DH_G2_2048 0x3e0b1a55
195 #define PKE_DH_3072 0x510c1a77
196 #define PKE_DH_G2_3072 0x3a0b1a6c
197 #define PKE_DH_4096 0x690c1a8e
198 #define PKE_DH_G2_4096 0x4a0b1a83
199
200 static unsigned long qat_dh_fn_id(unsigned int len, bool g2)
201 {
202         unsigned int bitslen = len << 3;
203
204         switch (bitslen) {
205         case 1536:
206                 return g2 ? PKE_DH_G2_1536 : PKE_DH_1536;
207         case 2048:
208                 return g2 ? PKE_DH_G2_2048 : PKE_DH_2048;
209         case 3072:
210                 return g2 ? PKE_DH_G2_3072 : PKE_DH_3072;
211         case 4096:
212                 return g2 ? PKE_DH_G2_4096 : PKE_DH_4096;
213         default:
214                 return 0;
215         }
216 }
217
218 static int qat_dh_compute_value(struct kpp_request *req)
219 {
220         struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
221         struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
222         struct qat_crypto_instance *inst = ctx->inst;
223         struct device *dev = &GET_DEV(inst->accel_dev);
224         struct qat_asym_request *qat_req =
225                         PTR_ALIGN(kpp_request_ctx(req), 64);
226         struct icp_qat_fw_pke_request *msg = &qat_req->req;
227         gfp_t flags = qat_algs_alloc_flags(&req->base);
228         int n_input_params = 0;
229         u8 *vaddr;
230         int ret;
231
232         if (unlikely(!ctx->xa))
233                 return -EINVAL;
234
235         if (req->dst_len < ctx->p_size) {
236                 req->dst_len = ctx->p_size;
237                 return -EOVERFLOW;
238         }
239
240         if (req->src_len > ctx->p_size)
241                 return -EINVAL;
242
243         memset(msg, '\0', sizeof(*msg));
244         ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
245                                           ICP_QAT_FW_COMN_REQ_FLAG_SET);
246
247         msg->pke_hdr.cd_pars.func_id = qat_dh_fn_id(ctx->p_size,
248                                                     !req->src && ctx->g2);
249         if (unlikely(!msg->pke_hdr.cd_pars.func_id))
250                 return -EINVAL;
251
252         qat_req->cb = qat_dh_cb;
253         qat_req->ctx.dh = ctx;
254         qat_req->areq.dh = req;
255         msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
256         msg->pke_hdr.comn_req_flags =
257                 ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
258                                             QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
259
260         /*
261          * If no source is provided use g as base
262          */
263         if (req->src) {
264                 qat_req->in.dh.in.xa = ctx->dma_xa;
265                 qat_req->in.dh.in.p = ctx->dma_p;
266                 n_input_params = 3;
267         } else {
268                 if (ctx->g2) {
269                         qat_req->in.dh.in_g2.xa = ctx->dma_xa;
270                         qat_req->in.dh.in_g2.p = ctx->dma_p;
271                         n_input_params = 2;
272                 } else {
273                         qat_req->in.dh.in.b = ctx->dma_g;
274                         qat_req->in.dh.in.xa = ctx->dma_xa;
275                         qat_req->in.dh.in.p = ctx->dma_p;
276                         n_input_params = 3;
277                 }
278         }
279
280         ret = -ENOMEM;
281         if (req->src) {
282                 /*
283                  * src can be of any size in valid range, but HW expects it to
284                  * be the same as modulo p so in case it is different we need
285                  * to allocate a new buf and copy src data.
286                  * In other case we just need to map the user provided buffer.
287                  * Also need to make sure that it is in contiguous buffer.
288                  */
289                 if (sg_is_last(req->src) && req->src_len == ctx->p_size) {
290                         qat_req->src_align = NULL;
291                         vaddr = sg_virt(req->src);
292                 } else {
293                         int shift = ctx->p_size - req->src_len;
294
295                         qat_req->src_align = kzalloc(ctx->p_size, flags);
296                         if (unlikely(!qat_req->src_align))
297                                 return ret;
298
299                         scatterwalk_map_and_copy(qat_req->src_align + shift,
300                                                  req->src, 0, req->src_len, 0);
301
302                         vaddr = qat_req->src_align;
303                 }
304
305                 qat_req->in.dh.in.b = dma_map_single(dev, vaddr, ctx->p_size,
306                                                      DMA_TO_DEVICE);
307                 if (unlikely(dma_mapping_error(dev, qat_req->in.dh.in.b)))
308                         goto unmap_src;
309         }
310         /*
311          * dst can be of any size in valid range, but HW expects it to be the
312          * same as modulo m so in case it is different we need to allocate a
313          * new buf and copy src data.
314          * In other case we just need to map the user provided buffer.
315          * Also need to make sure that it is in contiguous buffer.
316          */
317         if (sg_is_last(req->dst) && req->dst_len == ctx->p_size) {
318                 qat_req->dst_align = NULL;
319                 vaddr = sg_virt(req->dst);
320         } else {
321                 qat_req->dst_align = kzalloc(ctx->p_size, flags);
322                 if (unlikely(!qat_req->dst_align))
323                         goto unmap_src;
324
325                 vaddr = qat_req->dst_align;
326         }
327         qat_req->out.dh.r = dma_map_single(dev, vaddr, ctx->p_size,
328                                            DMA_FROM_DEVICE);
329         if (unlikely(dma_mapping_error(dev, qat_req->out.dh.r)))
330                 goto unmap_dst;
331
332         qat_req->in.dh.in_tab[n_input_params] = 0;
333         qat_req->out.dh.out_tab[1] = 0;
334         /* Mapping in.in.b or in.in_g2.xa is the same */
335         qat_req->phy_in = dma_map_single(dev, &qat_req->in.dh,
336                                          sizeof(struct qat_dh_input_params),
337                                          DMA_TO_DEVICE);
338         if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
339                 goto unmap_dst;
340
341         qat_req->phy_out = dma_map_single(dev, &qat_req->out.dh,
342                                           sizeof(struct qat_dh_output_params),
343                                           DMA_TO_DEVICE);
344         if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
345                 goto unmap_in_params;
346
347         msg->pke_mid.src_data_addr = qat_req->phy_in;
348         msg->pke_mid.dest_data_addr = qat_req->phy_out;
349         msg->pke_mid.opaque = (u64)(__force long)qat_req;
350         msg->input_param_count = n_input_params;
351         msg->output_param_count = 1;
352
353         ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
354         if (ret == -ENOSPC)
355                 goto unmap_all;
356
357         return ret;
358
359 unmap_all:
360         if (!dma_mapping_error(dev, qat_req->phy_out))
361                 dma_unmap_single(dev, qat_req->phy_out,
362                                  sizeof(struct qat_dh_output_params),
363                                  DMA_TO_DEVICE);
364 unmap_in_params:
365         if (!dma_mapping_error(dev, qat_req->phy_in))
366                 dma_unmap_single(dev, qat_req->phy_in,
367                                  sizeof(struct qat_dh_input_params),
368                                  DMA_TO_DEVICE);
369 unmap_dst:
370         if (!dma_mapping_error(dev, qat_req->out.dh.r))
371                 dma_unmap_single(dev, qat_req->out.dh.r, ctx->p_size,
372                                  DMA_FROM_DEVICE);
373         kfree_sensitive(qat_req->dst_align);
374 unmap_src:
375         if (req->src) {
376                 if (!dma_mapping_error(dev, qat_req->in.dh.in.b))
377                         dma_unmap_single(dev, qat_req->in.dh.in.b,
378                                          ctx->p_size,
379                                          DMA_TO_DEVICE);
380                 kfree_sensitive(qat_req->src_align);
381         }
382         return ret;
383 }
384
385 static int qat_dh_check_params_length(unsigned int p_len)
386 {
387         switch (p_len) {
388         case 1536:
389         case 2048:
390         case 3072:
391         case 4096:
392                 return 0;
393         }
394         return -EINVAL;
395 }
396
397 static int qat_dh_set_params(struct qat_dh_ctx *ctx, struct dh *params)
398 {
399         struct qat_crypto_instance *inst = ctx->inst;
400         struct device *dev = &GET_DEV(inst->accel_dev);
401
402         if (qat_dh_check_params_length(params->p_size << 3))
403                 return -EINVAL;
404
405         ctx->p_size = params->p_size;
406         ctx->p = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_p, GFP_KERNEL);
407         if (!ctx->p)
408                 return -ENOMEM;
409         memcpy(ctx->p, params->p, ctx->p_size);
410
411         /* If g equals 2 don't copy it */
412         if (params->g_size == 1 && *(char *)params->g == 0x02) {
413                 ctx->g2 = true;
414                 return 0;
415         }
416
417         ctx->g = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_g, GFP_KERNEL);
418         if (!ctx->g)
419                 return -ENOMEM;
420         memcpy(ctx->g + (ctx->p_size - params->g_size), params->g,
421                params->g_size);
422
423         return 0;
424 }
425
426 static void qat_dh_clear_ctx(struct device *dev, struct qat_dh_ctx *ctx)
427 {
428         if (ctx->g) {
429                 memset(ctx->g, 0, ctx->p_size);
430                 dma_free_coherent(dev, ctx->p_size, ctx->g, ctx->dma_g);
431                 ctx->g = NULL;
432         }
433         if (ctx->xa) {
434                 memset(ctx->xa, 0, ctx->p_size);
435                 dma_free_coherent(dev, ctx->p_size, ctx->xa, ctx->dma_xa);
436                 ctx->xa = NULL;
437         }
438         if (ctx->p) {
439                 memset(ctx->p, 0, ctx->p_size);
440                 dma_free_coherent(dev, ctx->p_size, ctx->p, ctx->dma_p);
441                 ctx->p = NULL;
442         }
443         ctx->p_size = 0;
444         ctx->g2 = false;
445 }
446
447 static int qat_dh_set_secret(struct crypto_kpp *tfm, const void *buf,
448                              unsigned int len)
449 {
450         struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
451         struct device *dev = &GET_DEV(ctx->inst->accel_dev);
452         struct dh params;
453         int ret;
454
455         if (crypto_dh_decode_key(buf, len, &params) < 0)
456                 return -EINVAL;
457
458         /* Free old secret if any */
459         qat_dh_clear_ctx(dev, ctx);
460
461         ret = qat_dh_set_params(ctx, &params);
462         if (ret < 0)
463                 goto err_clear_ctx;
464
465         ctx->xa = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_xa,
466                                      GFP_KERNEL);
467         if (!ctx->xa) {
468                 ret = -ENOMEM;
469                 goto err_clear_ctx;
470         }
471         memcpy(ctx->xa + (ctx->p_size - params.key_size), params.key,
472                params.key_size);
473
474         return 0;
475
476 err_clear_ctx:
477         qat_dh_clear_ctx(dev, ctx);
478         return ret;
479 }
480
481 static unsigned int qat_dh_max_size(struct crypto_kpp *tfm)
482 {
483         struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
484
485         return ctx->p_size;
486 }
487
488 static int qat_dh_init_tfm(struct crypto_kpp *tfm)
489 {
490         struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
491         struct qat_crypto_instance *inst =
492                         qat_crypto_get_instance_node(numa_node_id());
493
494         if (!inst)
495                 return -EINVAL;
496
497         ctx->p_size = 0;
498         ctx->g2 = false;
499         ctx->inst = inst;
500         return 0;
501 }
502
503 static void qat_dh_exit_tfm(struct crypto_kpp *tfm)
504 {
505         struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
506         struct device *dev = &GET_DEV(ctx->inst->accel_dev);
507
508         qat_dh_clear_ctx(dev, ctx);
509         qat_crypto_put_instance(ctx->inst);
510 }
511
512 static void qat_rsa_cb(struct icp_qat_fw_pke_resp *resp)
513 {
514         struct qat_asym_request *req = (void *)(__force long)resp->opaque;
515         struct akcipher_request *areq = req->areq.rsa;
516         struct device *dev = &GET_DEV(req->ctx.rsa->inst->accel_dev);
517         int err = ICP_QAT_FW_PKE_RESP_PKE_STAT_GET(
518                                 resp->pke_resp_hdr.comn_resp_flags);
519
520         err = (err == ICP_QAT_FW_COMN_STATUS_FLAG_OK) ? 0 : -EINVAL;
521
522         kfree_sensitive(req->src_align);
523
524         dma_unmap_single(dev, req->in.rsa.enc.m, req->ctx.rsa->key_sz,
525                          DMA_TO_DEVICE);
526
527         areq->dst_len = req->ctx.rsa->key_sz;
528         if (req->dst_align) {
529                 scatterwalk_map_and_copy(req->dst_align, areq->dst, 0,
530                                          areq->dst_len, 1);
531
532                 kfree_sensitive(req->dst_align);
533         }
534
535         dma_unmap_single(dev, req->out.rsa.enc.c, req->ctx.rsa->key_sz,
536                          DMA_FROM_DEVICE);
537
538         dma_unmap_single(dev, req->phy_in, sizeof(struct qat_rsa_input_params),
539                          DMA_TO_DEVICE);
540         dma_unmap_single(dev, req->phy_out,
541                          sizeof(struct qat_rsa_output_params),
542                          DMA_TO_DEVICE);
543
544         akcipher_request_complete(areq, err);
545 }
546
547 void qat_alg_asym_callback(void *_resp)
548 {
549         struct icp_qat_fw_pke_resp *resp = _resp;
550         struct qat_asym_request *areq = (void *)(__force long)resp->opaque;
551         struct qat_instance_backlog *backlog = areq->alg_req.backlog;
552
553         areq->cb(resp);
554
555         qat_alg_send_backlog(backlog);
556 }
557
558 #define PKE_RSA_EP_512 0x1c161b21
559 #define PKE_RSA_EP_1024 0x35111bf7
560 #define PKE_RSA_EP_1536 0x4d111cdc
561 #define PKE_RSA_EP_2048 0x6e111dba
562 #define PKE_RSA_EP_3072 0x7d111ea3
563 #define PKE_RSA_EP_4096 0xa5101f7e
564
565 static unsigned long qat_rsa_enc_fn_id(unsigned int len)
566 {
567         unsigned int bitslen = len << 3;
568
569         switch (bitslen) {
570         case 512:
571                 return PKE_RSA_EP_512;
572         case 1024:
573                 return PKE_RSA_EP_1024;
574         case 1536:
575                 return PKE_RSA_EP_1536;
576         case 2048:
577                 return PKE_RSA_EP_2048;
578         case 3072:
579                 return PKE_RSA_EP_3072;
580         case 4096:
581                 return PKE_RSA_EP_4096;
582         default:
583                 return 0;
584         }
585 }
586
587 #define PKE_RSA_DP1_512 0x1c161b3c
588 #define PKE_RSA_DP1_1024 0x35111c12
589 #define PKE_RSA_DP1_1536 0x4d111cf7
590 #define PKE_RSA_DP1_2048 0x6e111dda
591 #define PKE_RSA_DP1_3072 0x7d111ebe
592 #define PKE_RSA_DP1_4096 0xa5101f98
593
594 static unsigned long qat_rsa_dec_fn_id(unsigned int len)
595 {
596         unsigned int bitslen = len << 3;
597
598         switch (bitslen) {
599         case 512:
600                 return PKE_RSA_DP1_512;
601         case 1024:
602                 return PKE_RSA_DP1_1024;
603         case 1536:
604                 return PKE_RSA_DP1_1536;
605         case 2048:
606                 return PKE_RSA_DP1_2048;
607         case 3072:
608                 return PKE_RSA_DP1_3072;
609         case 4096:
610                 return PKE_RSA_DP1_4096;
611         default:
612                 return 0;
613         }
614 }
615
616 #define PKE_RSA_DP2_512 0x1c131b57
617 #define PKE_RSA_DP2_1024 0x26131c2d
618 #define PKE_RSA_DP2_1536 0x45111d12
619 #define PKE_RSA_DP2_2048 0x59121dfa
620 #define PKE_RSA_DP2_3072 0x81121ed9
621 #define PKE_RSA_DP2_4096 0xb1111fb2
622
623 static unsigned long qat_rsa_dec_fn_id_crt(unsigned int len)
624 {
625         unsigned int bitslen = len << 3;
626
627         switch (bitslen) {
628         case 512:
629                 return PKE_RSA_DP2_512;
630         case 1024:
631                 return PKE_RSA_DP2_1024;
632         case 1536:
633                 return PKE_RSA_DP2_1536;
634         case 2048:
635                 return PKE_RSA_DP2_2048;
636         case 3072:
637                 return PKE_RSA_DP2_3072;
638         case 4096:
639                 return PKE_RSA_DP2_4096;
640         default:
641                 return 0;
642         }
643 }
644
645 static int qat_rsa_enc(struct akcipher_request *req)
646 {
647         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
648         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
649         struct qat_crypto_instance *inst = ctx->inst;
650         struct device *dev = &GET_DEV(inst->accel_dev);
651         struct qat_asym_request *qat_req =
652                         PTR_ALIGN(akcipher_request_ctx(req), 64);
653         struct icp_qat_fw_pke_request *msg = &qat_req->req;
654         gfp_t flags = qat_algs_alloc_flags(&req->base);
655         u8 *vaddr;
656         int ret;
657
658         if (unlikely(!ctx->n || !ctx->e))
659                 return -EINVAL;
660
661         if (req->dst_len < ctx->key_sz) {
662                 req->dst_len = ctx->key_sz;
663                 return -EOVERFLOW;
664         }
665
666         if (req->src_len > ctx->key_sz)
667                 return -EINVAL;
668
669         memset(msg, '\0', sizeof(*msg));
670         ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
671                                           ICP_QAT_FW_COMN_REQ_FLAG_SET);
672         msg->pke_hdr.cd_pars.func_id = qat_rsa_enc_fn_id(ctx->key_sz);
673         if (unlikely(!msg->pke_hdr.cd_pars.func_id))
674                 return -EINVAL;
675
676         qat_req->cb = qat_rsa_cb;
677         qat_req->ctx.rsa = ctx;
678         qat_req->areq.rsa = req;
679         msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
680         msg->pke_hdr.comn_req_flags =
681                 ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
682                                             QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
683
684         qat_req->in.rsa.enc.e = ctx->dma_e;
685         qat_req->in.rsa.enc.n = ctx->dma_n;
686         ret = -ENOMEM;
687
688         /*
689          * src can be of any size in valid range, but HW expects it to be the
690          * same as modulo n so in case it is different we need to allocate a
691          * new buf and copy src data.
692          * In other case we just need to map the user provided buffer.
693          * Also need to make sure that it is in contiguous buffer.
694          */
695         if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
696                 qat_req->src_align = NULL;
697                 vaddr = sg_virt(req->src);
698         } else {
699                 int shift = ctx->key_sz - req->src_len;
700
701                 qat_req->src_align = kzalloc(ctx->key_sz, flags);
702                 if (unlikely(!qat_req->src_align))
703                         return ret;
704
705                 scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
706                                          0, req->src_len, 0);
707                 vaddr = qat_req->src_align;
708         }
709
710         qat_req->in.rsa.enc.m = dma_map_single(dev, vaddr, ctx->key_sz,
711                                                DMA_TO_DEVICE);
712         if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.enc.m)))
713                 goto unmap_src;
714
715         if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
716                 qat_req->dst_align = NULL;
717                 vaddr = sg_virt(req->dst);
718         } else {
719                 qat_req->dst_align = kzalloc(ctx->key_sz, flags);
720                 if (unlikely(!qat_req->dst_align))
721                         goto unmap_src;
722                 vaddr = qat_req->dst_align;
723         }
724
725         qat_req->out.rsa.enc.c = dma_map_single(dev, vaddr, ctx->key_sz,
726                                                 DMA_FROM_DEVICE);
727         if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.enc.c)))
728                 goto unmap_dst;
729
730         qat_req->in.rsa.in_tab[3] = 0;
731         qat_req->out.rsa.out_tab[1] = 0;
732         qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
733                                          sizeof(struct qat_rsa_input_params),
734                                          DMA_TO_DEVICE);
735         if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
736                 goto unmap_dst;
737
738         qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
739                                           sizeof(struct qat_rsa_output_params),
740                                           DMA_TO_DEVICE);
741         if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
742                 goto unmap_in_params;
743
744         msg->pke_mid.src_data_addr = qat_req->phy_in;
745         msg->pke_mid.dest_data_addr = qat_req->phy_out;
746         msg->pke_mid.opaque = (u64)(__force long)qat_req;
747         msg->input_param_count = 3;
748         msg->output_param_count = 1;
749
750         ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
751         if (ret == -ENOSPC)
752                 goto unmap_all;
753
754         return ret;
755
756 unmap_all:
757         if (!dma_mapping_error(dev, qat_req->phy_out))
758                 dma_unmap_single(dev, qat_req->phy_out,
759                                  sizeof(struct qat_rsa_output_params),
760                                  DMA_TO_DEVICE);
761 unmap_in_params:
762         if (!dma_mapping_error(dev, qat_req->phy_in))
763                 dma_unmap_single(dev, qat_req->phy_in,
764                                  sizeof(struct qat_rsa_input_params),
765                                  DMA_TO_DEVICE);
766 unmap_dst:
767         if (!dma_mapping_error(dev, qat_req->out.rsa.enc.c))
768                 dma_unmap_single(dev, qat_req->out.rsa.enc.c,
769                                  ctx->key_sz, DMA_FROM_DEVICE);
770         kfree_sensitive(qat_req->dst_align);
771 unmap_src:
772         if (!dma_mapping_error(dev, qat_req->in.rsa.enc.m))
773                 dma_unmap_single(dev, qat_req->in.rsa.enc.m, ctx->key_sz,
774                                  DMA_TO_DEVICE);
775         kfree_sensitive(qat_req->src_align);
776         return ret;
777 }
778
779 static int qat_rsa_dec(struct akcipher_request *req)
780 {
781         struct crypto_akcipher *tfm = crypto_akcipher_reqtfm(req);
782         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
783         struct qat_crypto_instance *inst = ctx->inst;
784         struct device *dev = &GET_DEV(inst->accel_dev);
785         struct qat_asym_request *qat_req =
786                         PTR_ALIGN(akcipher_request_ctx(req), 64);
787         struct icp_qat_fw_pke_request *msg = &qat_req->req;
788         gfp_t flags = qat_algs_alloc_flags(&req->base);
789         u8 *vaddr;
790         int ret;
791
792         if (unlikely(!ctx->n || !ctx->d))
793                 return -EINVAL;
794
795         if (req->dst_len < ctx->key_sz) {
796                 req->dst_len = ctx->key_sz;
797                 return -EOVERFLOW;
798         }
799
800         if (req->src_len > ctx->key_sz)
801                 return -EINVAL;
802
803         memset(msg, '\0', sizeof(*msg));
804         ICP_QAT_FW_PKE_HDR_VALID_FLAG_SET(msg->pke_hdr,
805                                           ICP_QAT_FW_COMN_REQ_FLAG_SET);
806         msg->pke_hdr.cd_pars.func_id = ctx->crt_mode ?
807                 qat_rsa_dec_fn_id_crt(ctx->key_sz) :
808                 qat_rsa_dec_fn_id(ctx->key_sz);
809         if (unlikely(!msg->pke_hdr.cd_pars.func_id))
810                 return -EINVAL;
811
812         qat_req->cb = qat_rsa_cb;
813         qat_req->ctx.rsa = ctx;
814         qat_req->areq.rsa = req;
815         msg->pke_hdr.service_type = ICP_QAT_FW_COMN_REQ_CPM_FW_PKE;
816         msg->pke_hdr.comn_req_flags =
817                 ICP_QAT_FW_COMN_FLAGS_BUILD(QAT_COMN_PTR_TYPE_FLAT,
818                                             QAT_COMN_CD_FLD_TYPE_64BIT_ADR);
819
820         if (ctx->crt_mode) {
821                 qat_req->in.rsa.dec_crt.p = ctx->dma_p;
822                 qat_req->in.rsa.dec_crt.q = ctx->dma_q;
823                 qat_req->in.rsa.dec_crt.dp = ctx->dma_dp;
824                 qat_req->in.rsa.dec_crt.dq = ctx->dma_dq;
825                 qat_req->in.rsa.dec_crt.qinv = ctx->dma_qinv;
826         } else {
827                 qat_req->in.rsa.dec.d = ctx->dma_d;
828                 qat_req->in.rsa.dec.n = ctx->dma_n;
829         }
830         ret = -ENOMEM;
831
832         /*
833          * src can be of any size in valid range, but HW expects it to be the
834          * same as modulo n so in case it is different we need to allocate a
835          * new buf and copy src data.
836          * In other case we just need to map the user provided buffer.
837          * Also need to make sure that it is in contiguous buffer.
838          */
839         if (sg_is_last(req->src) && req->src_len == ctx->key_sz) {
840                 qat_req->src_align = NULL;
841                 vaddr = sg_virt(req->src);
842         } else {
843                 int shift = ctx->key_sz - req->src_len;
844
845                 qat_req->src_align = kzalloc(ctx->key_sz, flags);
846                 if (unlikely(!qat_req->src_align))
847                         return ret;
848
849                 scatterwalk_map_and_copy(qat_req->src_align + shift, req->src,
850                                          0, req->src_len, 0);
851                 vaddr = qat_req->src_align;
852         }
853
854         qat_req->in.rsa.dec.c = dma_map_single(dev, vaddr, ctx->key_sz,
855                                                DMA_TO_DEVICE);
856         if (unlikely(dma_mapping_error(dev, qat_req->in.rsa.dec.c)))
857                 goto unmap_src;
858
859         if (sg_is_last(req->dst) && req->dst_len == ctx->key_sz) {
860                 qat_req->dst_align = NULL;
861                 vaddr = sg_virt(req->dst);
862         } else {
863                 qat_req->dst_align = kzalloc(ctx->key_sz, flags);
864                 if (unlikely(!qat_req->dst_align))
865                         goto unmap_src;
866                 vaddr = qat_req->dst_align;
867         }
868         qat_req->out.rsa.dec.m = dma_map_single(dev, vaddr, ctx->key_sz,
869                                                 DMA_FROM_DEVICE);
870         if (unlikely(dma_mapping_error(dev, qat_req->out.rsa.dec.m)))
871                 goto unmap_dst;
872
873         if (ctx->crt_mode)
874                 qat_req->in.rsa.in_tab[6] = 0;
875         else
876                 qat_req->in.rsa.in_tab[3] = 0;
877         qat_req->out.rsa.out_tab[1] = 0;
878         qat_req->phy_in = dma_map_single(dev, &qat_req->in.rsa,
879                                          sizeof(struct qat_rsa_input_params),
880                                          DMA_TO_DEVICE);
881         if (unlikely(dma_mapping_error(dev, qat_req->phy_in)))
882                 goto unmap_dst;
883
884         qat_req->phy_out = dma_map_single(dev, &qat_req->out.rsa,
885                                           sizeof(struct qat_rsa_output_params),
886                                           DMA_TO_DEVICE);
887         if (unlikely(dma_mapping_error(dev, qat_req->phy_out)))
888                 goto unmap_in_params;
889
890         msg->pke_mid.src_data_addr = qat_req->phy_in;
891         msg->pke_mid.dest_data_addr = qat_req->phy_out;
892         msg->pke_mid.opaque = (u64)(__force long)qat_req;
893         if (ctx->crt_mode)
894                 msg->input_param_count = 6;
895         else
896                 msg->input_param_count = 3;
897
898         msg->output_param_count = 1;
899
900         ret = qat_alg_send_asym_message(qat_req, inst, &req->base);
901         if (ret == -ENOSPC)
902                 goto unmap_all;
903
904         return ret;
905
906 unmap_all:
907         if (!dma_mapping_error(dev, qat_req->phy_out))
908                 dma_unmap_single(dev, qat_req->phy_out,
909                                  sizeof(struct qat_rsa_output_params),
910                                  DMA_TO_DEVICE);
911 unmap_in_params:
912         if (!dma_mapping_error(dev, qat_req->phy_in))
913                 dma_unmap_single(dev, qat_req->phy_in,
914                                  sizeof(struct qat_rsa_input_params),
915                                  DMA_TO_DEVICE);
916 unmap_dst:
917         if (!dma_mapping_error(dev, qat_req->out.rsa.dec.m))
918                 dma_unmap_single(dev, qat_req->out.rsa.dec.m,
919                                  ctx->key_sz, DMA_FROM_DEVICE);
920         kfree_sensitive(qat_req->dst_align);
921 unmap_src:
922         if (!dma_mapping_error(dev, qat_req->in.rsa.dec.c))
923                 dma_unmap_single(dev, qat_req->in.rsa.dec.c, ctx->key_sz,
924                                  DMA_TO_DEVICE);
925         kfree_sensitive(qat_req->src_align);
926         return ret;
927 }
928
929 static int qat_rsa_set_n(struct qat_rsa_ctx *ctx, const char *value,
930                          size_t vlen)
931 {
932         struct qat_crypto_instance *inst = ctx->inst;
933         struct device *dev = &GET_DEV(inst->accel_dev);
934         const char *ptr = value;
935         int ret;
936
937         while (!*ptr && vlen) {
938                 ptr++;
939                 vlen--;
940         }
941
942         ctx->key_sz = vlen;
943         ret = -EINVAL;
944         /* invalid key size provided */
945         if (!qat_rsa_enc_fn_id(ctx->key_sz))
946                 goto err;
947
948         ret = -ENOMEM;
949         ctx->n = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_n, GFP_KERNEL);
950         if (!ctx->n)
951                 goto err;
952
953         memcpy(ctx->n, ptr, ctx->key_sz);
954         return 0;
955 err:
956         ctx->key_sz = 0;
957         ctx->n = NULL;
958         return ret;
959 }
960
961 static int qat_rsa_set_e(struct qat_rsa_ctx *ctx, const char *value,
962                          size_t vlen)
963 {
964         struct qat_crypto_instance *inst = ctx->inst;
965         struct device *dev = &GET_DEV(inst->accel_dev);
966         const char *ptr = value;
967
968         while (!*ptr && vlen) {
969                 ptr++;
970                 vlen--;
971         }
972
973         if (!ctx->key_sz || !vlen || vlen > ctx->key_sz) {
974                 ctx->e = NULL;
975                 return -EINVAL;
976         }
977
978         ctx->e = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_e, GFP_KERNEL);
979         if (!ctx->e)
980                 return -ENOMEM;
981
982         memcpy(ctx->e + (ctx->key_sz - vlen), ptr, vlen);
983         return 0;
984 }
985
986 static int qat_rsa_set_d(struct qat_rsa_ctx *ctx, const char *value,
987                          size_t vlen)
988 {
989         struct qat_crypto_instance *inst = ctx->inst;
990         struct device *dev = &GET_DEV(inst->accel_dev);
991         const char *ptr = value;
992         int ret;
993
994         while (!*ptr && vlen) {
995                 ptr++;
996                 vlen--;
997         }
998
999         ret = -EINVAL;
1000         if (!ctx->key_sz || !vlen || vlen > ctx->key_sz)
1001                 goto err;
1002
1003         ret = -ENOMEM;
1004         ctx->d = dma_alloc_coherent(dev, ctx->key_sz, &ctx->dma_d, GFP_KERNEL);
1005         if (!ctx->d)
1006                 goto err;
1007
1008         memcpy(ctx->d + (ctx->key_sz - vlen), ptr, vlen);
1009         return 0;
1010 err:
1011         ctx->d = NULL;
1012         return ret;
1013 }
1014
1015 static void qat_rsa_drop_leading_zeros(const char **ptr, unsigned int *len)
1016 {
1017         while (!**ptr && *len) {
1018                 (*ptr)++;
1019                 (*len)--;
1020         }
1021 }
1022
1023 static void qat_rsa_setkey_crt(struct qat_rsa_ctx *ctx, struct rsa_key *rsa_key)
1024 {
1025         struct qat_crypto_instance *inst = ctx->inst;
1026         struct device *dev = &GET_DEV(inst->accel_dev);
1027         const char *ptr;
1028         unsigned int len;
1029         unsigned int half_key_sz = ctx->key_sz / 2;
1030
1031         /* p */
1032         ptr = rsa_key->p;
1033         len = rsa_key->p_sz;
1034         qat_rsa_drop_leading_zeros(&ptr, &len);
1035         if (!len)
1036                 goto err;
1037         ctx->p = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_p, GFP_KERNEL);
1038         if (!ctx->p)
1039                 goto err;
1040         memcpy(ctx->p + (half_key_sz - len), ptr, len);
1041
1042         /* q */
1043         ptr = rsa_key->q;
1044         len = rsa_key->q_sz;
1045         qat_rsa_drop_leading_zeros(&ptr, &len);
1046         if (!len)
1047                 goto free_p;
1048         ctx->q = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_q, GFP_KERNEL);
1049         if (!ctx->q)
1050                 goto free_p;
1051         memcpy(ctx->q + (half_key_sz - len), ptr, len);
1052
1053         /* dp */
1054         ptr = rsa_key->dp;
1055         len = rsa_key->dp_sz;
1056         qat_rsa_drop_leading_zeros(&ptr, &len);
1057         if (!len)
1058                 goto free_q;
1059         ctx->dp = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dp,
1060                                      GFP_KERNEL);
1061         if (!ctx->dp)
1062                 goto free_q;
1063         memcpy(ctx->dp + (half_key_sz - len), ptr, len);
1064
1065         /* dq */
1066         ptr = rsa_key->dq;
1067         len = rsa_key->dq_sz;
1068         qat_rsa_drop_leading_zeros(&ptr, &len);
1069         if (!len)
1070                 goto free_dp;
1071         ctx->dq = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_dq,
1072                                      GFP_KERNEL);
1073         if (!ctx->dq)
1074                 goto free_dp;
1075         memcpy(ctx->dq + (half_key_sz - len), ptr, len);
1076
1077         /* qinv */
1078         ptr = rsa_key->qinv;
1079         len = rsa_key->qinv_sz;
1080         qat_rsa_drop_leading_zeros(&ptr, &len);
1081         if (!len)
1082                 goto free_dq;
1083         ctx->qinv = dma_alloc_coherent(dev, half_key_sz, &ctx->dma_qinv,
1084                                        GFP_KERNEL);
1085         if (!ctx->qinv)
1086                 goto free_dq;
1087         memcpy(ctx->qinv + (half_key_sz - len), ptr, len);
1088
1089         ctx->crt_mode = true;
1090         return;
1091
1092 free_dq:
1093         memset(ctx->dq, '\0', half_key_sz);
1094         dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1095         ctx->dq = NULL;
1096 free_dp:
1097         memset(ctx->dp, '\0', half_key_sz);
1098         dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1099         ctx->dp = NULL;
1100 free_q:
1101         memset(ctx->q, '\0', half_key_sz);
1102         dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1103         ctx->q = NULL;
1104 free_p:
1105         memset(ctx->p, '\0', half_key_sz);
1106         dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1107         ctx->p = NULL;
1108 err:
1109         ctx->crt_mode = false;
1110 }
1111
1112 static void qat_rsa_clear_ctx(struct device *dev, struct qat_rsa_ctx *ctx)
1113 {
1114         unsigned int half_key_sz = ctx->key_sz / 2;
1115
1116         /* Free the old key if any */
1117         if (ctx->n)
1118                 dma_free_coherent(dev, ctx->key_sz, ctx->n, ctx->dma_n);
1119         if (ctx->e)
1120                 dma_free_coherent(dev, ctx->key_sz, ctx->e, ctx->dma_e);
1121         if (ctx->d) {
1122                 memset(ctx->d, '\0', ctx->key_sz);
1123                 dma_free_coherent(dev, ctx->key_sz, ctx->d, ctx->dma_d);
1124         }
1125         if (ctx->p) {
1126                 memset(ctx->p, '\0', half_key_sz);
1127                 dma_free_coherent(dev, half_key_sz, ctx->p, ctx->dma_p);
1128         }
1129         if (ctx->q) {
1130                 memset(ctx->q, '\0', half_key_sz);
1131                 dma_free_coherent(dev, half_key_sz, ctx->q, ctx->dma_q);
1132         }
1133         if (ctx->dp) {
1134                 memset(ctx->dp, '\0', half_key_sz);
1135                 dma_free_coherent(dev, half_key_sz, ctx->dp, ctx->dma_dp);
1136         }
1137         if (ctx->dq) {
1138                 memset(ctx->dq, '\0', half_key_sz);
1139                 dma_free_coherent(dev, half_key_sz, ctx->dq, ctx->dma_dq);
1140         }
1141         if (ctx->qinv) {
1142                 memset(ctx->qinv, '\0', half_key_sz);
1143                 dma_free_coherent(dev, half_key_sz, ctx->qinv, ctx->dma_qinv);
1144         }
1145
1146         ctx->n = NULL;
1147         ctx->e = NULL;
1148         ctx->d = NULL;
1149         ctx->p = NULL;
1150         ctx->q = NULL;
1151         ctx->dp = NULL;
1152         ctx->dq = NULL;
1153         ctx->qinv = NULL;
1154         ctx->crt_mode = false;
1155         ctx->key_sz = 0;
1156 }
1157
1158 static int qat_rsa_setkey(struct crypto_akcipher *tfm, const void *key,
1159                           unsigned int keylen, bool private)
1160 {
1161         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1162         struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1163         struct rsa_key rsa_key;
1164         int ret;
1165
1166         qat_rsa_clear_ctx(dev, ctx);
1167
1168         if (private)
1169                 ret = rsa_parse_priv_key(&rsa_key, key, keylen);
1170         else
1171                 ret = rsa_parse_pub_key(&rsa_key, key, keylen);
1172         if (ret < 0)
1173                 goto free;
1174
1175         ret = qat_rsa_set_n(ctx, rsa_key.n, rsa_key.n_sz);
1176         if (ret < 0)
1177                 goto free;
1178         ret = qat_rsa_set_e(ctx, rsa_key.e, rsa_key.e_sz);
1179         if (ret < 0)
1180                 goto free;
1181         if (private) {
1182                 ret = qat_rsa_set_d(ctx, rsa_key.d, rsa_key.d_sz);
1183                 if (ret < 0)
1184                         goto free;
1185                 qat_rsa_setkey_crt(ctx, &rsa_key);
1186         }
1187
1188         if (!ctx->n || !ctx->e) {
1189                 /* invalid key provided */
1190                 ret = -EINVAL;
1191                 goto free;
1192         }
1193         if (private && !ctx->d) {
1194                 /* invalid private key provided */
1195                 ret = -EINVAL;
1196                 goto free;
1197         }
1198
1199         return 0;
1200 free:
1201         qat_rsa_clear_ctx(dev, ctx);
1202         return ret;
1203 }
1204
1205 static int qat_rsa_setpubkey(struct crypto_akcipher *tfm, const void *key,
1206                              unsigned int keylen)
1207 {
1208         return qat_rsa_setkey(tfm, key, keylen, false);
1209 }
1210
1211 static int qat_rsa_setprivkey(struct crypto_akcipher *tfm, const void *key,
1212                               unsigned int keylen)
1213 {
1214         return qat_rsa_setkey(tfm, key, keylen, true);
1215 }
1216
1217 static unsigned int qat_rsa_max_size(struct crypto_akcipher *tfm)
1218 {
1219         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1220
1221         return ctx->key_sz;
1222 }
1223
1224 static int qat_rsa_init_tfm(struct crypto_akcipher *tfm)
1225 {
1226         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1227         struct qat_crypto_instance *inst =
1228                         qat_crypto_get_instance_node(numa_node_id());
1229
1230         if (!inst)
1231                 return -EINVAL;
1232
1233         ctx->key_sz = 0;
1234         ctx->inst = inst;
1235         return 0;
1236 }
1237
1238 static void qat_rsa_exit_tfm(struct crypto_akcipher *tfm)
1239 {
1240         struct qat_rsa_ctx *ctx = akcipher_tfm_ctx(tfm);
1241         struct device *dev = &GET_DEV(ctx->inst->accel_dev);
1242
1243         qat_rsa_clear_ctx(dev, ctx);
1244         qat_crypto_put_instance(ctx->inst);
1245 }
1246
1247 static struct akcipher_alg rsa = {
1248         .encrypt = qat_rsa_enc,
1249         .decrypt = qat_rsa_dec,
1250         .set_pub_key = qat_rsa_setpubkey,
1251         .set_priv_key = qat_rsa_setprivkey,
1252         .max_size = qat_rsa_max_size,
1253         .init = qat_rsa_init_tfm,
1254         .exit = qat_rsa_exit_tfm,
1255         .reqsize = sizeof(struct qat_asym_request) + 64,
1256         .base = {
1257                 .cra_name = "rsa",
1258                 .cra_driver_name = "qat-rsa",
1259                 .cra_priority = 1000,
1260                 .cra_module = THIS_MODULE,
1261                 .cra_ctxsize = sizeof(struct qat_rsa_ctx),
1262         },
1263 };
1264
1265 static struct kpp_alg dh = {
1266         .set_secret = qat_dh_set_secret,
1267         .generate_public_key = qat_dh_compute_value,
1268         .compute_shared_secret = qat_dh_compute_value,
1269         .max_size = qat_dh_max_size,
1270         .init = qat_dh_init_tfm,
1271         .exit = qat_dh_exit_tfm,
1272         .reqsize = sizeof(struct qat_asym_request) + 64,
1273         .base = {
1274                 .cra_name = "dh",
1275                 .cra_driver_name = "qat-dh",
1276                 .cra_priority = 1000,
1277                 .cra_module = THIS_MODULE,
1278                 .cra_ctxsize = sizeof(struct qat_dh_ctx),
1279         },
1280 };
1281
1282 int qat_asym_algs_register(void)
1283 {
1284         int ret = 0;
1285
1286         mutex_lock(&algs_lock);
1287         if (++active_devs == 1) {
1288                 rsa.base.cra_flags = 0;
1289                 ret = crypto_register_akcipher(&rsa);
1290                 if (ret)
1291                         goto unlock;
1292                 ret = crypto_register_kpp(&dh);
1293         }
1294 unlock:
1295         mutex_unlock(&algs_lock);
1296         return ret;
1297 }
1298
1299 void qat_asym_algs_unregister(void)
1300 {
1301         mutex_lock(&algs_lock);
1302         if (--active_devs == 0) {
1303                 crypto_unregister_akcipher(&rsa);
1304                 crypto_unregister_kpp(&dh);
1305         }
1306         mutex_unlock(&algs_lock);
1307 }