rxrpc: Fix potential race in error handling in afs_make_call()
[linux-block.git] / net / rxrpc / rxperf.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* In-kernel rxperf server for testing purposes.
3  *
4  * Copyright (C) 2022 Red Hat, Inc. All Rights Reserved.
5  * Written by David Howells (dhowells@redhat.com)
6  */
7
8 #define pr_fmt(fmt) "rxperf: " fmt
9 #include <linux/module.h>
10 #include <linux/slab.h>
11 #include <net/sock.h>
12 #include <net/af_rxrpc.h>
13 #define RXRPC_TRACE_ONLY_DEFINE_ENUMS
14 #include <trace/events/rxrpc.h>
15
16 MODULE_DESCRIPTION("rxperf test server (afs)");
17 MODULE_AUTHOR("Red Hat, Inc.");
18 MODULE_LICENSE("GPL");
19
20 #define RXPERF_PORT             7009
21 #define RX_PERF_SERVICE         147
22 #define RX_PERF_VERSION         3
23 #define RX_PERF_SEND            0
24 #define RX_PERF_RECV            1
25 #define RX_PERF_RPC             3
26 #define RX_PERF_FILE            4
27 #define RX_PERF_MAGIC_COOKIE    0x4711
28
29 struct rxperf_proto_params {
30         __be32          version;
31         __be32          type;
32         __be32          rsize;
33         __be32          wsize;
34 } __packed;
35
36 static const u8 rxperf_magic_cookie[] = { 0x00, 0x00, 0x47, 0x11 };
37 static const u8 secret[8] = { 0xa7, 0x83, 0x8a, 0xcb, 0xc7, 0x83, 0xec, 0x94 };
38
39 enum rxperf_call_state {
40         RXPERF_CALL_SV_AWAIT_PARAMS,    /* Server: Awaiting parameter block */
41         RXPERF_CALL_SV_AWAIT_REQUEST,   /* Server: Awaiting request data */
42         RXPERF_CALL_SV_REPLYING,        /* Server: Replying */
43         RXPERF_CALL_SV_AWAIT_ACK,       /* Server: Awaiting final ACK */
44         RXPERF_CALL_COMPLETE,           /* Completed or failed */
45 };
46
47 struct rxperf_call {
48         struct rxrpc_call       *rxcall;
49         struct iov_iter         iter;
50         struct kvec             kvec[1];
51         struct work_struct      work;
52         const char              *type;
53         size_t                  iov_len;
54         size_t                  req_len;        /* Size of request blob */
55         size_t                  reply_len;      /* Size of reply blob */
56         unsigned int            debug_id;
57         unsigned int            operation_id;
58         struct rxperf_proto_params params;
59         __be32                  tmp[2];
60         s32                     abort_code;
61         enum rxperf_call_state  state;
62         short                   error;
63         unsigned short          unmarshal;
64         u16                     service_id;
65         int (*deliver)(struct rxperf_call *call);
66         void (*processor)(struct work_struct *work);
67 };
68
69 static struct socket *rxperf_socket;
70 static struct key *rxperf_sec_keyring;  /* Ring of security/crypto keys */
71 static struct workqueue_struct *rxperf_workqueue;
72
73 static void rxperf_deliver_to_call(struct work_struct *work);
74 static int rxperf_deliver_param_block(struct rxperf_call *call);
75 static int rxperf_deliver_request(struct rxperf_call *call);
76 static int rxperf_process_call(struct rxperf_call *call);
77 static void rxperf_charge_preallocation(struct work_struct *work);
78
79 static DECLARE_WORK(rxperf_charge_preallocation_work,
80                     rxperf_charge_preallocation);
81
82 static inline void rxperf_set_call_state(struct rxperf_call *call,
83                                          enum rxperf_call_state to)
84 {
85         call->state = to;
86 }
87
88 static inline void rxperf_set_call_complete(struct rxperf_call *call,
89                                             int error, s32 remote_abort)
90 {
91         if (call->state != RXPERF_CALL_COMPLETE) {
92                 call->abort_code = remote_abort;
93                 call->error = error;
94                 call->state = RXPERF_CALL_COMPLETE;
95         }
96 }
97
98 static void rxperf_rx_discard_new_call(struct rxrpc_call *rxcall,
99                                        unsigned long user_call_ID)
100 {
101         kfree((struct rxperf_call *)user_call_ID);
102 }
103
104 static void rxperf_rx_new_call(struct sock *sk, struct rxrpc_call *rxcall,
105                                unsigned long user_call_ID)
106 {
107         queue_work(rxperf_workqueue, &rxperf_charge_preallocation_work);
108 }
109
110 static void rxperf_queue_call_work(struct rxperf_call *call)
111 {
112         queue_work(rxperf_workqueue, &call->work);
113 }
114
115 static void rxperf_notify_rx(struct sock *sk, struct rxrpc_call *rxcall,
116                              unsigned long call_user_ID)
117 {
118         struct rxperf_call *call = (struct rxperf_call *)call_user_ID;
119
120         if (call->state != RXPERF_CALL_COMPLETE)
121                 rxperf_queue_call_work(call);
122 }
123
124 static void rxperf_rx_attach(struct rxrpc_call *rxcall, unsigned long user_call_ID)
125 {
126         struct rxperf_call *call = (struct rxperf_call *)user_call_ID;
127
128         call->rxcall = rxcall;
129 }
130
131 static void rxperf_notify_end_reply_tx(struct sock *sock,
132                                        struct rxrpc_call *rxcall,
133                                        unsigned long call_user_ID)
134 {
135         rxperf_set_call_state((struct rxperf_call *)call_user_ID,
136                               RXPERF_CALL_SV_AWAIT_ACK);
137 }
138
139 /*
140  * Charge the incoming call preallocation.
141  */
142 static void rxperf_charge_preallocation(struct work_struct *work)
143 {
144         struct rxperf_call *call;
145
146         for (;;) {
147                 call = kzalloc(sizeof(*call), GFP_KERNEL);
148                 if (!call)
149                         break;
150
151                 call->type              = "unset";
152                 call->debug_id          = atomic_inc_return(&rxrpc_debug_id);
153                 call->deliver           = rxperf_deliver_param_block;
154                 call->state             = RXPERF_CALL_SV_AWAIT_PARAMS;
155                 call->service_id        = RX_PERF_SERVICE;
156                 call->iov_len           = sizeof(call->params);
157                 call->kvec[0].iov_len   = sizeof(call->params);
158                 call->kvec[0].iov_base  = &call->params;
159                 iov_iter_kvec(&call->iter, READ, call->kvec, 1, call->iov_len);
160                 INIT_WORK(&call->work, rxperf_deliver_to_call);
161
162                 if (rxrpc_kernel_charge_accept(rxperf_socket,
163                                                rxperf_notify_rx,
164                                                rxperf_rx_attach,
165                                                (unsigned long)call,
166                                                GFP_KERNEL,
167                                                call->debug_id) < 0)
168                         break;
169                 call = NULL;
170         }
171
172         kfree(call);
173 }
174
175 /*
176  * Open an rxrpc socket and bind it to be a server for callback notifications
177  * - the socket is left in blocking mode and non-blocking ops use MSG_DONTWAIT
178  */
179 static int rxperf_open_socket(void)
180 {
181         struct sockaddr_rxrpc srx;
182         struct socket *socket;
183         int ret;
184
185         ret = sock_create_kern(&init_net, AF_RXRPC, SOCK_DGRAM, PF_INET6,
186                                &socket);
187         if (ret < 0)
188                 goto error_1;
189
190         socket->sk->sk_allocation = GFP_NOFS;
191
192         /* bind the callback manager's address to make this a server socket */
193         memset(&srx, 0, sizeof(srx));
194         srx.srx_family                  = AF_RXRPC;
195         srx.srx_service                 = RX_PERF_SERVICE;
196         srx.transport_type              = SOCK_DGRAM;
197         srx.transport_len               = sizeof(srx.transport.sin6);
198         srx.transport.sin6.sin6_family  = AF_INET6;
199         srx.transport.sin6.sin6_port    = htons(RXPERF_PORT);
200
201         ret = rxrpc_sock_set_min_security_level(socket->sk,
202                                                 RXRPC_SECURITY_ENCRYPT);
203         if (ret < 0)
204                 goto error_2;
205
206         ret = rxrpc_sock_set_security_keyring(socket->sk, rxperf_sec_keyring);
207
208         ret = kernel_bind(socket, (struct sockaddr *)&srx, sizeof(srx));
209         if (ret < 0)
210                 goto error_2;
211
212         rxrpc_kernel_new_call_notification(socket, rxperf_rx_new_call,
213                                            rxperf_rx_discard_new_call);
214
215         ret = kernel_listen(socket, INT_MAX);
216         if (ret < 0)
217                 goto error_2;
218
219         rxperf_socket = socket;
220         rxperf_charge_preallocation(&rxperf_charge_preallocation_work);
221         return 0;
222
223 error_2:
224         sock_release(socket);
225 error_1:
226         pr_err("Can't set up rxperf socket: %d\n", ret);
227         return ret;
228 }
229
230 /*
231  * close the rxrpc socket rxperf was using
232  */
233 static void rxperf_close_socket(void)
234 {
235         kernel_listen(rxperf_socket, 0);
236         kernel_sock_shutdown(rxperf_socket, SHUT_RDWR);
237         flush_workqueue(rxperf_workqueue);
238         sock_release(rxperf_socket);
239 }
240
241 /*
242  * Log remote abort codes that indicate that we have a protocol disagreement
243  * with the server.
244  */
245 static void rxperf_log_error(struct rxperf_call *call, s32 remote_abort)
246 {
247         static int max = 0;
248         const char *msg;
249         int m;
250
251         switch (remote_abort) {
252         case RX_EOF:             msg = "unexpected EOF";        break;
253         case RXGEN_CC_MARSHAL:   msg = "client marshalling";    break;
254         case RXGEN_CC_UNMARSHAL: msg = "client unmarshalling";  break;
255         case RXGEN_SS_MARSHAL:   msg = "server marshalling";    break;
256         case RXGEN_SS_UNMARSHAL: msg = "server unmarshalling";  break;
257         case RXGEN_DECODE:       msg = "opcode decode";         break;
258         case RXGEN_SS_XDRFREE:   msg = "server XDR cleanup";    break;
259         case RXGEN_CC_XDRFREE:   msg = "client XDR cleanup";    break;
260         case -32:                msg = "insufficient data";     break;
261         default:
262                 return;
263         }
264
265         m = max;
266         if (m < 3) {
267                 max = m + 1;
268                 pr_info("Peer reported %s failure on %s\n", msg, call->type);
269         }
270 }
271
272 /*
273  * deliver messages to a call
274  */
275 static void rxperf_deliver_to_call(struct work_struct *work)
276 {
277         struct rxperf_call *call = container_of(work, struct rxperf_call, work);
278         enum rxperf_call_state state;
279         u32 abort_code, remote_abort = 0;
280         int ret = 0;
281
282         if (call->state == RXPERF_CALL_COMPLETE)
283                 return;
284
285         while (state = call->state,
286                state == RXPERF_CALL_SV_AWAIT_PARAMS ||
287                state == RXPERF_CALL_SV_AWAIT_REQUEST ||
288                state == RXPERF_CALL_SV_AWAIT_ACK
289                ) {
290                 if (state == RXPERF_CALL_SV_AWAIT_ACK) {
291                         if (!rxrpc_kernel_check_life(rxperf_socket, call->rxcall))
292                                 goto call_complete;
293                         return;
294                 }
295
296                 ret = call->deliver(call);
297                 if (ret == 0)
298                         ret = rxperf_process_call(call);
299
300                 switch (ret) {
301                 case 0:
302                         continue;
303                 case -EINPROGRESS:
304                 case -EAGAIN:
305                         return;
306                 case -ECONNABORTED:
307                         rxperf_log_error(call, call->abort_code);
308                         goto call_complete;
309                 case -EOPNOTSUPP:
310                         abort_code = RXGEN_OPCODE;
311                         rxrpc_kernel_abort_call(rxperf_socket, call->rxcall,
312                                                 abort_code, ret,
313                                                 rxperf_abort_op_not_supported);
314                         goto call_complete;
315                 case -ENOTSUPP:
316                         abort_code = RX_USER_ABORT;
317                         rxrpc_kernel_abort_call(rxperf_socket, call->rxcall,
318                                                 abort_code, ret,
319                                                 rxperf_abort_op_not_supported);
320                         goto call_complete;
321                 case -EIO:
322                         pr_err("Call %u in bad state %u\n",
323                                call->debug_id, call->state);
324                         fallthrough;
325                 case -ENODATA:
326                 case -EBADMSG:
327                 case -EMSGSIZE:
328                 case -ENOMEM:
329                 case -EFAULT:
330                         rxrpc_kernel_abort_call(rxperf_socket, call->rxcall,
331                                                 RXGEN_SS_UNMARSHAL, ret,
332                                                 rxperf_abort_unmarshal_error);
333                         goto call_complete;
334                 default:
335                         rxrpc_kernel_abort_call(rxperf_socket, call->rxcall,
336                                                 RX_CALL_DEAD, ret,
337                                                 rxperf_abort_general_error);
338                         goto call_complete;
339                 }
340         }
341
342 call_complete:
343         rxperf_set_call_complete(call, ret, remote_abort);
344         /* The call may have been requeued */
345         rxrpc_kernel_shutdown_call(rxperf_socket, call->rxcall);
346         rxrpc_kernel_put_call(rxperf_socket, call->rxcall);
347         cancel_work(&call->work);
348         kfree(call);
349 }
350
351 /*
352  * Extract a piece of data from the received data socket buffers.
353  */
354 static int rxperf_extract_data(struct rxperf_call *call, bool want_more)
355 {
356         u32 remote_abort = 0;
357         int ret;
358
359         ret = rxrpc_kernel_recv_data(rxperf_socket, call->rxcall, &call->iter,
360                                      &call->iov_len, want_more, &remote_abort,
361                                      &call->service_id);
362         pr_debug("Extract i=%zu l=%zu m=%u ret=%d\n",
363                  iov_iter_count(&call->iter), call->iov_len, want_more, ret);
364         if (ret == 0 || ret == -EAGAIN)
365                 return ret;
366
367         if (ret == 1) {
368                 switch (call->state) {
369                 case RXPERF_CALL_SV_AWAIT_REQUEST:
370                         rxperf_set_call_state(call, RXPERF_CALL_SV_REPLYING);
371                         break;
372                 case RXPERF_CALL_COMPLETE:
373                         pr_debug("premature completion %d", call->error);
374                         return call->error;
375                 default:
376                         break;
377                 }
378                 return 0;
379         }
380
381         rxperf_set_call_complete(call, ret, remote_abort);
382         return ret;
383 }
384
385 /*
386  * Grab the operation ID from an incoming manager call.
387  */
388 static int rxperf_deliver_param_block(struct rxperf_call *call)
389 {
390         u32 version;
391         int ret;
392
393         /* Extract the parameter block */
394         ret = rxperf_extract_data(call, true);
395         if (ret < 0)
396                 return ret;
397
398         version                 = ntohl(call->params.version);
399         call->operation_id      = ntohl(call->params.type);
400         call->deliver           = rxperf_deliver_request;
401
402         if (version != RX_PERF_VERSION) {
403                 pr_info("Version mismatch %x\n", version);
404                 return -ENOTSUPP;
405         }
406
407         switch (call->operation_id) {
408         case RX_PERF_SEND:
409                 call->type = "send";
410                 call->reply_len = 0;
411                 call->iov_len = 4;      /* Expect req size */
412                 break;
413         case RX_PERF_RECV:
414                 call->type = "recv";
415                 call->req_len = 0;
416                 call->iov_len = 4;      /* Expect reply size */
417                 break;
418         case RX_PERF_RPC:
419                 call->type = "rpc";
420                 call->iov_len = 8;      /* Expect req size and reply size */
421                 break;
422         case RX_PERF_FILE:
423                 call->type = "file";
424                 fallthrough;
425         default:
426                 return -EOPNOTSUPP;
427         }
428
429         rxperf_set_call_state(call, RXPERF_CALL_SV_AWAIT_REQUEST);
430         return call->deliver(call);
431 }
432
433 /*
434  * Deliver the request data.
435  */
436 static int rxperf_deliver_request(struct rxperf_call *call)
437 {
438         int ret;
439
440         switch (call->unmarshal) {
441         case 0:
442                 call->kvec[0].iov_len   = call->iov_len;
443                 call->kvec[0].iov_base  = call->tmp;
444                 iov_iter_kvec(&call->iter, READ, call->kvec, 1, call->iov_len);
445                 call->unmarshal++;
446                 fallthrough;
447         case 1:
448                 ret = rxperf_extract_data(call, true);
449                 if (ret < 0)
450                         return ret;
451
452                 switch (call->operation_id) {
453                 case RX_PERF_SEND:
454                         call->type = "send";
455                         call->req_len   = ntohl(call->tmp[0]);
456                         call->reply_len = 0;
457                         break;
458                 case RX_PERF_RECV:
459                         call->type = "recv";
460                         call->req_len = 0;
461                         call->reply_len = ntohl(call->tmp[0]);
462                         break;
463                 case RX_PERF_RPC:
464                         call->type = "rpc";
465                         call->req_len   = ntohl(call->tmp[0]);
466                         call->reply_len = ntohl(call->tmp[1]);
467                         break;
468                 default:
469                         pr_info("Can't parse extra params\n");
470                         return -EIO;
471                 }
472
473                 pr_debug("CALL op=%s rq=%zx rp=%zx\n",
474                          call->type, call->req_len, call->reply_len);
475
476                 call->iov_len = call->req_len;
477                 iov_iter_discard(&call->iter, READ, call->req_len);
478                 call->unmarshal++;
479                 fallthrough;
480         case 2:
481                 ret = rxperf_extract_data(call, false);
482                 if (ret < 0)
483                         return ret;
484                 call->unmarshal++;
485                 fallthrough;
486         default:
487                 return 0;
488         }
489 }
490
491 /*
492  * Process a call for which we've received the request.
493  */
494 static int rxperf_process_call(struct rxperf_call *call)
495 {
496         struct msghdr msg = {};
497         struct bio_vec bv;
498         struct kvec iov[1];
499         ssize_t n;
500         size_t reply_len = call->reply_len, len;
501
502         rxrpc_kernel_set_tx_length(rxperf_socket, call->rxcall,
503                                    reply_len + sizeof(rxperf_magic_cookie));
504
505         while (reply_len > 0) {
506                 len = min_t(size_t, reply_len, PAGE_SIZE);
507                 bvec_set_page(&bv, ZERO_PAGE(0), len, 0);
508                 iov_iter_bvec(&msg.msg_iter, WRITE, &bv, 1, len);
509                 msg.msg_flags = MSG_MORE;
510                 n = rxrpc_kernel_send_data(rxperf_socket, call->rxcall, &msg,
511                                            len, rxperf_notify_end_reply_tx);
512                 if (n < 0)
513                         return n;
514                 if (n == 0)
515                         return -EIO;
516                 reply_len -= n;
517         }
518
519         len = sizeof(rxperf_magic_cookie);
520         iov[0].iov_base = (void *)rxperf_magic_cookie;
521         iov[0].iov_len  = len;
522         iov_iter_kvec(&msg.msg_iter, WRITE, iov, 1, len);
523         msg.msg_flags = 0;
524         n = rxrpc_kernel_send_data(rxperf_socket, call->rxcall, &msg, len,
525                                    rxperf_notify_end_reply_tx);
526         if (n >= 0)
527                 return 0; /* Success */
528
529         if (n == -ENOMEM)
530                 rxrpc_kernel_abort_call(rxperf_socket, call->rxcall,
531                                         RXGEN_SS_MARSHAL, -ENOMEM,
532                                         rxperf_abort_oom);
533         return n;
534 }
535
536 /*
537  * Add a key to the security keyring.
538  */
539 static int rxperf_add_key(struct key *keyring)
540 {
541         key_ref_t kref;
542         int ret;
543
544         kref = key_create_or_update(make_key_ref(keyring, true),
545                                     "rxrpc_s",
546                                     __stringify(RX_PERF_SERVICE) ":2",
547                                     secret,
548                                     sizeof(secret),
549                                     KEY_POS_VIEW | KEY_POS_READ | KEY_POS_SEARCH
550                                     | KEY_USR_VIEW,
551                                     KEY_ALLOC_NOT_IN_QUOTA);
552
553         if (IS_ERR(kref)) {
554                 pr_err("Can't allocate rxperf server key: %ld\n", PTR_ERR(kref));
555                 return PTR_ERR(kref);
556         }
557
558         ret = key_link(keyring, key_ref_to_ptr(kref));
559         if (ret < 0)
560                 pr_err("Can't link rxperf server key: %d\n", ret);
561         key_ref_put(kref);
562         return ret;
563 }
564
565 /*
566  * Initialise the rxperf server.
567  */
568 static int __init rxperf_init(void)
569 {
570         struct key *keyring;
571         int ret = -ENOMEM;
572
573         pr_info("Server registering\n");
574
575         rxperf_workqueue = alloc_workqueue("rxperf", 0, 0);
576         if (!rxperf_workqueue)
577                 goto error_workqueue;
578
579         keyring = keyring_alloc("rxperf_server",
580                                 GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, current_cred(),
581                                 KEY_POS_VIEW | KEY_POS_READ | KEY_POS_SEARCH |
582                                 KEY_POS_WRITE |
583                                 KEY_USR_VIEW | KEY_USR_READ | KEY_USR_SEARCH |
584                                 KEY_USR_WRITE |
585                                 KEY_OTH_VIEW | KEY_OTH_READ | KEY_OTH_SEARCH,
586                                 KEY_ALLOC_NOT_IN_QUOTA,
587                                 NULL, NULL);
588         if (IS_ERR(keyring)) {
589                 pr_err("Can't allocate rxperf server keyring: %ld\n",
590                        PTR_ERR(keyring));
591                 goto error_keyring;
592         }
593         rxperf_sec_keyring = keyring;
594         ret = rxperf_add_key(keyring);
595         if (ret < 0)
596                 goto error_key;
597
598         ret = rxperf_open_socket();
599         if (ret < 0)
600                 goto error_socket;
601         return 0;
602
603 error_socket:
604 error_key:
605         key_put(rxperf_sec_keyring);
606 error_keyring:
607         destroy_workqueue(rxperf_workqueue);
608         rcu_barrier();
609 error_workqueue:
610         pr_err("Failed to register: %d\n", ret);
611         return ret;
612 }
613 late_initcall(rxperf_init); /* Must be called after net/ to create socket */
614
615 static void __exit rxperf_exit(void)
616 {
617         pr_info("Server unregistering.\n");
618
619         rxperf_close_socket();
620         key_put(rxperf_sec_keyring);
621         destroy_workqueue(rxperf_workqueue);
622         rcu_barrier();
623 }
624 module_exit(rxperf_exit);
625