Commit | Line | Data |
---|---|---|
92095781 DH |
1 | /* SPDX-License-Identifier: GPL-2.0 */ |
2 | #ifndef MEAN_AND_VARIANCE_H_ | |
3 | #define MEAN_AND_VARIANCE_H_ | |
4 | ||
5 | #include <linux/types.h> | |
6 | #include <linux/limits.h> | |
db32bb9a | 7 | #include <linux/math.h> |
92095781 DH |
8 | #include <linux/math64.h> |
9 | ||
10 | #define SQRT_U64_MAX 4294967295ULL | |
11 | ||
12 | /* | |
13 | * u128_u: u128 user mode, because not all architectures support a real int128 | |
14 | * type | |
44fd13a4 KO |
15 | * |
16 | * We don't use this version in userspace, because in userspace we link with | |
17 | * Rust and rustc has issues with u128. | |
92095781 DH |
18 | */ |
19 | ||
eba38cc7 | 20 | #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC) |
92095781 DH |
21 | |
22 | typedef struct { | |
23 | unsigned __int128 v; | |
24 | } __aligned(16) u128_u; | |
25 | ||
26 | static inline u128_u u64_to_u128(u64 a) | |
27 | { | |
28 | return (u128_u) { .v = a }; | |
29 | } | |
30 | ||
31 | static inline u64 u128_lo(u128_u a) | |
32 | { | |
33 | return a.v; | |
34 | } | |
35 | ||
36 | static inline u64 u128_hi(u128_u a) | |
37 | { | |
38 | return a.v >> 64; | |
39 | } | |
40 | ||
41 | static inline u128_u u128_add(u128_u a, u128_u b) | |
42 | { | |
43 | a.v += b.v; | |
44 | return a; | |
45 | } | |
46 | ||
47 | static inline u128_u u128_sub(u128_u a, u128_u b) | |
48 | { | |
49 | a.v -= b.v; | |
50 | return a; | |
51 | } | |
52 | ||
53 | static inline u128_u u128_shl(u128_u a, s8 shift) | |
54 | { | |
55 | a.v <<= shift; | |
56 | return a; | |
57 | } | |
58 | ||
59 | static inline u128_u u128_square(u64 a) | |
60 | { | |
61 | u128_u b = u64_to_u128(a); | |
62 | ||
63 | b.v *= b.v; | |
64 | return b; | |
65 | } | |
66 | ||
67 | #else | |
68 | ||
69 | typedef struct { | |
70 | u64 hi, lo; | |
71 | } __aligned(16) u128_u; | |
72 | ||
73 | /* conversions */ | |
74 | ||
75 | static inline u128_u u64_to_u128(u64 a) | |
76 | { | |
77 | return (u128_u) { .lo = a }; | |
78 | } | |
79 | ||
80 | static inline u64 u128_lo(u128_u a) | |
81 | { | |
82 | return a.lo; | |
83 | } | |
84 | ||
85 | static inline u64 u128_hi(u128_u a) | |
86 | { | |
87 | return a.hi; | |
88 | } | |
89 | ||
90 | /* arithmetic */ | |
91 | ||
92 | static inline u128_u u128_add(u128_u a, u128_u b) | |
93 | { | |
94 | u128_u c; | |
95 | ||
96 | c.lo = a.lo + b.lo; | |
97 | c.hi = a.hi + b.hi + (c.lo < a.lo); | |
98 | return c; | |
99 | } | |
100 | ||
101 | static inline u128_u u128_sub(u128_u a, u128_u b) | |
102 | { | |
103 | u128_u c; | |
104 | ||
105 | c.lo = a.lo - b.lo; | |
106 | c.hi = a.hi - b.hi - (c.lo > a.lo); | |
107 | return c; | |
108 | } | |
109 | ||
110 | static inline u128_u u128_shl(u128_u i, s8 shift) | |
111 | { | |
112 | u128_u r; | |
113 | ||
114 | r.lo = i.lo << shift; | |
115 | if (shift < 64) | |
116 | r.hi = (i.hi << shift) | (i.lo >> (64 - shift)); | |
117 | else { | |
118 | r.hi = i.lo << (shift - 64); | |
119 | r.lo = 0; | |
120 | } | |
121 | return r; | |
122 | } | |
123 | ||
124 | static inline u128_u u128_square(u64 i) | |
125 | { | |
126 | u128_u r; | |
127 | u64 h = i >> 32, l = i & U32_MAX; | |
128 | ||
129 | r = u128_shl(u64_to_u128(h*h), 64); | |
130 | r = u128_add(r, u128_shl(u64_to_u128(h*l), 32)); | |
131 | r = u128_add(r, u128_shl(u64_to_u128(l*h), 32)); | |
132 | r = u128_add(r, u64_to_u128(l*l)); | |
133 | return r; | |
134 | } | |
135 | ||
136 | #endif | |
137 | ||
138 | static inline u128_u u64s_to_u128(u64 hi, u64 lo) | |
139 | { | |
140 | u128_u c = u64_to_u128(hi); | |
141 | ||
142 | c = u128_shl(c, 64); | |
143 | c = u128_add(c, u64_to_u128(lo)); | |
144 | return c; | |
145 | } | |
146 | ||
147 | u128_u u128_div(u128_u n, u64 d); | |
148 | ||
149 | struct mean_and_variance { | |
150 | s64 n; | |
151 | s64 sum; | |
152 | u128_u sum_squares; | |
153 | }; | |
154 | ||
155 | /* expontentially weighted variant */ | |
156 | struct mean_and_variance_weighted { | |
157 | bool init; | |
158 | u8 weight; /* base 2 logarithim */ | |
159 | s64 mean; | |
160 | u64 variance; | |
161 | }; | |
162 | ||
163 | /** | |
164 | * fast_divpow2() - fast approximation for n / (1 << d) | |
165 | * @n: numerator | |
166 | * @d: the power of 2 denominator. | |
167 | * | |
168 | * note: this rounds towards 0. | |
169 | */ | |
170 | static inline s64 fast_divpow2(s64 n, u8 d) | |
171 | { | |
172 | return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d; | |
173 | } | |
174 | ||
175 | /** | |
176 | * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1 | |
177 | * and return it. | |
178 | * @s1: the mean_and_variance to update. | |
179 | * @v1: the new sample. | |
180 | * | |
181 | * see linked pdf equation 12. | |
182 | */ | |
65bc4109 KO |
183 | static inline void |
184 | mean_and_variance_update(struct mean_and_variance *s, s64 v) | |
185 | { | |
186 | s->n++; | |
187 | s->sum += v; | |
188 | s->sum_squares = u128_add(s->sum_squares, u128_square(abs(v))); | |
92095781 DH |
189 | } |
190 | ||
191 | s64 mean_and_variance_get_mean(struct mean_and_variance s); | |
192 | u64 mean_and_variance_get_variance(struct mean_and_variance s1); | |
193 | u32 mean_and_variance_get_stddev(struct mean_and_variance s); | |
194 | ||
195 | void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v); | |
196 | ||
197 | s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s); | |
198 | u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s); | |
199 | u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s); | |
200 | ||
201 | #endif // MEAN_AND_VAIRANCE_H_ |