bcachefs: Mean and variance
authorDaniel Hill <daniel@gluo.nz>
Sat, 6 Aug 2022 02:48:49 +0000 (14:48 +1200)
committerKent Overstreet <kent.overstreet@linux.dev>
Sun, 22 Oct 2023 21:09:43 +0000 (17:09 -0400)
This module provides a fast 64bit implementation of basic statistics
functions, including mean, variance and standard deviation in both
weighted and unweighted variants, the unweighted variant has a 32bit
limitation per sample to prevent overflow when squaring.

Signed-off-by: Daniel Hill <daniel@gluo.nz>
Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
fs/bcachefs/Kconfig
fs/bcachefs/Makefile
fs/bcachefs/mean_and_variance.c [new file with mode: 0644]
fs/bcachefs/mean_and_variance.h [new file with mode: 0644]
fs/bcachefs/mean_and_variance_test.c [new file with mode: 0644]

index 76953e05b240f4f2ffd89522aed872c354bbf0ba..f8e2088269978cf83c01a5c6cc301ed99c5ffcab 100644 (file)
@@ -71,3 +71,12 @@ config BCACHEFS_NO_LATENCY_ACCT
        depends on BCACHEFS_FS
        help
        This disables device latency tracking and time stats, only for performance testing
+
+config MEAN_AND_VARIANCE_UNIT_TEST
+       tristate "mean_and_variance unit tests" if !KUNIT_ALL_TESTS
+       depends on KUNIT
+       select MEAN_AND_VARIANCE
+       default KUNIT_ALL_TESTS
+       help
+         This option enables the kunit tests for mean_and_variance module.
+         If unsure, say N.
index e23667548e09ef47c161e3ecdffa93c1dcd28403..444e79c62b5095cce12b2158fa543719991d4a04 100644 (file)
@@ -46,6 +46,7 @@ bcachefs-y            :=      \
        journal_seq_blacklist.o \
        keylist.o               \
        lru.o                   \
+       mean_and_variance.o     \
        migrate.o               \
        move.o                  \
        movinggc.o              \
@@ -69,3 +70,4 @@ bcachefs-y            :=      \
        xattr.o
 
 bcachefs-$(CONFIG_BCACHEFS_POSIX_ACL) += acl.o
+obj-$(CONFIG_MEAN_AND_VARIANCE_UNIT_TEST)   += mean_and_variance_test.o
diff --git a/fs/bcachefs/mean_and_variance.c b/fs/bcachefs/mean_and_variance.c
new file mode 100644 (file)
index 0000000..1f0801e
--- /dev/null
@@ -0,0 +1,159 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Functions for incremental mean and variance.
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 as published by
+ * the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
+ * more details.
+ *
+ * Copyright © 2022 Daniel B. Hill
+ *
+ * Author: Daniel B. Hill <daniel@gluo.nz>
+ *
+ * Description:
+ *
+ * This is includes some incremental algorithms for mean and variance calculation
+ *
+ * Derived from the paper: https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
+ *
+ * Create a struct and if it's the weighted variant set the w field (weight = 2^k).
+ *
+ * Use mean_and_variance[_weighted]_update() on the struct to update it's state.
+ *
+ * Use the mean_and_variance[_weighted]_get_* functions to calculate the mean and variance, some computation
+ * is deferred to these functions for performance reasons.
+ *
+ * see lib/math/mean_and_variance_test.c for examples of usage.
+ *
+ * DO NOT access the mean and variance fields of the weighted variants directly.
+ * DO NOT change the weight after calling update.
+ */
+
+#include <linux/bug.h>
+#include <linux/compiler.h>
+#include <linux/export.h>
+#include <linux/limits.h>
+#include <linux/math.h>
+#include <linux/math64.h>
+#include <linux/module.h>
+
+#include "mean_and_variance.h"
+
+u128_u u128_div(u128_u n, u64 d)
+{
+       u128_u r;
+       u64 rem;
+       u64 hi = u128_hi(n);
+       u64 lo = u128_lo(n);
+       u64  h =  hi & ((u64) U32_MAX  << 32);
+       u64  l = (hi &  (u64) U32_MAX) << 32;
+
+       r =             u128_shl(u64_to_u128(div64_u64_rem(h,                d, &rem)), 64);
+       r = u128_add(r, u128_shl(u64_to_u128(div64_u64_rem(l  + (rem << 32), d, &rem)), 32));
+       r = u128_add(r,          u64_to_u128(div64_u64_rem(lo + (rem << 32), d, &rem)));
+       return r;
+}
+EXPORT_SYMBOL_GPL(u128_div);
+
+/**
+ * mean_and_variance_get_mean() - get mean from @s
+ */
+s64 mean_and_variance_get_mean(struct mean_and_variance s)
+{
+       return s.n ? div64_u64(s.sum, s.n) : 0;
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_get_mean);
+
+/**
+ * mean_and_variance_get_variance() -  get variance from @s1
+ *
+ * see linked pdf equation 12.
+ */
+u64 mean_and_variance_get_variance(struct mean_and_variance s1)
+{
+       if (s1.n) {
+               u128_u s2 = u128_div(s1.sum_squares, s1.n);
+               u64  s3 = abs(mean_and_variance_get_mean(s1));
+
+               return u128_lo(u128_sub(s2, u128_square(s3)));
+       } else {
+               return 0;
+       }
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_get_variance);
+
+/**
+ * mean_and_variance_get_stddev() - get standard deviation from @s
+ */
+u32 mean_and_variance_get_stddev(struct mean_and_variance s)
+{
+       return int_sqrt64(mean_and_variance_get_variance(s));
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_get_stddev);
+
+/**
+ * mean_and_variance_weighted_update() - exponentially weighted variant of mean_and_variance_update()
+ * @s1: ..
+ * @s2: ..
+ *
+ * see linked pdf: function derived from equations 140-143 where alpha = 2^w.
+ * values are stored bitshifted for performance and added precision.
+ */
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 x)
+{
+       // previous weighted variance.
+       u8 w            = s->weight;
+       u64 var_w0      = s->variance;
+       // new value weighted.
+       s64 x_w         = x << w;
+       s64 diff_w      = x_w - s->mean;
+       s64 diff        = fast_divpow2(diff_w, w);
+       // new mean weighted.
+       s64 u_w1        = s->mean + diff;
+
+       if (!s->init) {
+               s->mean = x_w;
+               s->variance = 0;
+       } else {
+               s->mean = u_w1;
+               s->variance = ((var_w0 << w) - var_w0 + ((diff_w * (x_w - u_w1)) >> w)) >> w;
+       }
+       s->init = true;
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_weighted_update);
+
+/**
+ * mean_and_variance_weighted_get_mean() - get mean from @s
+ */
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s)
+{
+       return fast_divpow2(s.mean, s.weight);
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_mean);
+
+/**
+ * mean_and_variance_weighted_get_variance() -- get variance from @s
+ */
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s)
+{
+       // always positive don't need fast divpow2
+       return s.variance >> s.weight;
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_variance);
+
+/**
+ * mean_and_variance_weighted_get_stddev() - get standard deviation from @s
+ */
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s)
+{
+       return int_sqrt64(mean_and_variance_weighted_get_variance(s));
+}
+EXPORT_SYMBOL_GPL(mean_and_variance_weighted_get_stddev);
+
+MODULE_AUTHOR("Daniel B. Hill");
+MODULE_LICENSE("GPL");
diff --git a/fs/bcachefs/mean_and_variance.h b/fs/bcachefs/mean_and_variance.h
new file mode 100644 (file)
index 0000000..880e950
--- /dev/null
@@ -0,0 +1,199 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef MEAN_AND_VARIANCE_H_
+#define MEAN_AND_VARIANCE_H_
+
+#include <linux/types.h>
+#include <linux/limits.h>
+#include <linux/math64.h>
+
+#define SQRT_U64_MAX 4294967295ULL
+
+/*
+ * u128_u: u128 user mode, because not all architectures support a real int128
+ * type
+ */
+
+#ifdef __SIZEOF_INT128__
+
+typedef struct {
+       unsigned __int128 v;
+} __aligned(16) u128_u;
+
+static inline u128_u u64_to_u128(u64 a)
+{
+       return (u128_u) { .v = a };
+}
+
+static inline u64 u128_lo(u128_u a)
+{
+       return a.v;
+}
+
+static inline u64 u128_hi(u128_u a)
+{
+       return a.v >> 64;
+}
+
+static inline u128_u u128_add(u128_u a, u128_u b)
+{
+       a.v += b.v;
+       return a;
+}
+
+static inline u128_u u128_sub(u128_u a, u128_u b)
+{
+       a.v -= b.v;
+       return a;
+}
+
+static inline u128_u u128_shl(u128_u a, s8 shift)
+{
+       a.v <<= shift;
+       return a;
+}
+
+static inline u128_u u128_square(u64 a)
+{
+       u128_u b = u64_to_u128(a);
+
+       b.v *= b.v;
+       return b;
+}
+
+#else
+
+typedef struct {
+       u64 hi, lo;
+} __aligned(16) u128_u;
+
+/* conversions */
+
+static inline u128_u u64_to_u128(u64 a)
+{
+       return (u128_u) { .lo = a };
+}
+
+static inline u64 u128_lo(u128_u a)
+{
+       return a.lo;
+}
+
+static inline u64 u128_hi(u128_u a)
+{
+       return a.hi;
+}
+
+/* arithmetic */
+
+static inline u128_u u128_add(u128_u a, u128_u b)
+{
+       u128_u c;
+
+       c.lo = a.lo + b.lo;
+       c.hi = a.hi + b.hi + (c.lo < a.lo);
+       return c;
+}
+
+static inline u128_u u128_sub(u128_u a, u128_u b)
+{
+       u128_u c;
+
+       c.lo = a.lo - b.lo;
+       c.hi = a.hi - b.hi - (c.lo > a.lo);
+       return c;
+}
+
+static inline u128_u u128_shl(u128_u i, s8 shift)
+{
+       u128_u r;
+
+       r.lo = i.lo << shift;
+       if (shift < 64)
+               r.hi = (i.hi << shift) | (i.lo >> (64 - shift));
+       else {
+               r.hi = i.lo << (shift - 64);
+               r.lo = 0;
+       }
+       return r;
+}
+
+static inline u128_u u128_square(u64 i)
+{
+       u128_u r;
+       u64  h = i >> 32, l = i & U32_MAX;
+
+       r =             u128_shl(u64_to_u128(h*h), 64);
+       r = u128_add(r, u128_shl(u64_to_u128(h*l), 32));
+       r = u128_add(r, u128_shl(u64_to_u128(l*h), 32));
+       r = u128_add(r,          u64_to_u128(l*l));
+       return r;
+}
+
+#endif
+
+static inline u128_u u64s_to_u128(u64 hi, u64 lo)
+{
+       u128_u c = u64_to_u128(hi);
+
+       c = u128_shl(c, 64);
+       c = u128_add(c, u64_to_u128(lo));
+       return c;
+}
+
+u128_u u128_div(u128_u n, u64 d);
+
+struct mean_and_variance {
+       s64     n;
+       s64     sum;
+       u128_u  sum_squares;
+};
+
+/* expontentially weighted variant */
+struct mean_and_variance_weighted {
+       bool    init;
+       u8      weight; /* base 2 logarithim */
+       s64     mean;
+       u64     variance;
+};
+
+/**
+ * fast_divpow2() - fast approximation for n / (1 << d)
+ * @n: numerator
+ * @d: the power of 2 denominator.
+ *
+ * note: this rounds towards 0.
+ */
+static inline s64 fast_divpow2(s64 n, u8 d)
+{
+       return (n + ((n < 0) ? ((1 << d) - 1) : 0)) >> d;
+}
+
+/**
+ * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
+ * and return it.
+ * @s1: the mean_and_variance to update.
+ * @v1: the new sample.
+ *
+ * see linked pdf equation 12.
+ */
+static inline struct mean_and_variance
+mean_and_variance_update(struct mean_and_variance s, s64 v)
+{
+       return (struct mean_and_variance) {
+               .n           = s.n + 1,
+               .sum         = s.sum + v,
+               .sum_squares = u128_add(s.sum_squares, u128_square(abs(v))),
+       };
+}
+
+s64 mean_and_variance_get_mean(struct mean_and_variance s);
+u64 mean_and_variance_get_variance(struct mean_and_variance s1);
+u32 mean_and_variance_get_stddev(struct mean_and_variance s);
+
+void mean_and_variance_weighted_update(struct mean_and_variance_weighted *s, s64 v);
+
+s64 mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s);
+u64 mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s);
+u32 mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s);
+
+#endif // MEAN_AND_VAIRANCE_H_
diff --git a/fs/bcachefs/mean_and_variance_test.c b/fs/bcachefs/mean_and_variance_test.c
new file mode 100644 (file)
index 0000000..2b4cf9b
--- /dev/null
@@ -0,0 +1,153 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <kunit/test.h>
+
+#include "mean_and_variance.h"
+
+#define MAX_SQR (SQRT_U64_MAX*SQRT_U64_MAX)
+
+static void mean_and_variance_basic_test(struct kunit *test)
+{
+       struct mean_and_variance s = {};
+
+       s = mean_and_variance_update(s, 2);
+       s = mean_and_variance_update(s, 2);
+
+       KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 2);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 0);
+       KUNIT_EXPECT_EQ(test, s.n, 2);
+
+       s = mean_and_variance_update(s, 4);
+       s = mean_and_variance_update(s, 4);
+
+       KUNIT_EXPECT_EQ(test, mean_and_variance_get_mean(s), 3);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_get_variance(s), 1);
+       KUNIT_EXPECT_EQ(test, s.n, 4);
+}
+
+/*
+ * Test values computed using a spreadsheet from the psuedocode at the bottom:
+ * https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf
+ */
+
+static void mean_and_variance_weighted_test(struct kunit *test)
+{
+       struct mean_and_variance_weighted s = { .weight = 2 };
+
+       s.weight = 2;
+
+       mean_and_variance_weighted_update(&s, 10);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 10);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
+
+       mean_and_variance_weighted_update(&s, 20);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 12);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
+
+       mean_and_variance_weighted_update(&s, 30);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 16);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
+
+       s = (struct mean_and_variance_weighted) { .weight = 2 };
+
+       mean_and_variance_weighted_update(&s, -10);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -10);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 0);
+
+       mean_and_variance_weighted_update(&s, -20);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -12);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 18);
+
+       mean_and_variance_weighted_update(&s, -30);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -16);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 72);
+
+}
+
+static void mean_and_variance_weighted_advanced_test(struct kunit *test)
+{
+       struct mean_and_variance_weighted s = { .weight = 8 };
+       s64 i;
+
+       for (i = 10; i <= 100; i += 10)
+               mean_and_variance_weighted_update(&s, i);
+
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), 11);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
+
+       s = (struct mean_and_variance_weighted) { .weight = 8 };
+
+       for (i = -10; i >= -100; i -= 10)
+               mean_and_variance_weighted_update(&s, i);
+
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_mean(s), -11);
+       KUNIT_EXPECT_EQ(test, mean_and_variance_weighted_get_variance(s), 107);
+
+}
+
+static void mean_and_variance_fast_divpow2(struct kunit *test)
+{
+       s64 i;
+       u8 d;
+
+       for (i = 0; i < 100; i++) {
+               d = 0;
+               KUNIT_EXPECT_EQ(test, fast_divpow2(i, d), div_u64(i, 1LLU << d));
+               KUNIT_EXPECT_EQ(test, abs(fast_divpow2(-i, d)), div_u64(i, 1LLU << d));
+               for (d = 1; d < 32; d++) {
+                       KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(i, d)),
+                                           div_u64(i, 1 << d), "%lld %u", i, d);
+                       KUNIT_EXPECT_EQ_MSG(test, abs(fast_divpow2(-i, d)),
+                                           div_u64(i, 1 << d), "%lld %u", -i, d);
+               }
+       }
+}
+
+static void mean_and_variance_u128_basic_test(struct kunit *test)
+{
+       u128_u a  = u64s_to_u128(0, U64_MAX);
+       u128_u a1 = u64s_to_u128(0, 1);
+       u128_u b  = u64s_to_u128(1, 0);
+       u128_u c  = u64s_to_u128(0, 1LLU << 63);
+       u128_u c2 = u64s_to_u128(U64_MAX, U64_MAX);
+
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a, a1)), 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a, a1)), 0);
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_add(a1, a)), 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_add(a1, a)), 0);
+
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_sub(b, a1)), U64_MAX);
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_sub(b, a1)), 0);
+
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_shl(c, 1)), 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_shl(c, 1)), 0);
+
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_square(U64_MAX)), U64_MAX - 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_square(U64_MAX)), 1);
+
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_div(b, 2)), 1LLU << 63);
+
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_div(c2, 2)), U64_MAX >> 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_div(c2, 2)), U64_MAX);
+
+       KUNIT_EXPECT_EQ(test, u128_hi(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U32_MAX >> 1);
+       KUNIT_EXPECT_EQ(test, u128_lo(u128_div(u128_shl(u64_to_u128(U64_MAX), 32), 2)), U64_MAX << 31);
+}
+
+static struct kunit_case mean_and_variance_test_cases[] = {
+       KUNIT_CASE(mean_and_variance_fast_divpow2),
+       KUNIT_CASE(mean_and_variance_u128_basic_test),
+       KUNIT_CASE(mean_and_variance_basic_test),
+       KUNIT_CASE(mean_and_variance_weighted_test),
+       KUNIT_CASE(mean_and_variance_weighted_advanced_test),
+       {}
+};
+
+static struct kunit_suite mean_and_variance_test_suite = {
+       .name           = "mean and variance tests",
+       .test_cases     = mean_and_variance_test_cases
+};
+
+kunit_test_suite(mean_and_variance_test_suite);
+
+MODULE_AUTHOR("Daniel B. Hill");
+MODULE_LICENSE("GPL");