riscv: implement user_access_begin() and families
authorJisheng Zhang <jszhang@kernel.org>
Thu, 10 Apr 2025 07:05:23 +0000 (07:05 +0000)
committerPalmer Dabbelt <palmer@rivosinc.com>
Thu, 8 May 2025 17:01:00 +0000 (10:01 -0700)
Currently, when a function like strncpy_from_user() is called,
the userspace access protection is disabled and enabled
for every word read.

By implementing user_access_begin() and families, the protection
is disabled at the beginning of the copy and enabled at the end.

The __inttype macro is borrowed from x86 implementation.

Signed-off-by: Jisheng Zhang <jszhang@kernel.org>
Signed-off-by: Cyril Bur <cyrilbur@tenstorrent.com>
Reviewed-by: Alexandre Ghiti <alexghiti@rivosinc.com>
Link: https://lore.kernel.org/r/20250410070526.3160847-3-cyrilbur@tenstorrent.com
Signed-off-by: Palmer Dabbelt <palmer@rivosinc.com>
arch/riscv/include/asm/uaccess.h

index fee56b0c8058655fa5215d5934b73e2c9936a7d9..c9a461467bf47aa9913cb64fb99e8475adc65b40 100644 (file)
@@ -61,6 +61,19 @@ static inline unsigned long __untagged_addr_remote(struct mm_struct *mm, unsigne
 #define __disable_user_access()                                                        \
        __asm__ __volatile__ ("csrc sstatus, %0" : : "r" (SR_SUM) : "memory")
 
+/*
+ * This is the smallest unsigned integer type that can fit a value
+ * (up to 'long long')
+ */
+#define __inttype(x) __typeof__(               \
+       __typefits(x, char,                     \
+         __typefits(x, short,                  \
+           __typefits(x, int,                  \
+             __typefits(x, long, 0ULL)))))
+
+#define __typefits(x, type, not) \
+       __builtin_choose_expr(sizeof(x) <= sizeof(type), (unsigned type)0, not)
+
 /*
  * The exception table consists of pairs of addresses: the first is the
  * address of an instruction that is allowed to fault, and the second is
@@ -368,6 +381,69 @@ do {                                                                       \
                goto err_label;                                         \
 } while (0)
 
+static __must_check __always_inline bool user_access_begin(const void __user *ptr, size_t len)
+{
+       if (unlikely(!access_ok(ptr, len)))
+               return 0;
+       __enable_user_access();
+       return 1;
+}
+#define user_access_begin user_access_begin
+#define user_access_end __disable_user_access
+
+static inline unsigned long user_access_save(void) { return 0UL; }
+static inline void user_access_restore(unsigned long enabled) { }
+
+/*
+ * We want the unsafe accessors to always be inlined and use
+ * the error labels - thus the macro games.
+ */
+#define unsafe_put_user(x, ptr, label) do {                            \
+       long __err = 0;                                                 \
+       __put_user_nocheck(x, (ptr), __err);                            \
+       if (__err)                                                      \
+               goto label;                                             \
+} while (0)
+
+#define unsafe_get_user(x, ptr, label) do {                            \
+       long __err = 0;                                                 \
+       __inttype(*(ptr)) __gu_val;                                     \
+       __get_user_nocheck(__gu_val, (ptr), __err);                     \
+       (x) = (__force __typeof__(*(ptr)))__gu_val;                     \
+       if (__err)                                                      \
+               goto label;                                             \
+} while (0)
+
+#define unsafe_copy_loop(dst, src, len, type, op, label)               \
+       while (len >= sizeof(type)) {                                   \
+               op(*(type *)(src), (type __user *)(dst), label);        \
+               dst += sizeof(type);                                    \
+               src += sizeof(type);                                    \
+               len -= sizeof(type);                                    \
+       }
+
+#define unsafe_copy_to_user(_dst, _src, _len, label)                   \
+do {                                                                   \
+       char __user *__ucu_dst = (_dst);                                \
+       const char *__ucu_src = (_src);                                 \
+       size_t __ucu_len = (_len);                                      \
+       unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u64, unsafe_put_user, label); \
+       unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u32, unsafe_put_user, label); \
+       unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u16, unsafe_put_user, label); \
+       unsafe_copy_loop(__ucu_dst, __ucu_src, __ucu_len, u8, unsafe_put_user, label);  \
+} while (0)
+
+#define unsafe_copy_from_user(_dst, _src, _len, label)                 \
+do {                                                                   \
+       char *__ucu_dst = (_dst);                                       \
+       const char __user *__ucu_src = (_src);                          \
+       size_t __ucu_len = (_len);                                      \
+       unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u64, unsafe_get_user, label); \
+       unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u32, unsafe_get_user, label); \
+       unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u16, unsafe_get_user, label); \
+       unsafe_copy_loop(__ucu_src, __ucu_dst, __ucu_len, u8, unsafe_get_user, label);  \
+} while (0)
+
 #else /* CONFIG_MMU */
 #include <asm-generic/uaccess.h>
 #endif /* CONFIG_MMU */