kasan: preassign tags to objects with ctors or SLAB_TYPESAFE_BY_RCU
[linux-block.git] / mm / slub.c
index e3629cd7aff1640854ffa30b50ac012f8c78ff68..08740c3f374513c28d2963207791be22744f4e18 100644 (file)
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -1372,10 +1372,10 @@ static inline void dec_slabs_node(struct kmem_cache *s, int node,
  * Hooks for other subsystems that check memory allocations. In a typical
  * production configuration these hooks all should produce no code at all.
  */
-static inline void kmalloc_large_node_hook(void *ptr, size_t size, gfp_t flags)
+static inline void *kmalloc_large_node_hook(void *ptr, size_t size, gfp_t flags)
 {
        kmemleak_alloc(ptr, size, 1, flags);
-       kasan_kmalloc_large(ptr, size, flags);
+       return kasan_kmalloc_large(ptr, size, flags);
 }
 
 static __always_inline void kfree_hook(void *x)
@@ -1451,16 +1451,17 @@ static inline bool slab_free_freelist_hook(struct kmem_cache *s,
 #endif
 }
 
-static void setup_object(struct kmem_cache *s, struct page *page,
+static void *setup_object(struct kmem_cache *s, struct page *page,
                                void *object)
 {
        setup_object_debug(s, page, object);
-       kasan_init_slab_obj(s, object);
+       object = kasan_init_slab_obj(s, object);
        if (unlikely(s->ctor)) {
                kasan_unpoison_object_data(s, object);
                s->ctor(object);
                kasan_poison_object_data(s, object);
        }
+       return object;
 }
 
 /*
@@ -1568,16 +1569,16 @@ static bool shuffle_freelist(struct kmem_cache *s, struct page *page)
        /* First entry is used as the base of the freelist */
        cur = next_freelist_entry(s, page, &pos, start, page_limit,
                                freelist_count);
+       cur = setup_object(s, page, cur);
        page->freelist = cur;
 
        for (idx = 1; idx < page->objects; idx++) {
-               setup_object(s, page, cur);
                next = next_freelist_entry(s, page, &pos, start, page_limit,
                        freelist_count);
+               next = setup_object(s, page, next);
                set_freepointer(s, cur, next);
                cur = next;
        }
-       setup_object(s, page, cur);
        set_freepointer(s, cur, NULL);
 
        return true;
@@ -1599,7 +1600,7 @@ static struct page *allocate_slab(struct kmem_cache *s, gfp_t flags, int node)
        struct page *page;
        struct kmem_cache_order_objects oo = s->oo;
        gfp_t alloc_gfp;
-       void *start, *p;
+       void *start, *p, *next;
        int idx, order;
        bool shuffle;
 
@@ -1651,13 +1652,16 @@ static struct page *allocate_slab(struct kmem_cache *s, gfp_t flags, int node)
 
        if (!shuffle) {
                for_each_object_idx(p, idx, s, start, page->objects) {
-                       setup_object(s, page, p);
-                       if (likely(idx < page->objects))
-                               set_freepointer(s, p, p + s->size);
-                       else
+                       if (likely(idx < page->objects)) {
+                               next = p + s->size;
+                               next = setup_object(s, page, next);
+                               set_freepointer(s, p, next);
+                       } else
                                set_freepointer(s, p, NULL);
                }
-               page->freelist = fixup_red_left(s, start);
+               start = fixup_red_left(s, start);
+               start = setup_object(s, page, start);
+               page->freelist = start;
        }
 
        page->inuse = page->objects;
@@ -2768,7 +2772,7 @@ void *kmem_cache_alloc_trace(struct kmem_cache *s, gfp_t gfpflags, size_t size)
 {
        void *ret = slab_alloc(s, gfpflags, _RET_IP_);
        trace_kmalloc(_RET_IP_, ret, size, s->size, gfpflags);
-       kasan_kmalloc(s, ret, size, gfpflags);
+       ret = kasan_kmalloc(s, ret, size, gfpflags);
        return ret;
 }
 EXPORT_SYMBOL(kmem_cache_alloc_trace);
@@ -2796,7 +2800,7 @@ void *kmem_cache_alloc_node_trace(struct kmem_cache *s,
        trace_kmalloc_node(_RET_IP_, ret,
                           size, s->size, gfpflags, node);
 
-       kasan_kmalloc(s, ret, size, gfpflags);
+       ret = kasan_kmalloc(s, ret, size, gfpflags);
        return ret;
 }
 EXPORT_SYMBOL(kmem_cache_alloc_node_trace);
@@ -2992,7 +2996,7 @@ static __always_inline void slab_free(struct kmem_cache *s, struct page *page,
                do_slab_free(s, page, head, tail, cnt, addr);
 }
 
-#ifdef CONFIG_KASAN
+#ifdef CONFIG_KASAN_GENERIC
 void ___cache_free(struct kmem_cache *cache, void *x, unsigned long addr)
 {
        do_slab_free(cache, virt_to_head_page(x), x, NULL, 1, addr);
@@ -3364,16 +3368,16 @@ static void early_kmem_cache_node_alloc(int node)
 
        n = page->freelist;
        BUG_ON(!n);
-       page->freelist = get_freepointer(kmem_cache_node, n);
-       page->inuse = 1;
-       page->frozen = 0;
-       kmem_cache_node->node[node] = n;
 #ifdef CONFIG_SLUB_DEBUG
        init_object(kmem_cache_node, n, SLUB_RED_ACTIVE);
        init_tracking(kmem_cache_node, n);
 #endif
-       kasan_kmalloc(kmem_cache_node, n, sizeof(struct kmem_cache_node),
+       n = kasan_kmalloc(kmem_cache_node, n, sizeof(struct kmem_cache_node),
                      GFP_KERNEL);
+       page->freelist = get_freepointer(kmem_cache_node, n);
+       page->inuse = 1;
+       page->frozen = 0;
+       kmem_cache_node->node[node] = n;
        init_kmem_cache_node(n);
        inc_slabs_node(kmem_cache_node, node, page->objects);
 
@@ -3784,7 +3788,7 @@ void *__kmalloc(size_t size, gfp_t flags)
 
        trace_kmalloc(_RET_IP_, ret, size, s->size, flags);
 
-       kasan_kmalloc(s, ret, size, flags);
+       ret = kasan_kmalloc(s, ret, size, flags);
 
        return ret;
 }
@@ -3801,8 +3805,7 @@ static void *kmalloc_large_node(size_t size, gfp_t flags, int node)
        if (page)
                ptr = page_address(page);
 
-       kmalloc_large_node_hook(ptr, size, flags);
-       return ptr;
+       return kmalloc_large_node_hook(ptr, size, flags);
 }
 
 void *__kmalloc_node(size_t size, gfp_t flags, int node)
@@ -3829,7 +3832,7 @@ void *__kmalloc_node(size_t size, gfp_t flags, int node)
 
        trace_kmalloc_node(_RET_IP_, ret, size, s->size, flags, node);
 
-       kasan_kmalloc(s, ret, size, flags);
+       ret = kasan_kmalloc(s, ret, size, flags);
 
        return ret;
 }