Merge tag 'tomoyo-pr-20230903' of git://git.osdn.net/gitroot/tomoyo/tomoyo-test1
[linux-block.git] / net / devlink / sb.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright (c) 2016 Mellanox Technologies. All rights reserved.
4  * Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
5  */
6
7 #include "devl_internal.h"
8
9 struct devlink_sb {
10         struct list_head list;
11         unsigned int index;
12         u32 size;
13         u16 ingress_pools_count;
14         u16 egress_pools_count;
15         u16 ingress_tc_count;
16         u16 egress_tc_count;
17 };
18
19 static u16 devlink_sb_pool_count(struct devlink_sb *devlink_sb)
20 {
21         return devlink_sb->ingress_pools_count + devlink_sb->egress_pools_count;
22 }
23
24 static struct devlink_sb *devlink_sb_get_by_index(struct devlink *devlink,
25                                                   unsigned int sb_index)
26 {
27         struct devlink_sb *devlink_sb;
28
29         list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
30                 if (devlink_sb->index == sb_index)
31                         return devlink_sb;
32         }
33         return NULL;
34 }
35
36 static bool devlink_sb_index_exists(struct devlink *devlink,
37                                     unsigned int sb_index)
38 {
39         return devlink_sb_get_by_index(devlink, sb_index);
40 }
41
42 static struct devlink_sb *devlink_sb_get_from_attrs(struct devlink *devlink,
43                                                     struct nlattr **attrs)
44 {
45         if (attrs[DEVLINK_ATTR_SB_INDEX]) {
46                 u32 sb_index = nla_get_u32(attrs[DEVLINK_ATTR_SB_INDEX]);
47                 struct devlink_sb *devlink_sb;
48
49                 devlink_sb = devlink_sb_get_by_index(devlink, sb_index);
50                 if (!devlink_sb)
51                         return ERR_PTR(-ENODEV);
52                 return devlink_sb;
53         }
54         return ERR_PTR(-EINVAL);
55 }
56
57 static struct devlink_sb *devlink_sb_get_from_info(struct devlink *devlink,
58                                                    struct genl_info *info)
59 {
60         return devlink_sb_get_from_attrs(devlink, info->attrs);
61 }
62
63 static int devlink_sb_pool_index_get_from_attrs(struct devlink_sb *devlink_sb,
64                                                 struct nlattr **attrs,
65                                                 u16 *p_pool_index)
66 {
67         u16 val;
68
69         if (!attrs[DEVLINK_ATTR_SB_POOL_INDEX])
70                 return -EINVAL;
71
72         val = nla_get_u16(attrs[DEVLINK_ATTR_SB_POOL_INDEX]);
73         if (val >= devlink_sb_pool_count(devlink_sb))
74                 return -EINVAL;
75         *p_pool_index = val;
76         return 0;
77 }
78
79 static int devlink_sb_pool_index_get_from_info(struct devlink_sb *devlink_sb,
80                                                struct genl_info *info,
81                                                u16 *p_pool_index)
82 {
83         return devlink_sb_pool_index_get_from_attrs(devlink_sb, info->attrs,
84                                                     p_pool_index);
85 }
86
87 static int
88 devlink_sb_pool_type_get_from_attrs(struct nlattr **attrs,
89                                     enum devlink_sb_pool_type *p_pool_type)
90 {
91         u8 val;
92
93         if (!attrs[DEVLINK_ATTR_SB_POOL_TYPE])
94                 return -EINVAL;
95
96         val = nla_get_u8(attrs[DEVLINK_ATTR_SB_POOL_TYPE]);
97         if (val != DEVLINK_SB_POOL_TYPE_INGRESS &&
98             val != DEVLINK_SB_POOL_TYPE_EGRESS)
99                 return -EINVAL;
100         *p_pool_type = val;
101         return 0;
102 }
103
104 static int
105 devlink_sb_pool_type_get_from_info(struct genl_info *info,
106                                    enum devlink_sb_pool_type *p_pool_type)
107 {
108         return devlink_sb_pool_type_get_from_attrs(info->attrs, p_pool_type);
109 }
110
111 static int
112 devlink_sb_th_type_get_from_attrs(struct nlattr **attrs,
113                                   enum devlink_sb_threshold_type *p_th_type)
114 {
115         u8 val;
116
117         if (!attrs[DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE])
118                 return -EINVAL;
119
120         val = nla_get_u8(attrs[DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE]);
121         if (val != DEVLINK_SB_THRESHOLD_TYPE_STATIC &&
122             val != DEVLINK_SB_THRESHOLD_TYPE_DYNAMIC)
123                 return -EINVAL;
124         *p_th_type = val;
125         return 0;
126 }
127
128 static int
129 devlink_sb_th_type_get_from_info(struct genl_info *info,
130                                  enum devlink_sb_threshold_type *p_th_type)
131 {
132         return devlink_sb_th_type_get_from_attrs(info->attrs, p_th_type);
133 }
134
135 static int
136 devlink_sb_tc_index_get_from_attrs(struct devlink_sb *devlink_sb,
137                                    struct nlattr **attrs,
138                                    enum devlink_sb_pool_type pool_type,
139                                    u16 *p_tc_index)
140 {
141         u16 val;
142
143         if (!attrs[DEVLINK_ATTR_SB_TC_INDEX])
144                 return -EINVAL;
145
146         val = nla_get_u16(attrs[DEVLINK_ATTR_SB_TC_INDEX]);
147         if (pool_type == DEVLINK_SB_POOL_TYPE_INGRESS &&
148             val >= devlink_sb->ingress_tc_count)
149                 return -EINVAL;
150         if (pool_type == DEVLINK_SB_POOL_TYPE_EGRESS &&
151             val >= devlink_sb->egress_tc_count)
152                 return -EINVAL;
153         *p_tc_index = val;
154         return 0;
155 }
156
157 static int
158 devlink_sb_tc_index_get_from_info(struct devlink_sb *devlink_sb,
159                                   struct genl_info *info,
160                                   enum devlink_sb_pool_type pool_type,
161                                   u16 *p_tc_index)
162 {
163         return devlink_sb_tc_index_get_from_attrs(devlink_sb, info->attrs,
164                                                   pool_type, p_tc_index);
165 }
166
167 static int devlink_nl_sb_fill(struct sk_buff *msg, struct devlink *devlink,
168                               struct devlink_sb *devlink_sb,
169                               enum devlink_command cmd, u32 portid,
170                               u32 seq, int flags)
171 {
172         void *hdr;
173
174         hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd);
175         if (!hdr)
176                 return -EMSGSIZE;
177
178         if (devlink_nl_put_handle(msg, devlink))
179                 goto nla_put_failure;
180         if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index))
181                 goto nla_put_failure;
182         if (nla_put_u32(msg, DEVLINK_ATTR_SB_SIZE, devlink_sb->size))
183                 goto nla_put_failure;
184         if (nla_put_u16(msg, DEVLINK_ATTR_SB_INGRESS_POOL_COUNT,
185                         devlink_sb->ingress_pools_count))
186                 goto nla_put_failure;
187         if (nla_put_u16(msg, DEVLINK_ATTR_SB_EGRESS_POOL_COUNT,
188                         devlink_sb->egress_pools_count))
189                 goto nla_put_failure;
190         if (nla_put_u16(msg, DEVLINK_ATTR_SB_INGRESS_TC_COUNT,
191                         devlink_sb->ingress_tc_count))
192                 goto nla_put_failure;
193         if (nla_put_u16(msg, DEVLINK_ATTR_SB_EGRESS_TC_COUNT,
194                         devlink_sb->egress_tc_count))
195                 goto nla_put_failure;
196
197         genlmsg_end(msg, hdr);
198         return 0;
199
200 nla_put_failure:
201         genlmsg_cancel(msg, hdr);
202         return -EMSGSIZE;
203 }
204
205 int devlink_nl_sb_get_doit(struct sk_buff *skb, struct genl_info *info)
206 {
207         struct devlink *devlink = info->user_ptr[0];
208         struct devlink_sb *devlink_sb;
209         struct sk_buff *msg;
210         int err;
211
212         devlink_sb = devlink_sb_get_from_info(devlink, info);
213         if (IS_ERR(devlink_sb))
214                 return PTR_ERR(devlink_sb);
215
216         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
217         if (!msg)
218                 return -ENOMEM;
219
220         err = devlink_nl_sb_fill(msg, devlink, devlink_sb,
221                                  DEVLINK_CMD_SB_NEW,
222                                  info->snd_portid, info->snd_seq, 0);
223         if (err) {
224                 nlmsg_free(msg);
225                 return err;
226         }
227
228         return genlmsg_reply(msg, info);
229 }
230
231 static int
232 devlink_nl_sb_get_dump_one(struct sk_buff *msg, struct devlink *devlink,
233                            struct netlink_callback *cb, int flags)
234 {
235         struct devlink_nl_dump_state *state = devlink_dump_state(cb);
236         struct devlink_sb *devlink_sb;
237         int idx = 0;
238         int err = 0;
239
240         list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
241                 if (idx < state->idx) {
242                         idx++;
243                         continue;
244                 }
245                 err = devlink_nl_sb_fill(msg, devlink, devlink_sb,
246                                          DEVLINK_CMD_SB_NEW,
247                                          NETLINK_CB(cb->skb).portid,
248                                          cb->nlh->nlmsg_seq, flags);
249                 if (err) {
250                         state->idx = idx;
251                         break;
252                 }
253                 idx++;
254         }
255
256         return err;
257 }
258
259 int devlink_nl_sb_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb)
260 {
261         return devlink_nl_dumpit(skb, cb, devlink_nl_sb_get_dump_one);
262 }
263
264 static int devlink_nl_sb_pool_fill(struct sk_buff *msg, struct devlink *devlink,
265                                    struct devlink_sb *devlink_sb,
266                                    u16 pool_index, enum devlink_command cmd,
267                                    u32 portid, u32 seq, int flags)
268 {
269         struct devlink_sb_pool_info pool_info;
270         void *hdr;
271         int err;
272
273         err = devlink->ops->sb_pool_get(devlink, devlink_sb->index,
274                                         pool_index, &pool_info);
275         if (err)
276                 return err;
277
278         hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd);
279         if (!hdr)
280                 return -EMSGSIZE;
281
282         if (devlink_nl_put_handle(msg, devlink))
283                 goto nla_put_failure;
284         if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index))
285                 goto nla_put_failure;
286         if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index))
287                 goto nla_put_failure;
288         if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_TYPE, pool_info.pool_type))
289                 goto nla_put_failure;
290         if (nla_put_u32(msg, DEVLINK_ATTR_SB_POOL_SIZE, pool_info.size))
291                 goto nla_put_failure;
292         if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE,
293                        pool_info.threshold_type))
294                 goto nla_put_failure;
295         if (nla_put_u32(msg, DEVLINK_ATTR_SB_POOL_CELL_SIZE,
296                         pool_info.cell_size))
297                 goto nla_put_failure;
298
299         genlmsg_end(msg, hdr);
300         return 0;
301
302 nla_put_failure:
303         genlmsg_cancel(msg, hdr);
304         return -EMSGSIZE;
305 }
306
307 int devlink_nl_sb_pool_get_doit(struct sk_buff *skb, struct genl_info *info)
308 {
309         struct devlink *devlink = info->user_ptr[0];
310         struct devlink_sb *devlink_sb;
311         struct sk_buff *msg;
312         u16 pool_index;
313         int err;
314
315         devlink_sb = devlink_sb_get_from_info(devlink, info);
316         if (IS_ERR(devlink_sb))
317                 return PTR_ERR(devlink_sb);
318
319         err = devlink_sb_pool_index_get_from_info(devlink_sb, info,
320                                                   &pool_index);
321         if (err)
322                 return err;
323
324         if (!devlink->ops->sb_pool_get)
325                 return -EOPNOTSUPP;
326
327         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
328         if (!msg)
329                 return -ENOMEM;
330
331         err = devlink_nl_sb_pool_fill(msg, devlink, devlink_sb, pool_index,
332                                       DEVLINK_CMD_SB_POOL_NEW,
333                                       info->snd_portid, info->snd_seq, 0);
334         if (err) {
335                 nlmsg_free(msg);
336                 return err;
337         }
338
339         return genlmsg_reply(msg, info);
340 }
341
342 static int __sb_pool_get_dumpit(struct sk_buff *msg, int start, int *p_idx,
343                                 struct devlink *devlink,
344                                 struct devlink_sb *devlink_sb,
345                                 u32 portid, u32 seq, int flags)
346 {
347         u16 pool_count = devlink_sb_pool_count(devlink_sb);
348         u16 pool_index;
349         int err;
350
351         for (pool_index = 0; pool_index < pool_count; pool_index++) {
352                 if (*p_idx < start) {
353                         (*p_idx)++;
354                         continue;
355                 }
356                 err = devlink_nl_sb_pool_fill(msg, devlink,
357                                               devlink_sb,
358                                               pool_index,
359                                               DEVLINK_CMD_SB_POOL_NEW,
360                                               portid, seq, flags);
361                 if (err)
362                         return err;
363                 (*p_idx)++;
364         }
365         return 0;
366 }
367
368 static int
369 devlink_nl_sb_pool_get_dump_one(struct sk_buff *msg, struct devlink *devlink,
370                                 struct netlink_callback *cb, int flags)
371 {
372         struct devlink_nl_dump_state *state = devlink_dump_state(cb);
373         struct devlink_sb *devlink_sb;
374         int err = 0;
375         int idx = 0;
376
377         if (!devlink->ops->sb_pool_get)
378                 return 0;
379
380         list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
381                 err = __sb_pool_get_dumpit(msg, state->idx, &idx,
382                                            devlink, devlink_sb,
383                                            NETLINK_CB(cb->skb).portid,
384                                            cb->nlh->nlmsg_seq, flags);
385                 if (err == -EOPNOTSUPP) {
386                         err = 0;
387                 } else if (err) {
388                         state->idx = idx;
389                         break;
390                 }
391         }
392
393         return err;
394 }
395
396 int devlink_nl_sb_pool_get_dumpit(struct sk_buff *skb,
397                                   struct netlink_callback *cb)
398 {
399         return devlink_nl_dumpit(skb, cb, devlink_nl_sb_pool_get_dump_one);
400 }
401
402 static int devlink_sb_pool_set(struct devlink *devlink, unsigned int sb_index,
403                                u16 pool_index, u32 size,
404                                enum devlink_sb_threshold_type threshold_type,
405                                struct netlink_ext_ack *extack)
406
407 {
408         const struct devlink_ops *ops = devlink->ops;
409
410         if (ops->sb_pool_set)
411                 return ops->sb_pool_set(devlink, sb_index, pool_index,
412                                         size, threshold_type, extack);
413         return -EOPNOTSUPP;
414 }
415
416 int devlink_nl_cmd_sb_pool_set_doit(struct sk_buff *skb, struct genl_info *info)
417 {
418         struct devlink *devlink = info->user_ptr[0];
419         enum devlink_sb_threshold_type threshold_type;
420         struct devlink_sb *devlink_sb;
421         u16 pool_index;
422         u32 size;
423         int err;
424
425         devlink_sb = devlink_sb_get_from_info(devlink, info);
426         if (IS_ERR(devlink_sb))
427                 return PTR_ERR(devlink_sb);
428
429         err = devlink_sb_pool_index_get_from_info(devlink_sb, info,
430                                                   &pool_index);
431         if (err)
432                 return err;
433
434         err = devlink_sb_th_type_get_from_info(info, &threshold_type);
435         if (err)
436                 return err;
437
438         if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_POOL_SIZE))
439                 return -EINVAL;
440
441         size = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_POOL_SIZE]);
442         return devlink_sb_pool_set(devlink, devlink_sb->index,
443                                    pool_index, size, threshold_type,
444                                    info->extack);
445 }
446
447 static int devlink_nl_sb_port_pool_fill(struct sk_buff *msg,
448                                         struct devlink *devlink,
449                                         struct devlink_port *devlink_port,
450                                         struct devlink_sb *devlink_sb,
451                                         u16 pool_index,
452                                         enum devlink_command cmd,
453                                         u32 portid, u32 seq, int flags)
454 {
455         const struct devlink_ops *ops = devlink->ops;
456         u32 threshold;
457         void *hdr;
458         int err;
459
460         err = ops->sb_port_pool_get(devlink_port, devlink_sb->index,
461                                     pool_index, &threshold);
462         if (err)
463                 return err;
464
465         hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd);
466         if (!hdr)
467                 return -EMSGSIZE;
468
469         if (devlink_nl_put_handle(msg, devlink))
470                 goto nla_put_failure;
471         if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index))
472                 goto nla_put_failure;
473         if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index))
474                 goto nla_put_failure;
475         if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index))
476                 goto nla_put_failure;
477         if (nla_put_u32(msg, DEVLINK_ATTR_SB_THRESHOLD, threshold))
478                 goto nla_put_failure;
479
480         if (ops->sb_occ_port_pool_get) {
481                 u32 cur;
482                 u32 max;
483
484                 err = ops->sb_occ_port_pool_get(devlink_port, devlink_sb->index,
485                                                 pool_index, &cur, &max);
486                 if (err && err != -EOPNOTSUPP)
487                         goto sb_occ_get_failure;
488                 if (!err) {
489                         if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_CUR, cur))
490                                 goto nla_put_failure;
491                         if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_MAX, max))
492                                 goto nla_put_failure;
493                 }
494         }
495
496         genlmsg_end(msg, hdr);
497         return 0;
498
499 nla_put_failure:
500         err = -EMSGSIZE;
501 sb_occ_get_failure:
502         genlmsg_cancel(msg, hdr);
503         return err;
504 }
505
506 int devlink_nl_sb_port_pool_get_doit(struct sk_buff *skb,
507                                      struct genl_info *info)
508 {
509         struct devlink_port *devlink_port = info->user_ptr[1];
510         struct devlink *devlink = devlink_port->devlink;
511         struct devlink_sb *devlink_sb;
512         struct sk_buff *msg;
513         u16 pool_index;
514         int err;
515
516         devlink_sb = devlink_sb_get_from_info(devlink, info);
517         if (IS_ERR(devlink_sb))
518                 return PTR_ERR(devlink_sb);
519
520         err = devlink_sb_pool_index_get_from_info(devlink_sb, info,
521                                                   &pool_index);
522         if (err)
523                 return err;
524
525         if (!devlink->ops->sb_port_pool_get)
526                 return -EOPNOTSUPP;
527
528         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
529         if (!msg)
530                 return -ENOMEM;
531
532         err = devlink_nl_sb_port_pool_fill(msg, devlink, devlink_port,
533                                            devlink_sb, pool_index,
534                                            DEVLINK_CMD_SB_PORT_POOL_NEW,
535                                            info->snd_portid, info->snd_seq, 0);
536         if (err) {
537                 nlmsg_free(msg);
538                 return err;
539         }
540
541         return genlmsg_reply(msg, info);
542 }
543
544 static int __sb_port_pool_get_dumpit(struct sk_buff *msg, int start, int *p_idx,
545                                      struct devlink *devlink,
546                                      struct devlink_sb *devlink_sb,
547                                      u32 portid, u32 seq, int flags)
548 {
549         struct devlink_port *devlink_port;
550         u16 pool_count = devlink_sb_pool_count(devlink_sb);
551         unsigned long port_index;
552         u16 pool_index;
553         int err;
554
555         xa_for_each(&devlink->ports, port_index, devlink_port) {
556                 for (pool_index = 0; pool_index < pool_count; pool_index++) {
557                         if (*p_idx < start) {
558                                 (*p_idx)++;
559                                 continue;
560                         }
561                         err = devlink_nl_sb_port_pool_fill(msg, devlink,
562                                                            devlink_port,
563                                                            devlink_sb,
564                                                            pool_index,
565                                                            DEVLINK_CMD_SB_PORT_POOL_NEW,
566                                                            portid, seq, flags);
567                         if (err)
568                                 return err;
569                         (*p_idx)++;
570                 }
571         }
572         return 0;
573 }
574
575 static int
576 devlink_nl_sb_port_pool_get_dump_one(struct sk_buff *msg,
577                                      struct devlink *devlink,
578                                      struct netlink_callback *cb, int flags)
579 {
580         struct devlink_nl_dump_state *state = devlink_dump_state(cb);
581         struct devlink_sb *devlink_sb;
582         int idx = 0;
583         int err = 0;
584
585         if (!devlink->ops->sb_port_pool_get)
586                 return 0;
587
588         list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
589                 err = __sb_port_pool_get_dumpit(msg, state->idx, &idx,
590                                                 devlink, devlink_sb,
591                                                 NETLINK_CB(cb->skb).portid,
592                                                 cb->nlh->nlmsg_seq, flags);
593                 if (err == -EOPNOTSUPP) {
594                         err = 0;
595                 } else if (err) {
596                         state->idx = idx;
597                         break;
598                 }
599         }
600
601         return err;
602 }
603
604 int devlink_nl_sb_port_pool_get_dumpit(struct sk_buff *skb,
605                                        struct netlink_callback *cb)
606 {
607         return devlink_nl_dumpit(skb, cb, devlink_nl_sb_port_pool_get_dump_one);
608 }
609
610 static int devlink_sb_port_pool_set(struct devlink_port *devlink_port,
611                                     unsigned int sb_index, u16 pool_index,
612                                     u32 threshold,
613                                     struct netlink_ext_ack *extack)
614
615 {
616         const struct devlink_ops *ops = devlink_port->devlink->ops;
617
618         if (ops->sb_port_pool_set)
619                 return ops->sb_port_pool_set(devlink_port, sb_index,
620                                              pool_index, threshold, extack);
621         return -EOPNOTSUPP;
622 }
623
624 int devlink_nl_cmd_sb_port_pool_set_doit(struct sk_buff *skb,
625                                          struct genl_info *info)
626 {
627         struct devlink_port *devlink_port = info->user_ptr[1];
628         struct devlink *devlink = info->user_ptr[0];
629         struct devlink_sb *devlink_sb;
630         u16 pool_index;
631         u32 threshold;
632         int err;
633
634         devlink_sb = devlink_sb_get_from_info(devlink, info);
635         if (IS_ERR(devlink_sb))
636                 return PTR_ERR(devlink_sb);
637
638         err = devlink_sb_pool_index_get_from_info(devlink_sb, info,
639                                                   &pool_index);
640         if (err)
641                 return err;
642
643         if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_THRESHOLD))
644                 return -EINVAL;
645
646         threshold = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_THRESHOLD]);
647         return devlink_sb_port_pool_set(devlink_port, devlink_sb->index,
648                                         pool_index, threshold, info->extack);
649 }
650
651 static int
652 devlink_nl_sb_tc_pool_bind_fill(struct sk_buff *msg, struct devlink *devlink,
653                                 struct devlink_port *devlink_port,
654                                 struct devlink_sb *devlink_sb, u16 tc_index,
655                                 enum devlink_sb_pool_type pool_type,
656                                 enum devlink_command cmd,
657                                 u32 portid, u32 seq, int flags)
658 {
659         const struct devlink_ops *ops = devlink->ops;
660         u16 pool_index;
661         u32 threshold;
662         void *hdr;
663         int err;
664
665         err = ops->sb_tc_pool_bind_get(devlink_port, devlink_sb->index,
666                                        tc_index, pool_type,
667                                        &pool_index, &threshold);
668         if (err)
669                 return err;
670
671         hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd);
672         if (!hdr)
673                 return -EMSGSIZE;
674
675         if (devlink_nl_put_handle(msg, devlink))
676                 goto nla_put_failure;
677         if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index))
678                 goto nla_put_failure;
679         if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index))
680                 goto nla_put_failure;
681         if (nla_put_u16(msg, DEVLINK_ATTR_SB_TC_INDEX, tc_index))
682                 goto nla_put_failure;
683         if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_TYPE, pool_type))
684                 goto nla_put_failure;
685         if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index))
686                 goto nla_put_failure;
687         if (nla_put_u32(msg, DEVLINK_ATTR_SB_THRESHOLD, threshold))
688                 goto nla_put_failure;
689
690         if (ops->sb_occ_tc_port_bind_get) {
691                 u32 cur;
692                 u32 max;
693
694                 err = ops->sb_occ_tc_port_bind_get(devlink_port,
695                                                    devlink_sb->index,
696                                                    tc_index, pool_type,
697                                                    &cur, &max);
698                 if (err && err != -EOPNOTSUPP)
699                         return err;
700                 if (!err) {
701                         if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_CUR, cur))
702                                 goto nla_put_failure;
703                         if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_MAX, max))
704                                 goto nla_put_failure;
705                 }
706         }
707
708         genlmsg_end(msg, hdr);
709         return 0;
710
711 nla_put_failure:
712         genlmsg_cancel(msg, hdr);
713         return -EMSGSIZE;
714 }
715
716 int devlink_nl_sb_tc_pool_bind_get_doit(struct sk_buff *skb,
717                                         struct genl_info *info)
718 {
719         struct devlink_port *devlink_port = info->user_ptr[1];
720         struct devlink *devlink = devlink_port->devlink;
721         struct devlink_sb *devlink_sb;
722         struct sk_buff *msg;
723         enum devlink_sb_pool_type pool_type;
724         u16 tc_index;
725         int err;
726
727         devlink_sb = devlink_sb_get_from_info(devlink, info);
728         if (IS_ERR(devlink_sb))
729                 return PTR_ERR(devlink_sb);
730
731         err = devlink_sb_pool_type_get_from_info(info, &pool_type);
732         if (err)
733                 return err;
734
735         err = devlink_sb_tc_index_get_from_info(devlink_sb, info,
736                                                 pool_type, &tc_index);
737         if (err)
738                 return err;
739
740         if (!devlink->ops->sb_tc_pool_bind_get)
741                 return -EOPNOTSUPP;
742
743         msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
744         if (!msg)
745                 return -ENOMEM;
746
747         err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink, devlink_port,
748                                               devlink_sb, tc_index, pool_type,
749                                               DEVLINK_CMD_SB_TC_POOL_BIND_NEW,
750                                               info->snd_portid,
751                                               info->snd_seq, 0);
752         if (err) {
753                 nlmsg_free(msg);
754                 return err;
755         }
756
757         return genlmsg_reply(msg, info);
758 }
759
760 static int __sb_tc_pool_bind_get_dumpit(struct sk_buff *msg,
761                                         int start, int *p_idx,
762                                         struct devlink *devlink,
763                                         struct devlink_sb *devlink_sb,
764                                         u32 portid, u32 seq, int flags)
765 {
766         struct devlink_port *devlink_port;
767         unsigned long port_index;
768         u16 tc_index;
769         int err;
770
771         xa_for_each(&devlink->ports, port_index, devlink_port) {
772                 for (tc_index = 0;
773                      tc_index < devlink_sb->ingress_tc_count; tc_index++) {
774                         if (*p_idx < start) {
775                                 (*p_idx)++;
776                                 continue;
777                         }
778                         err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink,
779                                                               devlink_port,
780                                                               devlink_sb,
781                                                               tc_index,
782                                                               DEVLINK_SB_POOL_TYPE_INGRESS,
783                                                               DEVLINK_CMD_SB_TC_POOL_BIND_NEW,
784                                                               portid, seq,
785                                                               flags);
786                         if (err)
787                                 return err;
788                         (*p_idx)++;
789                 }
790                 for (tc_index = 0;
791                      tc_index < devlink_sb->egress_tc_count; tc_index++) {
792                         if (*p_idx < start) {
793                                 (*p_idx)++;
794                                 continue;
795                         }
796                         err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink,
797                                                               devlink_port,
798                                                               devlink_sb,
799                                                               tc_index,
800                                                               DEVLINK_SB_POOL_TYPE_EGRESS,
801                                                               DEVLINK_CMD_SB_TC_POOL_BIND_NEW,
802                                                               portid, seq,
803                                                               flags);
804                         if (err)
805                                 return err;
806                         (*p_idx)++;
807                 }
808         }
809         return 0;
810 }
811
812 static int devlink_nl_sb_tc_pool_bind_get_dump_one(struct sk_buff *msg,
813                                                    struct devlink *devlink,
814                                                    struct netlink_callback *cb,
815                                                    int flags)
816 {
817         struct devlink_nl_dump_state *state = devlink_dump_state(cb);
818         struct devlink_sb *devlink_sb;
819         int idx = 0;
820         int err = 0;
821
822         if (!devlink->ops->sb_tc_pool_bind_get)
823                 return 0;
824
825         list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
826                 err = __sb_tc_pool_bind_get_dumpit(msg, state->idx, &idx,
827                                                    devlink, devlink_sb,
828                                                    NETLINK_CB(cb->skb).portid,
829                                                    cb->nlh->nlmsg_seq, flags);
830                 if (err == -EOPNOTSUPP) {
831                         err = 0;
832                 } else if (err) {
833                         state->idx = idx;
834                         break;
835                 }
836         }
837
838         return err;
839 }
840
841 int devlink_nl_sb_tc_pool_bind_get_dumpit(struct sk_buff *skb,
842                                           struct netlink_callback *cb)
843 {
844         return devlink_nl_dumpit(skb, cb,
845                                  devlink_nl_sb_tc_pool_bind_get_dump_one);
846 }
847
848 static int devlink_sb_tc_pool_bind_set(struct devlink_port *devlink_port,
849                                        unsigned int sb_index, u16 tc_index,
850                                        enum devlink_sb_pool_type pool_type,
851                                        u16 pool_index, u32 threshold,
852                                        struct netlink_ext_ack *extack)
853
854 {
855         const struct devlink_ops *ops = devlink_port->devlink->ops;
856
857         if (ops->sb_tc_pool_bind_set)
858                 return ops->sb_tc_pool_bind_set(devlink_port, sb_index,
859                                                 tc_index, pool_type,
860                                                 pool_index, threshold, extack);
861         return -EOPNOTSUPP;
862 }
863
864 int devlink_nl_cmd_sb_tc_pool_bind_set_doit(struct sk_buff *skb,
865                                             struct genl_info *info)
866 {
867         struct devlink_port *devlink_port = info->user_ptr[1];
868         struct devlink *devlink = info->user_ptr[0];
869         enum devlink_sb_pool_type pool_type;
870         struct devlink_sb *devlink_sb;
871         u16 tc_index;
872         u16 pool_index;
873         u32 threshold;
874         int err;
875
876         devlink_sb = devlink_sb_get_from_info(devlink, info);
877         if (IS_ERR(devlink_sb))
878                 return PTR_ERR(devlink_sb);
879
880         err = devlink_sb_pool_type_get_from_info(info, &pool_type);
881         if (err)
882                 return err;
883
884         err = devlink_sb_tc_index_get_from_info(devlink_sb, info,
885                                                 pool_type, &tc_index);
886         if (err)
887                 return err;
888
889         err = devlink_sb_pool_index_get_from_info(devlink_sb, info,
890                                                   &pool_index);
891         if (err)
892                 return err;
893
894         if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_THRESHOLD))
895                 return -EINVAL;
896
897         threshold = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_THRESHOLD]);
898         return devlink_sb_tc_pool_bind_set(devlink_port, devlink_sb->index,
899                                            tc_index, pool_type,
900                                            pool_index, threshold, info->extack);
901 }
902
903 int devlink_nl_cmd_sb_occ_snapshot_doit(struct sk_buff *skb,
904                                         struct genl_info *info)
905 {
906         struct devlink *devlink = info->user_ptr[0];
907         const struct devlink_ops *ops = devlink->ops;
908         struct devlink_sb *devlink_sb;
909
910         devlink_sb = devlink_sb_get_from_info(devlink, info);
911         if (IS_ERR(devlink_sb))
912                 return PTR_ERR(devlink_sb);
913
914         if (ops->sb_occ_snapshot)
915                 return ops->sb_occ_snapshot(devlink, devlink_sb->index);
916         return -EOPNOTSUPP;
917 }
918
919 int devlink_nl_cmd_sb_occ_max_clear_doit(struct sk_buff *skb,
920                                          struct genl_info *info)
921 {
922         struct devlink *devlink = info->user_ptr[0];
923         const struct devlink_ops *ops = devlink->ops;
924         struct devlink_sb *devlink_sb;
925
926         devlink_sb = devlink_sb_get_from_info(devlink, info);
927         if (IS_ERR(devlink_sb))
928                 return PTR_ERR(devlink_sb);
929
930         if (ops->sb_occ_max_clear)
931                 return ops->sb_occ_max_clear(devlink, devlink_sb->index);
932         return -EOPNOTSUPP;
933 }
934
935 int devl_sb_register(struct devlink *devlink, unsigned int sb_index,
936                      u32 size, u16 ingress_pools_count,
937                      u16 egress_pools_count, u16 ingress_tc_count,
938                      u16 egress_tc_count)
939 {
940         struct devlink_sb *devlink_sb;
941
942         lockdep_assert_held(&devlink->lock);
943
944         if (devlink_sb_index_exists(devlink, sb_index))
945                 return -EEXIST;
946
947         devlink_sb = kzalloc(sizeof(*devlink_sb), GFP_KERNEL);
948         if (!devlink_sb)
949                 return -ENOMEM;
950         devlink_sb->index = sb_index;
951         devlink_sb->size = size;
952         devlink_sb->ingress_pools_count = ingress_pools_count;
953         devlink_sb->egress_pools_count = egress_pools_count;
954         devlink_sb->ingress_tc_count = ingress_tc_count;
955         devlink_sb->egress_tc_count = egress_tc_count;
956         list_add_tail(&devlink_sb->list, &devlink->sb_list);
957         return 0;
958 }
959 EXPORT_SYMBOL_GPL(devl_sb_register);
960
961 int devlink_sb_register(struct devlink *devlink, unsigned int sb_index,
962                         u32 size, u16 ingress_pools_count,
963                         u16 egress_pools_count, u16 ingress_tc_count,
964                         u16 egress_tc_count)
965 {
966         int err;
967
968         devl_lock(devlink);
969         err = devl_sb_register(devlink, sb_index, size, ingress_pools_count,
970                                egress_pools_count, ingress_tc_count,
971                                egress_tc_count);
972         devl_unlock(devlink);
973         return err;
974 }
975 EXPORT_SYMBOL_GPL(devlink_sb_register);
976
977 void devl_sb_unregister(struct devlink *devlink, unsigned int sb_index)
978 {
979         struct devlink_sb *devlink_sb;
980
981         lockdep_assert_held(&devlink->lock);
982
983         devlink_sb = devlink_sb_get_by_index(devlink, sb_index);
984         WARN_ON(!devlink_sb);
985         list_del(&devlink_sb->list);
986         kfree(devlink_sb);
987 }
988 EXPORT_SYMBOL_GPL(devl_sb_unregister);
989
990 void devlink_sb_unregister(struct devlink *devlink, unsigned int sb_index)
991 {
992         devl_lock(devlink);
993         devl_sb_unregister(devlink, sb_index);
994         devl_unlock(devlink);
995 }
996 EXPORT_SYMBOL_GPL(devlink_sb_unregister);