Documentation: PM: Drop pme_interrupt reference
[linux-2.6-block.git] / net / mptcp / pm_userspace.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2022, Intel Corporation.
5  */
6
7 #include "protocol.h"
8
9 void mptcp_free_local_addr_list(struct mptcp_sock *msk)
10 {
11         struct mptcp_pm_addr_entry *entry, *tmp;
12         struct sock *sk = (struct sock *)msk;
13         LIST_HEAD(free_list);
14
15         if (!mptcp_pm_is_userspace(msk))
16                 return;
17
18         spin_lock_bh(&msk->pm.lock);
19         list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
20         spin_unlock_bh(&msk->pm.lock);
21
22         list_for_each_entry_safe(entry, tmp, &free_list, list) {
23                 sock_kfree_s(sk, entry, sizeof(*entry));
24         }
25 }
26
27 int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
28                                              struct mptcp_pm_addr_entry *entry)
29 {
30         DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
31         struct mptcp_pm_addr_entry *match = NULL;
32         struct sock *sk = (struct sock *)msk;
33         struct mptcp_pm_addr_entry *e;
34         bool addr_match = false;
35         bool id_match = false;
36         int ret = -EINVAL;
37
38         bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
39
40         spin_lock_bh(&msk->pm.lock);
41         list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
42                 addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
43                 if (addr_match && entry->addr.id == 0)
44                         entry->addr.id = e->addr.id;
45                 id_match = (e->addr.id == entry->addr.id);
46                 if (addr_match && id_match) {
47                         match = e;
48                         break;
49                 } else if (addr_match || id_match) {
50                         break;
51                 }
52                 __set_bit(e->addr.id, id_bitmap);
53         }
54
55         if (!match && !addr_match && !id_match) {
56                 /* Memory for the entry is allocated from the
57                  * sock option buffer.
58                  */
59                 e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
60                 if (!e) {
61                         spin_unlock_bh(&msk->pm.lock);
62                         return -ENOMEM;
63                 }
64
65                 *e = *entry;
66                 if (!e->addr.id)
67                         e->addr.id = find_next_zero_bit(id_bitmap,
68                                                         MPTCP_PM_MAX_ADDR_ID + 1,
69                                                         1);
70                 list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
71                 ret = e->addr.id;
72         } else if (match) {
73                 ret = entry->addr.id;
74         }
75
76         spin_unlock_bh(&msk->pm.lock);
77         return ret;
78 }
79
80 int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
81                                                    unsigned int id,
82                                                    u8 *flags, int *ifindex)
83 {
84         struct mptcp_pm_addr_entry *entry, *match = NULL;
85
86         *flags = 0;
87         *ifindex = 0;
88
89         spin_lock_bh(&msk->pm.lock);
90         list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
91                 if (id == entry->addr.id) {
92                         match = entry;
93                         break;
94                 }
95         }
96         spin_unlock_bh(&msk->pm.lock);
97         if (match) {
98                 *flags = match->flags;
99                 *ifindex = match->ifindex;
100         }
101
102         return 0;
103 }
104
105 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
106                                     struct mptcp_addr_info *skc)
107 {
108         struct mptcp_pm_addr_entry new_entry;
109         __be16 msk_sport =  ((struct inet_sock *)
110                              inet_sk((struct sock *)msk))->inet_sport;
111
112         memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
113         new_entry.addr = *skc;
114         new_entry.addr.id = 0;
115         new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
116
117         if (new_entry.addr.port == msk_sport)
118                 new_entry.addr.port = 0;
119
120         return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry);
121 }
122
123 int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info)
124 {
125         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
126         struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR];
127         struct mptcp_pm_addr_entry addr_val;
128         struct mptcp_sock *msk;
129         int err = -EINVAL;
130         u32 token_val;
131
132         if (!addr || !token) {
133                 GENL_SET_ERR_MSG(info, "missing required inputs");
134                 return err;
135         }
136
137         token_val = nla_get_u32(token);
138
139         msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
140         if (!msk) {
141                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
142                 return err;
143         }
144
145         if (!mptcp_pm_is_userspace(msk)) {
146                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
147                 goto announce_err;
148         }
149
150         err = mptcp_pm_parse_entry(addr, info, true, &addr_val);
151         if (err < 0) {
152                 GENL_SET_ERR_MSG(info, "error parsing local address");
153                 goto announce_err;
154         }
155
156         if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
157                 GENL_SET_ERR_MSG(info, "invalid addr id or flags");
158                 goto announce_err;
159         }
160
161         err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val);
162         if (err < 0) {
163                 GENL_SET_ERR_MSG(info, "did not match address and id");
164                 goto announce_err;
165         }
166
167         lock_sock((struct sock *)msk);
168         spin_lock_bh(&msk->pm.lock);
169
170         if (mptcp_pm_alloc_anno_list(msk, &addr_val)) {
171                 mptcp_pm_announce_addr(msk, &addr_val.addr, false);
172                 mptcp_pm_nl_addr_send_ack(msk);
173         }
174
175         spin_unlock_bh(&msk->pm.lock);
176         release_sock((struct sock *)msk);
177
178         err = 0;
179  announce_err:
180         sock_put((struct sock *)msk);
181         return err;
182 }
183
184 int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info)
185 {
186         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
187         struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID];
188         struct mptcp_pm_addr_entry *match = NULL;
189         struct mptcp_pm_addr_entry *entry;
190         struct mptcp_sock *msk;
191         LIST_HEAD(free_list);
192         int err = -EINVAL;
193         u32 token_val;
194         u8 id_val;
195
196         if (!id || !token) {
197                 GENL_SET_ERR_MSG(info, "missing required inputs");
198                 return err;
199         }
200
201         id_val = nla_get_u8(id);
202         token_val = nla_get_u32(token);
203
204         msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
205         if (!msk) {
206                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
207                 return err;
208         }
209
210         if (!mptcp_pm_is_userspace(msk)) {
211                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
212                 goto remove_err;
213         }
214
215         lock_sock((struct sock *)msk);
216
217         list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
218                 if (entry->addr.id == id_val) {
219                         match = entry;
220                         break;
221                 }
222         }
223
224         if (!match) {
225                 GENL_SET_ERR_MSG(info, "address with specified id not found");
226                 release_sock((struct sock *)msk);
227                 goto remove_err;
228         }
229
230         list_move(&match->list, &free_list);
231
232         mptcp_pm_remove_addrs_and_subflows(msk, &free_list);
233
234         release_sock((struct sock *)msk);
235
236         list_for_each_entry_safe(match, entry, &free_list, list) {
237                 sock_kfree_s((struct sock *)msk, match, sizeof(*match));
238         }
239
240         err = 0;
241  remove_err:
242         sock_put((struct sock *)msk);
243         return err;
244 }
245
246 int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info)
247 {
248         struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
249         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
250         struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
251         struct mptcp_addr_info addr_r;
252         struct mptcp_addr_info addr_l;
253         struct mptcp_sock *msk;
254         int err = -EINVAL;
255         struct sock *sk;
256         u32 token_val;
257
258         if (!laddr || !raddr || !token) {
259                 GENL_SET_ERR_MSG(info, "missing required inputs");
260                 return err;
261         }
262
263         token_val = nla_get_u32(token);
264
265         msk = mptcp_token_get_sock(genl_info_net(info), token_val);
266         if (!msk) {
267                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
268                 return err;
269         }
270
271         if (!mptcp_pm_is_userspace(msk)) {
272                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
273                 goto create_err;
274         }
275
276         err = mptcp_pm_parse_addr(laddr, info, &addr_l);
277         if (err < 0) {
278                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
279                 goto create_err;
280         }
281
282         if (addr_l.id == 0) {
283                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id");
284                 goto create_err;
285         }
286
287         err = mptcp_pm_parse_addr(raddr, info, &addr_r);
288         if (err < 0) {
289                 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
290                 goto create_err;
291         }
292
293         sk = &msk->sk.icsk_inet.sk;
294         lock_sock(sk);
295
296         err = __mptcp_subflow_connect(sk, &addr_l, &addr_r);
297
298         release_sock(sk);
299
300  create_err:
301         sock_put((struct sock *)msk);
302         return err;
303 }
304
305 static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk,
306                                       const struct mptcp_addr_info *local,
307                                       const struct mptcp_addr_info *remote)
308 {
309         struct sock *sk = &msk->sk.icsk_inet.sk;
310         struct mptcp_subflow_context *subflow;
311         struct sock *found = NULL;
312
313         if (local->family != remote->family)
314                 return NULL;
315
316         lock_sock(sk);
317
318         mptcp_for_each_subflow(msk, subflow) {
319                 const struct inet_sock *issk;
320                 struct sock *ssk;
321
322                 ssk = mptcp_subflow_tcp_sock(subflow);
323
324                 if (local->family != ssk->sk_family)
325                         continue;
326
327                 issk = inet_sk(ssk);
328
329                 switch (ssk->sk_family) {
330                 case AF_INET:
331                         if (issk->inet_saddr != local->addr.s_addr ||
332                             issk->inet_daddr != remote->addr.s_addr)
333                                 continue;
334                         break;
335 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
336                 case AF_INET6: {
337                         const struct ipv6_pinfo *pinfo = inet6_sk(ssk);
338
339                         if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) ||
340                             !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr))
341                                 continue;
342                         break;
343                 }
344 #endif
345                 default:
346                         continue;
347                 }
348
349                 if (issk->inet_sport == local->port &&
350                     issk->inet_dport == remote->port) {
351                         found = ssk;
352                         goto found;
353                 }
354         }
355
356 found:
357         release_sock(sk);
358
359         return found;
360 }
361
362 int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info)
363 {
364         struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
365         struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
366         struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
367         struct mptcp_addr_info addr_l;
368         struct mptcp_addr_info addr_r;
369         struct mptcp_sock *msk;
370         struct sock *sk, *ssk;
371         int err = -EINVAL;
372         u32 token_val;
373
374         if (!laddr || !raddr || !token) {
375                 GENL_SET_ERR_MSG(info, "missing required inputs");
376                 return err;
377         }
378
379         token_val = nla_get_u32(token);
380
381         msk = mptcp_token_get_sock(genl_info_net(info), token_val);
382         if (!msk) {
383                 NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
384                 return err;
385         }
386
387         if (!mptcp_pm_is_userspace(msk)) {
388                 GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
389                 goto destroy_err;
390         }
391
392         err = mptcp_pm_parse_addr(laddr, info, &addr_l);
393         if (err < 0) {
394                 NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
395                 goto destroy_err;
396         }
397
398         err = mptcp_pm_parse_addr(raddr, info, &addr_r);
399         if (err < 0) {
400                 NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
401                 goto destroy_err;
402         }
403
404         if (addr_l.family != addr_r.family) {
405                 GENL_SET_ERR_MSG(info, "address families do not match");
406                 goto destroy_err;
407         }
408
409         if (!addr_l.port || !addr_r.port) {
410                 GENL_SET_ERR_MSG(info, "missing local or remote port");
411                 goto destroy_err;
412         }
413
414         sk = &msk->sk.icsk_inet.sk;
415         ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r);
416         if (ssk) {
417                 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
418
419                 mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN);
420                 mptcp_close_ssk(sk, ssk, subflow);
421                 err = 0;
422         } else {
423                 err = -ESRCH;
424         }
425
426  destroy_err:
427         sock_put((struct sock *)msk);
428         return err;
429 }