Commit | Line | Data |
---|---|---|
b9873755 AG |
1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* | |
3 | * Amazon Nitro Secure Module driver. | |
4 | * | |
5 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
6 | * | |
7 | * The Nitro Secure Module implements commands via CBOR over virtio. | |
8 | * This driver exposes a raw message ioctls on /dev/nsm that user | |
9 | * space can use to issue these commands. | |
10 | */ | |
11 | ||
12 | #include <linux/file.h> | |
13 | #include <linux/fs.h> | |
14 | #include <linux/interrupt.h> | |
15 | #include <linux/hw_random.h> | |
16 | #include <linux/miscdevice.h> | |
17 | #include <linux/module.h> | |
18 | #include <linux/mutex.h> | |
19 | #include <linux/slab.h> | |
20 | #include <linux/string.h> | |
21 | #include <linux/uaccess.h> | |
22 | #include <linux/uio.h> | |
23 | #include <linux/virtio_config.h> | |
24 | #include <linux/virtio_ids.h> | |
25 | #include <linux/virtio.h> | |
26 | #include <linux/wait.h> | |
27 | #include <uapi/linux/nsm.h> | |
28 | ||
29 | /* Timeout for NSM virtqueue respose in milliseconds. */ | |
30 | #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */ | |
31 | ||
32 | /* Maximum length input data */ | |
33 | struct nsm_data_req { | |
34 | u32 len; | |
35 | u8 data[NSM_REQUEST_MAX_SIZE]; | |
36 | }; | |
37 | ||
38 | /* Maximum length output data */ | |
39 | struct nsm_data_resp { | |
40 | u32 len; | |
41 | u8 data[NSM_RESPONSE_MAX_SIZE]; | |
42 | }; | |
43 | ||
44 | /* Full NSM request/response message */ | |
45 | struct nsm_msg { | |
46 | struct nsm_data_req req; | |
47 | struct nsm_data_resp resp; | |
48 | }; | |
49 | ||
50 | struct nsm { | |
51 | struct virtio_device *vdev; | |
52 | struct virtqueue *vq; | |
53 | struct mutex lock; | |
54 | struct completion cmd_done; | |
55 | struct miscdevice misc; | |
56 | struct hwrng hwrng; | |
57 | struct work_struct misc_init; | |
58 | struct nsm_msg msg; | |
59 | }; | |
60 | ||
61 | /* NSM device ID */ | |
62 | static const struct virtio_device_id id_table[] = { | |
63 | { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID }, | |
64 | { 0 }, | |
65 | }; | |
66 | ||
67 | static struct nsm *file_to_nsm(struct file *file) | |
68 | { | |
69 | return container_of(file->private_data, struct nsm, misc); | |
70 | } | |
71 | ||
72 | static struct nsm *hwrng_to_nsm(struct hwrng *rng) | |
73 | { | |
74 | return container_of(rng, struct nsm, hwrng); | |
75 | } | |
76 | ||
77 | #define CBOR_TYPE_MASK 0xE0 | |
78 | #define CBOR_TYPE_MAP 0xA0 | |
79 | #define CBOR_TYPE_TEXT 0x60 | |
80 | #define CBOR_TYPE_ARRAY 0x40 | |
81 | #define CBOR_HEADER_SIZE_SHORT 1 | |
82 | ||
83 | #define CBOR_SHORT_SIZE_MAX_VALUE 23 | |
84 | #define CBOR_LONG_SIZE_U8 24 | |
85 | #define CBOR_LONG_SIZE_U16 25 | |
86 | #define CBOR_LONG_SIZE_U32 26 | |
87 | #define CBOR_LONG_SIZE_U64 27 | |
88 | ||
89 | static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size) | |
90 | { | |
91 | if (cbor_object_size == 0 || cbor_object == NULL) | |
92 | return false; | |
93 | ||
94 | return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY; | |
95 | } | |
96 | ||
97 | static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array) | |
98 | { | |
99 | u8 cbor_short_size; | |
100 | void *array_len_p; | |
101 | u64 array_len; | |
102 | u64 array_offset; | |
103 | ||
104 | if (!cbor_object_is_array(cbor_object, cbor_object_size)) | |
105 | return -EFAULT; | |
106 | ||
107 | cbor_short_size = (cbor_object[0] & 0x1F); | |
108 | ||
109 | /* Decoding byte array length */ | |
110 | array_offset = CBOR_HEADER_SIZE_SHORT; | |
111 | if (cbor_short_size >= CBOR_LONG_SIZE_U8) | |
112 | array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8); | |
113 | ||
114 | if (cbor_object_size < array_offset) | |
115 | return -EFAULT; | |
116 | ||
117 | array_len_p = &cbor_object[1]; | |
118 | ||
119 | switch (cbor_short_size) { | |
120 | case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */ | |
121 | array_len = cbor_short_size; | |
122 | break; | |
123 | case CBOR_LONG_SIZE_U8: | |
124 | array_len = *(u8 *)array_len_p; | |
125 | break; | |
126 | case CBOR_LONG_SIZE_U16: | |
127 | array_len = be16_to_cpup((__be16 *)array_len_p); | |
128 | break; | |
129 | case CBOR_LONG_SIZE_U32: | |
130 | array_len = be32_to_cpup((__be32 *)array_len_p); | |
131 | break; | |
132 | case CBOR_LONG_SIZE_U64: | |
133 | array_len = be64_to_cpup((__be64 *)array_len_p); | |
134 | break; | |
135 | } | |
136 | ||
137 | if (cbor_object_size < array_offset) | |
138 | return -EFAULT; | |
139 | ||
140 | if (cbor_object_size - array_offset < array_len) | |
141 | return -EFAULT; | |
142 | ||
143 | if (array_len > INT_MAX) | |
144 | return -EFAULT; | |
145 | ||
146 | *cbor_array = cbor_object + array_offset; | |
147 | return array_len; | |
148 | } | |
149 | ||
150 | /* Copy the request of a raw message to kernel space */ | |
151 | static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req, | |
152 | struct nsm_raw *raw) | |
153 | { | |
154 | /* Verify the user input size. */ | |
155 | if (raw->request.len > sizeof(req->data)) | |
156 | return -EMSGSIZE; | |
157 | ||
158 | /* Copy the request payload */ | |
159 | if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr), | |
160 | raw->request.len)) | |
161 | return -EFAULT; | |
162 | ||
163 | req->len = raw->request.len; | |
164 | ||
165 | return 0; | |
166 | } | |
167 | ||
168 | /* Copy the response of a raw message back to user-space */ | |
169 | static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp, | |
170 | struct nsm_raw *raw) | |
171 | { | |
172 | /* Truncate any message that does not fit. */ | |
173 | raw->response.len = min_t(u64, raw->response.len, resp->len); | |
174 | ||
175 | /* Copy the response content to user space */ | |
176 | if (copy_to_user(u64_to_user_ptr(raw->response.addr), | |
177 | resp->data, raw->response.len)) | |
178 | return -EFAULT; | |
179 | ||
180 | return 0; | |
181 | } | |
182 | ||
183 | /* Virtqueue interrupt handler */ | |
184 | static void nsm_vq_callback(struct virtqueue *vq) | |
185 | { | |
186 | struct nsm *nsm = vq->vdev->priv; | |
187 | ||
188 | complete(&nsm->cmd_done); | |
189 | } | |
190 | ||
191 | /* Forward a message to the NSM device and wait for the response from it */ | |
192 | static int nsm_sendrecv_msg_locked(struct nsm *nsm) | |
193 | { | |
194 | struct device *dev = &nsm->vdev->dev; | |
195 | struct scatterlist sg_in, sg_out; | |
196 | struct nsm_msg *msg = &nsm->msg; | |
197 | struct virtqueue *vq = nsm->vq; | |
198 | unsigned int len; | |
199 | void *queue_buf; | |
200 | bool kicked; | |
201 | int rc; | |
202 | ||
203 | /* Initialize scatter-gather lists with request and response buffers. */ | |
204 | sg_init_one(&sg_out, msg->req.data, msg->req.len); | |
205 | sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data)); | |
206 | ||
207 | init_completion(&nsm->cmd_done); | |
208 | /* Add the request buffer (read by the device). */ | |
209 | rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL); | |
210 | if (rc) | |
211 | return rc; | |
212 | ||
213 | /* Add the response buffer (written by the device). */ | |
214 | rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL); | |
215 | if (rc) | |
216 | goto cleanup; | |
217 | ||
218 | kicked = virtqueue_kick(vq); | |
219 | if (!kicked) { | |
220 | /* Cannot kick the virtqueue. */ | |
221 | rc = -EIO; | |
222 | goto cleanup; | |
223 | } | |
224 | ||
225 | /* If the kick succeeded, wait for the device's response. */ | |
226 | if (!wait_for_completion_io_timeout(&nsm->cmd_done, | |
227 | msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) { | |
228 | rc = -ETIMEDOUT; | |
229 | goto cleanup; | |
230 | } | |
231 | ||
232 | queue_buf = virtqueue_get_buf(vq, &len); | |
233 | if (!queue_buf || (queue_buf != msg->req.data)) { | |
234 | dev_err(dev, "wrong request buffer."); | |
235 | rc = -ENODATA; | |
236 | goto cleanup; | |
237 | } | |
238 | ||
239 | queue_buf = virtqueue_get_buf(vq, &len); | |
240 | if (!queue_buf || (queue_buf != msg->resp.data)) { | |
241 | dev_err(dev, "wrong response buffer."); | |
242 | rc = -ENODATA; | |
243 | goto cleanup; | |
244 | } | |
245 | ||
246 | msg->resp.len = len; | |
247 | ||
248 | rc = 0; | |
249 | ||
250 | cleanup: | |
251 | if (rc) { | |
252 | /* Clean the virtqueue. */ | |
253 | while (virtqueue_get_buf(vq, &len) != NULL) | |
254 | ; | |
255 | } | |
256 | ||
257 | return rc; | |
258 | } | |
259 | ||
260 | static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req) | |
261 | { | |
262 | /* | |
263 | * 69 # text(9) | |
264 | * 47657452616E646F6D # "GetRandom" | |
265 | */ | |
266 | const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"), | |
267 | 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' }; | |
268 | ||
269 | memcpy(req->data, request, sizeof(request)); | |
270 | req->len = sizeof(request); | |
271 | ||
272 | return 0; | |
273 | } | |
274 | ||
275 | static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp, | |
276 | void *out, size_t max) | |
277 | { | |
278 | /* | |
279 | * A1 # map(1) | |
280 | * 69 # text(9) - Name of field | |
281 | * 47657452616E646F6D # "GetRandom" | |
282 | * A1 # map(1) - The field itself | |
283 | * 66 # text(6) | |
284 | * 72616E646F6D # "random" | |
285 | * # The rest of the response is random data | |
286 | */ | |
287 | const u8 response[] = { CBOR_TYPE_MAP + 1, | |
288 | CBOR_TYPE_TEXT + strlen("GetRandom"), | |
289 | 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm', | |
290 | CBOR_TYPE_MAP + 1, | |
291 | CBOR_TYPE_TEXT + strlen("random"), | |
292 | 'r', 'a', 'n', 'd', 'o', 'm' }; | |
293 | struct device *dev = &nsm->vdev->dev; | |
294 | u8 *rand_data = NULL; | |
295 | u8 *resp_ptr = resp->data; | |
296 | u64 resp_len = resp->len; | |
297 | int rc; | |
298 | ||
299 | if ((resp->len < sizeof(response) + 1) || | |
300 | (memcmp(resp_ptr, response, sizeof(response)) != 0)) { | |
301 | dev_err(dev, "Invalid response for GetRandom"); | |
302 | return -EFAULT; | |
303 | } | |
304 | ||
305 | resp_ptr += sizeof(response); | |
306 | resp_len -= sizeof(response); | |
307 | ||
308 | rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data); | |
309 | if (rc < 0) { | |
310 | dev_err(dev, "GetRandom: Invalid CBOR encoding\n"); | |
311 | return rc; | |
312 | } | |
313 | ||
314 | rc = min_t(size_t, rc, max); | |
315 | memcpy(out, rand_data, rc); | |
316 | ||
317 | return rc; | |
318 | } | |
319 | ||
320 | /* | |
321 | * HwRNG implementation | |
322 | */ | |
323 | static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait) | |
324 | { | |
325 | struct nsm *nsm = hwrng_to_nsm(rng); | |
326 | struct device *dev = &nsm->vdev->dev; | |
327 | int rc = 0; | |
328 | ||
329 | /* NSM always needs to wait for a response */ | |
330 | if (!wait) | |
331 | return 0; | |
332 | ||
333 | mutex_lock(&nsm->lock); | |
334 | ||
335 | rc = fill_req_get_random(nsm, &nsm->msg.req); | |
336 | if (rc != 0) | |
337 | goto out; | |
338 | ||
339 | rc = nsm_sendrecv_msg_locked(nsm); | |
340 | if (rc != 0) | |
341 | goto out; | |
342 | ||
343 | rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max); | |
344 | if (rc < 0) | |
345 | goto out; | |
346 | ||
347 | dev_dbg(dev, "RNG: returning rand bytes = %d", rc); | |
348 | out: | |
349 | mutex_unlock(&nsm->lock); | |
350 | return rc; | |
351 | } | |
352 | ||
353 | static long nsm_dev_ioctl(struct file *file, unsigned int cmd, | |
354 | unsigned long arg) | |
355 | { | |
356 | void __user *argp = u64_to_user_ptr((u64)arg); | |
357 | struct nsm *nsm = file_to_nsm(file); | |
358 | struct nsm_raw raw; | |
359 | int r = 0; | |
360 | ||
361 | if (cmd != NSM_IOCTL_RAW) | |
362 | return -EINVAL; | |
363 | ||
364 | if (_IOC_SIZE(cmd) != sizeof(raw)) | |
365 | return -EINVAL; | |
366 | ||
367 | /* Copy user argument struct to kernel argument struct */ | |
368 | r = -EFAULT; | |
369 | if (copy_from_user(&raw, argp, _IOC_SIZE(cmd))) | |
370 | goto out; | |
371 | ||
372 | mutex_lock(&nsm->lock); | |
373 | ||
374 | /* Convert kernel argument struct to device request */ | |
375 | r = fill_req_raw(nsm, &nsm->msg.req, &raw); | |
376 | if (r) | |
377 | goto out; | |
378 | ||
379 | /* Send message to NSM and read reply */ | |
380 | r = nsm_sendrecv_msg_locked(nsm); | |
381 | if (r) | |
382 | goto out; | |
383 | ||
384 | /* Parse device response into kernel argument struct */ | |
385 | r = parse_resp_raw(nsm, &nsm->msg.resp, &raw); | |
386 | if (r) | |
387 | goto out; | |
388 | ||
389 | /* Copy kernel argument struct back to user argument struct */ | |
390 | r = -EFAULT; | |
391 | if (copy_to_user(argp, &raw, sizeof(raw))) | |
392 | goto out; | |
393 | ||
394 | r = 0; | |
395 | ||
396 | out: | |
397 | mutex_unlock(&nsm->lock); | |
398 | return r; | |
399 | } | |
400 | ||
401 | static int nsm_device_init_vq(struct virtio_device *vdev) | |
402 | { | |
403 | struct virtqueue *vq = virtio_find_single_vq(vdev, | |
404 | nsm_vq_callback, "nsm.vq.0"); | |
405 | struct nsm *nsm = vdev->priv; | |
406 | ||
407 | if (IS_ERR(vq)) | |
408 | return PTR_ERR(vq); | |
409 | ||
410 | nsm->vq = vq; | |
411 | ||
412 | return 0; | |
413 | } | |
414 | ||
415 | static const struct file_operations nsm_dev_fops = { | |
416 | .unlocked_ioctl = nsm_dev_ioctl, | |
417 | .compat_ioctl = compat_ptr_ioctl, | |
418 | }; | |
419 | ||
420 | /* Handler for probing the NSM device */ | |
421 | static int nsm_device_probe(struct virtio_device *vdev) | |
422 | { | |
423 | struct device *dev = &vdev->dev; | |
424 | struct nsm *nsm; | |
425 | int rc; | |
426 | ||
427 | nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL); | |
428 | if (!nsm) | |
429 | return -ENOMEM; | |
430 | ||
431 | vdev->priv = nsm; | |
432 | nsm->vdev = vdev; | |
433 | ||
434 | rc = nsm_device_init_vq(vdev); | |
435 | if (rc) { | |
436 | dev_err(dev, "queue failed to initialize: %d.\n", rc); | |
437 | goto err_init_vq; | |
438 | } | |
439 | ||
440 | mutex_init(&nsm->lock); | |
441 | ||
442 | /* Register as hwrng provider */ | |
443 | nsm->hwrng = (struct hwrng) { | |
444 | .read = nsm_rng_read, | |
445 | .name = "nsm-hwrng", | |
446 | .quality = 1000, | |
447 | }; | |
448 | ||
449 | rc = hwrng_register(&nsm->hwrng); | |
450 | if (rc) { | |
451 | dev_err(dev, "RNG initialization error: %d.\n", rc); | |
452 | goto err_hwrng; | |
453 | } | |
454 | ||
455 | /* Register /dev/nsm device node */ | |
456 | nsm->misc = (struct miscdevice) { | |
457 | .minor = MISC_DYNAMIC_MINOR, | |
458 | .name = "nsm", | |
459 | .fops = &nsm_dev_fops, | |
460 | .mode = 0666, | |
461 | }; | |
462 | ||
463 | rc = misc_register(&nsm->misc); | |
464 | if (rc) { | |
465 | dev_err(dev, "misc device registration error: %d.\n", rc); | |
466 | goto err_misc; | |
467 | } | |
468 | ||
469 | return 0; | |
470 | ||
471 | err_misc: | |
472 | hwrng_unregister(&nsm->hwrng); | |
473 | err_hwrng: | |
474 | vdev->config->del_vqs(vdev); | |
475 | err_init_vq: | |
476 | return rc; | |
477 | } | |
478 | ||
479 | /* Handler for removing the NSM device */ | |
480 | static void nsm_device_remove(struct virtio_device *vdev) | |
481 | { | |
482 | struct nsm *nsm = vdev->priv; | |
483 | ||
484 | hwrng_unregister(&nsm->hwrng); | |
485 | ||
486 | vdev->config->del_vqs(vdev); | |
487 | misc_deregister(&nsm->misc); | |
488 | } | |
489 | ||
490 | /* NSM device configuration structure */ | |
491 | static struct virtio_driver virtio_nsm_driver = { | |
492 | .feature_table = 0, | |
493 | .feature_table_size = 0, | |
494 | .feature_table_legacy = 0, | |
495 | .feature_table_size_legacy = 0, | |
496 | .driver.name = KBUILD_MODNAME, | |
497 | .driver.owner = THIS_MODULE, | |
498 | .id_table = id_table, | |
499 | .probe = nsm_device_probe, | |
500 | .remove = nsm_device_remove, | |
501 | }; | |
502 | ||
503 | module_virtio_driver(virtio_nsm_driver); | |
504 | MODULE_DEVICE_TABLE(virtio, id_table); | |
505 | MODULE_DESCRIPTION("Virtio NSM driver"); | |
506 | MODULE_LICENSE("GPL"); |