mptcp: Add key generation and token tree
[linux-2.6-block.git] / net / mptcp / protocol.c
CommitLineData
f870fa0b
MM
1// SPDX-License-Identifier: GPL-2.0
2/* Multipath TCP
3 *
4 * Copyright (c) 2017 - 2019, Intel Corporation.
5 */
6
7#define pr_fmt(fmt) "MPTCP: " fmt
8
9#include <linux/kernel.h>
10#include <linux/module.h>
11#include <linux/netdevice.h>
12#include <net/sock.h>
13#include <net/inet_common.h>
14#include <net/inet_hashtables.h>
15#include <net/protocol.h>
16#include <net/tcp.h>
cf7da0d6
PK
17#if IS_ENABLED(CONFIG_MPTCP_IPV6)
18#include <net/transp_v6.h>
19#endif
f870fa0b
MM
20#include <net/mptcp.h>
21#include "protocol.h"
22
2303f994
PK
23#define MPTCP_SAME_STATE TCP_MAX_STATES
24
25/* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
26 * completed yet or has failed, return the subflow socket.
27 * Otherwise return NULL.
28 */
29static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
30{
cec37a6e 31 if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack)
2303f994
PK
32 return NULL;
33
34 return msk->subflow;
35}
36
cec37a6e
PK
37/* if msk has a single subflow, and the mp_capable handshake is failed,
38 * return it.
39 * Otherwise returns NULL
40 */
41static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk)
42{
43 struct socket *ssock = __mptcp_nmpc_socket(msk);
44
45 sock_owned_by_me((const struct sock *)msk);
46
47 if (!ssock || sk_is_mptcp(ssock->sk))
48 return NULL;
49
50 return ssock;
51}
52
2303f994
PK
53static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
54{
55 return ((struct sock *)msk)->sk_state == TCP_CLOSE;
56}
57
58static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
59{
60 struct mptcp_subflow_context *subflow;
61 struct sock *sk = (struct sock *)msk;
62 struct socket *ssock;
63 int err;
64
65 ssock = __mptcp_nmpc_socket(msk);
66 if (ssock)
67 goto set_state;
68
69 if (!__mptcp_can_create_subflow(msk))
70 return ERR_PTR(-EINVAL);
71
72 err = mptcp_subflow_create_socket(sk, &ssock);
73 if (err)
74 return ERR_PTR(err);
75
76 msk->subflow = ssock;
77 subflow = mptcp_subflow_ctx(ssock->sk);
cec37a6e 78 list_add(&subflow->node, &msk->conn_list);
2303f994
PK
79 subflow->request_mptcp = 1;
80
81set_state:
82 if (state != MPTCP_SAME_STATE)
83 inet_sk_state_store(sk, state);
84 return ssock;
85}
86
cec37a6e
PK
87static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk)
88{
89 struct mptcp_subflow_context *subflow;
90
91 sock_owned_by_me((const struct sock *)msk);
92
93 mptcp_for_each_subflow(msk, subflow) {
94 return mptcp_subflow_tcp_sock(subflow);
95 }
96
97 return NULL;
98}
99
f870fa0b
MM
100static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
101{
102 struct mptcp_sock *msk = mptcp_sk(sk);
cec37a6e
PK
103 struct socket *ssock;
104 struct sock *ssk;
105 int ret;
f870fa0b
MM
106
107 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
108 return -EOPNOTSUPP;
109
cec37a6e
PK
110 lock_sock(sk);
111 ssock = __mptcp_tcp_fallback(msk);
112 if (ssock) {
113 pr_debug("fallback passthrough");
114 ret = sock_sendmsg(ssock, msg);
115 release_sock(sk);
116 return ret;
117 }
118
119 ssk = mptcp_subflow_get(msk);
120 if (!ssk) {
121 release_sock(sk);
122 return -ENOTCONN;
123 }
124
125 ret = sock_sendmsg(ssk->sk_socket, msg);
126
127 release_sock(sk);
128 return ret;
f870fa0b
MM
129}
130
131static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
132 int nonblock, int flags, int *addr_len)
133{
134 struct mptcp_sock *msk = mptcp_sk(sk);
cec37a6e
PK
135 struct socket *ssock;
136 struct sock *ssk;
137 int copied = 0;
f870fa0b
MM
138
139 if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
140 return -EOPNOTSUPP;
141
cec37a6e
PK
142 lock_sock(sk);
143 ssock = __mptcp_tcp_fallback(msk);
144 if (ssock) {
145 pr_debug("fallback-read subflow=%p",
146 mptcp_subflow_ctx(ssock->sk));
147 copied = sock_recvmsg(ssock, msg, flags);
148 release_sock(sk);
149 return copied;
150 }
151
152 ssk = mptcp_subflow_get(msk);
153 if (!ssk) {
154 release_sock(sk);
155 return -ENOTCONN;
156 }
157
158 copied = sock_recvmsg(ssk->sk_socket, msg, flags);
159
160 release_sock(sk);
161
162 return copied;
163}
164
165/* subflow sockets can be either outgoing (connect) or incoming
166 * (accept).
167 *
168 * Outgoing subflows use in-kernel sockets.
169 * Incoming subflows do not have their own 'struct socket' allocated,
170 * so we need to use tcp_close() after detaching them from the mptcp
171 * parent socket.
172 */
173static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
174 struct mptcp_subflow_context *subflow,
175 long timeout)
176{
177 struct socket *sock = READ_ONCE(ssk->sk_socket);
178
179 list_del(&subflow->node);
180
181 if (sock && sock != sk->sk_socket) {
182 /* outgoing subflow */
183 sock_release(sock);
184 } else {
185 /* incoming subflow */
186 tcp_close(ssk, timeout);
187 }
f870fa0b
MM
188}
189
190static int mptcp_init_sock(struct sock *sk)
191{
cec37a6e
PK
192 struct mptcp_sock *msk = mptcp_sk(sk);
193
194 INIT_LIST_HEAD(&msk->conn_list);
195
f870fa0b
MM
196 return 0;
197}
198
199static void mptcp_close(struct sock *sk, long timeout)
200{
cec37a6e 201 struct mptcp_subflow_context *subflow, *tmp;
f870fa0b
MM
202 struct mptcp_sock *msk = mptcp_sk(sk);
203
79c0949e 204 mptcp_token_destroy(msk->token);
f870fa0b
MM
205 inet_sk_state_store(sk, TCP_CLOSE);
206
cec37a6e
PK
207 lock_sock(sk);
208
209 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
210 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
211
212 __mptcp_close_ssk(sk, ssk, subflow, timeout);
f870fa0b
MM
213 }
214
cec37a6e
PK
215 release_sock(sk);
216 sk_common_release(sk);
f870fa0b
MM
217}
218
cf7da0d6
PK
219static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
220{
221#if IS_ENABLED(CONFIG_MPTCP_IPV6)
222 const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
223 struct ipv6_pinfo *msk6 = inet6_sk(msk);
224
225 msk->sk_v6_daddr = ssk->sk_v6_daddr;
226 msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
227
228 if (msk6 && ssk6) {
229 msk6->saddr = ssk6->saddr;
230 msk6->flow_label = ssk6->flow_label;
231 }
232#endif
233
234 inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
235 inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
236 inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
237 inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
238 inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
239 inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
240}
241
242static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
243 bool kern)
244{
245 struct mptcp_sock *msk = mptcp_sk(sk);
246 struct socket *listener;
247 struct sock *newsk;
248
249 listener = __mptcp_nmpc_socket(msk);
250 if (WARN_ON_ONCE(!listener)) {
251 *err = -EINVAL;
252 return NULL;
253 }
254
255 pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
256 newsk = inet_csk_accept(listener->sk, flags, err, kern);
257 if (!newsk)
258 return NULL;
259
260 pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
261
262 if (sk_is_mptcp(newsk)) {
263 struct mptcp_subflow_context *subflow;
264 struct sock *new_mptcp_sock;
265 struct sock *ssk = newsk;
266
267 subflow = mptcp_subflow_ctx(newsk);
268 lock_sock(sk);
269
270 local_bh_disable();
271 new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC);
272 if (!new_mptcp_sock) {
273 *err = -ENOBUFS;
274 local_bh_enable();
275 release_sock(sk);
276 tcp_close(newsk, 0);
277 return NULL;
278 }
279
280 mptcp_init_sock(new_mptcp_sock);
281
282 msk = mptcp_sk(new_mptcp_sock);
283 msk->remote_key = subflow->remote_key;
284 msk->local_key = subflow->local_key;
79c0949e 285 msk->token = subflow->token;
cf7da0d6
PK
286 msk->subflow = NULL;
287
79c0949e 288 mptcp_token_update_accept(newsk, new_mptcp_sock);
cf7da0d6
PK
289 newsk = new_mptcp_sock;
290 mptcp_copy_inaddrs(newsk, ssk);
291 list_add(&subflow->node, &msk->conn_list);
292
293 /* will be fully established at mptcp_stream_accept()
294 * completion.
295 */
296 inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV);
297 bh_unlock_sock(new_mptcp_sock);
298 local_bh_enable();
299 release_sock(sk);
300 }
301
302 return newsk;
303}
304
79c0949e
PK
305static void mptcp_destroy(struct sock *sk)
306{
307}
308
cec37a6e 309static int mptcp_get_port(struct sock *sk, unsigned short snum)
f870fa0b
MM
310{
311 struct mptcp_sock *msk = mptcp_sk(sk);
cec37a6e 312 struct socket *ssock;
f870fa0b 313
cec37a6e
PK
314 ssock = __mptcp_nmpc_socket(msk);
315 pr_debug("msk=%p, subflow=%p", msk, ssock);
316 if (WARN_ON_ONCE(!ssock))
317 return -EINVAL;
f870fa0b 318
cec37a6e
PK
319 return inet_csk_get_port(ssock->sk, snum);
320}
f870fa0b 321
cec37a6e
PK
322void mptcp_finish_connect(struct sock *ssk)
323{
324 struct mptcp_subflow_context *subflow;
325 struct mptcp_sock *msk;
326 struct sock *sk;
f870fa0b 327
cec37a6e 328 subflow = mptcp_subflow_ctx(ssk);
f870fa0b 329
cec37a6e
PK
330 if (!subflow->mp_capable)
331 return;
332
333 sk = subflow->conn;
334 msk = mptcp_sk(sk);
335
336 /* the socket is not connected yet, no msk/subflow ops can access/race
337 * accessing the field below
338 */
339 WRITE_ONCE(msk->remote_key, subflow->remote_key);
340 WRITE_ONCE(msk->local_key, subflow->local_key);
79c0949e 341 WRITE_ONCE(msk->token, subflow->token);
f870fa0b
MM
342}
343
cf7da0d6
PK
344static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
345{
346 write_lock_bh(&sk->sk_callback_lock);
347 rcu_assign_pointer(sk->sk_wq, &parent->wq);
348 sk_set_socket(sk, parent);
349 sk->sk_uid = SOCK_INODE(parent)->i_uid;
350 write_unlock_bh(&sk->sk_callback_lock);
351}
352
f870fa0b
MM
353static struct proto mptcp_prot = {
354 .name = "MPTCP",
355 .owner = THIS_MODULE,
356 .init = mptcp_init_sock,
357 .close = mptcp_close,
cf7da0d6 358 .accept = mptcp_accept,
f870fa0b 359 .shutdown = tcp_shutdown,
79c0949e 360 .destroy = mptcp_destroy,
f870fa0b
MM
361 .sendmsg = mptcp_sendmsg,
362 .recvmsg = mptcp_recvmsg,
363 .hash = inet_hash,
364 .unhash = inet_unhash,
cec37a6e 365 .get_port = mptcp_get_port,
f870fa0b
MM
366 .obj_size = sizeof(struct mptcp_sock),
367 .no_autobind = true,
368};
369
2303f994
PK
370static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
371{
372 struct mptcp_sock *msk = mptcp_sk(sock->sk);
373 struct socket *ssock;
cf7da0d6 374 int err;
2303f994
PK
375
376 lock_sock(sock->sk);
377 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
378 if (IS_ERR(ssock)) {
379 err = PTR_ERR(ssock);
380 goto unlock;
381 }
382
383 err = ssock->ops->bind(ssock, uaddr, addr_len);
cf7da0d6
PK
384 if (!err)
385 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2303f994
PK
386
387unlock:
388 release_sock(sock->sk);
389 return err;
390}
391
392static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
393 int addr_len, int flags)
394{
395 struct mptcp_sock *msk = mptcp_sk(sock->sk);
396 struct socket *ssock;
397 int err;
398
399 lock_sock(sock->sk);
400 ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
401 if (IS_ERR(ssock)) {
402 err = PTR_ERR(ssock);
403 goto unlock;
404 }
405
cf7da0d6
PK
406#ifdef CONFIG_TCP_MD5SIG
407 /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
408 * TCP option space.
409 */
410 if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
411 mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
412#endif
413
2303f994
PK
414 err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
415 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
cf7da0d6 416 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2303f994
PK
417
418unlock:
419 release_sock(sock->sk);
420 return err;
421}
422
cf7da0d6
PK
423static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
424 int peer)
425{
426 if (sock->sk->sk_prot == &tcp_prot) {
427 /* we are being invoked from __sys_accept4, after
428 * mptcp_accept() has just accepted a non-mp-capable
429 * flow: sk is a tcp_sk, not an mptcp one.
430 *
431 * Hand the socket over to tcp so all further socket ops
432 * bypass mptcp.
433 */
434 sock->ops = &inet_stream_ops;
435 }
436
437 return inet_getname(sock, uaddr, peer);
438}
439
440#if IS_ENABLED(CONFIG_MPTCP_IPV6)
441static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
442 int peer)
443{
444 if (sock->sk->sk_prot == &tcpv6_prot) {
445 /* we are being invoked from __sys_accept4 after
446 * mptcp_accept() has accepted a non-mp-capable
447 * subflow: sk is a tcp_sk, not mptcp.
448 *
449 * Hand the socket over to tcp so all further
450 * socket ops bypass mptcp.
451 */
452 sock->ops = &inet6_stream_ops;
453 }
454
455 return inet6_getname(sock, uaddr, peer);
456}
457#endif
458
459static int mptcp_listen(struct socket *sock, int backlog)
460{
461 struct mptcp_sock *msk = mptcp_sk(sock->sk);
462 struct socket *ssock;
463 int err;
464
465 pr_debug("msk=%p", msk);
466
467 lock_sock(sock->sk);
468 ssock = __mptcp_socket_create(msk, TCP_LISTEN);
469 if (IS_ERR(ssock)) {
470 err = PTR_ERR(ssock);
471 goto unlock;
472 }
473
474 err = ssock->ops->listen(ssock, backlog);
475 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
476 if (!err)
477 mptcp_copy_inaddrs(sock->sk, ssock->sk);
478
479unlock:
480 release_sock(sock->sk);
481 return err;
482}
483
484static bool is_tcp_proto(const struct proto *p)
485{
486#if IS_ENABLED(CONFIG_MPTCP_IPV6)
487 return p == &tcp_prot || p == &tcpv6_prot;
488#else
489 return p == &tcp_prot;
490#endif
491}
492
493static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
494 int flags, bool kern)
495{
496 struct mptcp_sock *msk = mptcp_sk(sock->sk);
497 struct socket *ssock;
498 int err;
499
500 pr_debug("msk=%p", msk);
501
502 lock_sock(sock->sk);
503 if (sock->sk->sk_state != TCP_LISTEN)
504 goto unlock_fail;
505
506 ssock = __mptcp_nmpc_socket(msk);
507 if (!ssock)
508 goto unlock_fail;
509
510 sock_hold(ssock->sk);
511 release_sock(sock->sk);
512
513 err = ssock->ops->accept(sock, newsock, flags, kern);
514 if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
515 struct mptcp_sock *msk = mptcp_sk(newsock->sk);
516 struct mptcp_subflow_context *subflow;
517
518 /* set ssk->sk_socket of accept()ed flows to mptcp socket.
519 * This is needed so NOSPACE flag can be set from tcp stack.
520 */
521 list_for_each_entry(subflow, &msk->conn_list, node) {
522 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
523
524 if (!ssk->sk_socket)
525 mptcp_sock_graft(ssk, newsock);
526 }
527
528 inet_sk_state_store(newsock->sk, TCP_ESTABLISHED);
529 }
530
531 sock_put(ssock->sk);
532 return err;
533
534unlock_fail:
535 release_sock(sock->sk);
536 return -EINVAL;
537}
538
2303f994
PK
539static __poll_t mptcp_poll(struct file *file, struct socket *sock,
540 struct poll_table_struct *wait)
541{
542 __poll_t mask = 0;
543
544 return mask;
545}
546
547static struct proto_ops mptcp_stream_ops;
548
f870fa0b
MM
549static struct inet_protosw mptcp_protosw = {
550 .type = SOCK_STREAM,
551 .protocol = IPPROTO_MPTCP,
552 .prot = &mptcp_prot,
2303f994
PK
553 .ops = &mptcp_stream_ops,
554 .flags = INET_PROTOSW_ICSK,
f870fa0b
MM
555};
556
557void __init mptcp_init(void)
558{
2303f994
PK
559 mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
560 mptcp_stream_ops = inet_stream_ops;
561 mptcp_stream_ops.bind = mptcp_bind;
562 mptcp_stream_ops.connect = mptcp_stream_connect;
563 mptcp_stream_ops.poll = mptcp_poll;
cf7da0d6
PK
564 mptcp_stream_ops.accept = mptcp_stream_accept;
565 mptcp_stream_ops.getname = mptcp_v4_getname;
566 mptcp_stream_ops.listen = mptcp_listen;
2303f994
PK
567
568 mptcp_subflow_init();
569
f870fa0b
MM
570 if (proto_register(&mptcp_prot, 1) != 0)
571 panic("Failed to register MPTCP proto.\n");
572
573 inet_register_protosw(&mptcp_protosw);
574}
575
576#if IS_ENABLED(CONFIG_MPTCP_IPV6)
2303f994 577static struct proto_ops mptcp_v6_stream_ops;
f870fa0b
MM
578static struct proto mptcp_v6_prot;
579
79c0949e
PK
580static void mptcp_v6_destroy(struct sock *sk)
581{
582 mptcp_destroy(sk);
583 inet6_destroy_sock(sk);
584}
585
f870fa0b
MM
586static struct inet_protosw mptcp_v6_protosw = {
587 .type = SOCK_STREAM,
588 .protocol = IPPROTO_MPTCP,
589 .prot = &mptcp_v6_prot,
2303f994 590 .ops = &mptcp_v6_stream_ops,
f870fa0b
MM
591 .flags = INET_PROTOSW_ICSK,
592};
593
594int mptcpv6_init(void)
595{
596 int err;
597
598 mptcp_v6_prot = mptcp_prot;
599 strcpy(mptcp_v6_prot.name, "MPTCPv6");
600 mptcp_v6_prot.slab = NULL;
79c0949e 601 mptcp_v6_prot.destroy = mptcp_v6_destroy;
f870fa0b
MM
602 mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
603 sizeof(struct ipv6_pinfo);
604
605 err = proto_register(&mptcp_v6_prot, 1);
606 if (err)
607 return err;
608
2303f994
PK
609 mptcp_v6_stream_ops = inet6_stream_ops;
610 mptcp_v6_stream_ops.bind = mptcp_bind;
611 mptcp_v6_stream_ops.connect = mptcp_stream_connect;
612 mptcp_v6_stream_ops.poll = mptcp_poll;
cf7da0d6
PK
613 mptcp_v6_stream_ops.accept = mptcp_stream_accept;
614 mptcp_v6_stream_ops.getname = mptcp_v6_getname;
615 mptcp_v6_stream_ops.listen = mptcp_listen;
2303f994 616
f870fa0b
MM
617 err = inet6_register_protosw(&mptcp_v6_protosw);
618 if (err)
619 proto_unregister(&mptcp_v6_prot);
620
621 return err;
622}
623#endif