Merge branch 'for-4.18/mcsafe' into libnvdimm-for-next
[linux-2.6-block.git] / lib / iov_iter.c
index fdae394172fa78efaf3637266492c5a5823ce41a..7e43cd54c84ca3da2d77b02e7112c69386428a2b 100644 (file)
@@ -573,6 +573,67 @@ size_t _copy_to_iter(const void *addr, size_t bytes, struct iov_iter *i)
 }
 EXPORT_SYMBOL(_copy_to_iter);
 
+#ifdef CONFIG_ARCH_HAS_UACCESS_MCSAFE
+static int copyout_mcsafe(void __user *to, const void *from, size_t n)
+{
+       if (access_ok(VERIFY_WRITE, to, n)) {
+               kasan_check_read(from, n);
+               n = copy_to_user_mcsafe((__force void *) to, from, n);
+       }
+       return n;
+}
+
+static unsigned long memcpy_mcsafe_to_page(struct page *page, size_t offset,
+               const char *from, size_t len)
+{
+       unsigned long ret;
+       char *to;
+
+       to = kmap_atomic(page);
+       ret = memcpy_mcsafe(to + offset, from, len);
+       kunmap_atomic(to);
+
+       return ret;
+}
+
+size_t _copy_to_iter_mcsafe(const void *addr, size_t bytes, struct iov_iter *i)
+{
+       const char *from = addr;
+       unsigned long rem, curr_addr, s_addr = (unsigned long) addr;
+
+       if (unlikely(i->type & ITER_PIPE)) {
+               WARN_ON(1);
+               return 0;
+       }
+       if (iter_is_iovec(i))
+               might_fault();
+       iterate_and_advance(i, bytes, v,
+               copyout_mcsafe(v.iov_base, (from += v.iov_len) - v.iov_len, v.iov_len),
+               ({
+               rem = memcpy_mcsafe_to_page(v.bv_page, v.bv_offset,
+                               (from += v.bv_len) - v.bv_len, v.bv_len);
+               if (rem) {
+                       curr_addr = (unsigned long) from;
+                       bytes = curr_addr - s_addr - rem;
+                       return bytes;
+               }
+               }),
+               ({
+               rem = memcpy_mcsafe(v.iov_base, (from += v.iov_len) - v.iov_len,
+                               v.iov_len);
+               if (rem) {
+                       curr_addr = (unsigned long) from;
+                       bytes = curr_addr - s_addr - rem;
+                       return bytes;
+               }
+               })
+       )
+
+       return bytes;
+}
+EXPORT_SYMBOL_GPL(_copy_to_iter_mcsafe);
+#endif /* CONFIG_ARCH_HAS_UACCESS_MCSAFE */
+
 size_t _copy_from_iter(void *addr, size_t bytes, struct iov_iter *i)
 {
        char *to = addr;