io_uring/kbuf: use vm_insert_pages() for mmap'ed pbuf ring
authorJens Axboe <axboe@kernel.dk>
Wed, 13 Mar 2024 02:24:21 +0000 (20:24 -0600)
committerJens Axboe <axboe@kernel.dk>
Mon, 15 Apr 2024 14:10:26 +0000 (08:10 -0600)
Rather than use remap_pfn_range() for this and manually free later,
switch to using vm_insert_page() and have it Just Work.

This requires a bit of effort on the mmap lookup side, as the ctx
uring_lock isn't held, which  otherwise protects buffer_lists from being
torn down, and it's not safe to grab from mmap context that would
introduce an ABBA deadlock between the mmap lock and the ctx uring_lock.
Instead, lookup the buffer_list under RCU, as the the list is RCU freed
already. Use the existing reference count to determine whether it's
possible to safely grab a reference to it (eg if it's not zero already),
and drop that reference when done with the mapping. If the mmap
reference is the last one, the buffer_list and the associated memory can
go away, since the vma insertion has references to the inserted pages at
that point.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
include/linux/io_uring_types.h
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/kbuf.c
io_uring/kbuf.h

index ef45b8bd1b35a7deccf2810f022d659c3a57ab0a..d34c8433caf94f52fad9c679006fc70e48fe6fdb 100644 (file)
@@ -372,9 +372,6 @@ struct io_ring_ctx {
 
        struct list_head        io_buffers_cache;
 
-       /* deferred free list, protected by ->uring_lock */
-       struct hlist_head       io_buf_list;
-
        /* Keep this last, we don't need it for the fast path */
        struct wait_queue_head          poll_wq;
        struct io_restriction           restrictions;
index 0ef418faac334896ee4ff1cffe12f2652ae5c43c..50e859da59e127ed109e988cc9b08006e80875f9 100644 (file)
@@ -303,7 +303,6 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        INIT_LIST_HEAD(&ctx->sqd_list);
        INIT_LIST_HEAD(&ctx->cq_overflow_list);
        INIT_LIST_HEAD(&ctx->io_buffers_cache);
-       INIT_HLIST_HEAD(&ctx->io_buf_list);
        ret = io_alloc_cache_init(&ctx->rsrc_node_cache, IO_NODE_ALLOC_CACHE_MAX,
                            sizeof(struct io_rsrc_node));
        ret |= io_alloc_cache_init(&ctx->apoll_cache, IO_POLL_ALLOC_CACHE_MAX,
@@ -2598,15 +2597,15 @@ static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events,
        return READ_ONCE(rings->cq.head) == READ_ONCE(rings->cq.tail) ? ret : 0;
 }
 
-static void io_pages_unmap(void *ptr, struct page ***pages,
-                          unsigned short *npages)
+void io_pages_unmap(void *ptr, struct page ***pages, unsigned short *npages,
+                   bool put_pages)
 {
        bool do_vunmap = false;
 
        if (!ptr)
                return;
 
-       if (*npages) {
+       if (put_pages && *npages) {
                struct page **to_free = *pages;
                int i;
 
@@ -2628,14 +2627,6 @@ static void io_pages_unmap(void *ptr, struct page ***pages,
        *npages = 0;
 }
 
-void io_mem_free(void *ptr)
-{
-       if (!ptr)
-               return;
-
-       folio_put(virt_to_folio(ptr));
-}
-
 static void io_pages_free(struct page ***pages, int npages)
 {
        struct page **page_array = *pages;
@@ -2730,8 +2721,10 @@ static void *io_sqes_map(struct io_ring_ctx *ctx, unsigned long uaddr,
 static void io_rings_free(struct io_ring_ctx *ctx)
 {
        if (!(ctx->flags & IORING_SETUP_NO_MMAP)) {
-               io_pages_unmap(ctx->rings, &ctx->ring_pages, &ctx->n_ring_pages);
-               io_pages_unmap(ctx->sq_sqes, &ctx->sqe_pages, &ctx->n_sqe_pages);
+               io_pages_unmap(ctx->rings, &ctx->ring_pages, &ctx->n_ring_pages,
+                               true);
+               io_pages_unmap(ctx->sq_sqes, &ctx->sqe_pages, &ctx->n_sqe_pages,
+                               true);
        } else {
                io_pages_free(&ctx->ring_pages, ctx->n_ring_pages);
                ctx->n_ring_pages = 0;
@@ -2788,8 +2781,8 @@ err:
        return ERR_PTR(-ENOMEM);
 }
 
-static void *io_pages_map(struct page ***out_pages, unsigned short *npages,
-                         size_t size)
+void *io_pages_map(struct page ***out_pages, unsigned short *npages,
+                  size_t size)
 {
        gfp_t gfp = GFP_KERNEL_ACCOUNT | __GFP_ZERO | __GFP_NOWARN;
        struct page **pages;
@@ -2819,17 +2812,6 @@ done:
        return ret;
 }
 
-void *io_mem_alloc(size_t size)
-{
-       gfp_t gfp = GFP_KERNEL_ACCOUNT | __GFP_ZERO | __GFP_NOWARN | __GFP_COMP;
-       void *ret;
-
-       ret = (void *) __get_free_pages(gfp, get_order(size));
-       if (ret)
-               return ret;
-       return ERR_PTR(-ENOMEM);
-}
-
 static unsigned long rings_size(struct io_ring_ctx *ctx, unsigned int sq_entries,
                                unsigned int cq_entries, size_t *sq_offset)
 {
@@ -2926,7 +2908,6 @@ static __cold void io_ring_ctx_free(struct io_ring_ctx *ctx)
                ctx->mm_account = NULL;
        }
        io_rings_free(ctx);
-       io_kbuf_mmap_list_free(ctx);
 
        percpu_ref_exit(&ctx->refs);
        free_uid(ctx->user);
@@ -3396,10 +3377,8 @@ static void *io_uring_validate_mmap_request(struct file *file,
 {
        struct io_ring_ctx *ctx = file->private_data;
        loff_t offset = pgoff << PAGE_SHIFT;
-       struct page *page;
-       void *ptr;
 
-       switch (offset & IORING_OFF_MMAP_MASK) {
+       switch ((pgoff << PAGE_SHIFT) & IORING_OFF_MMAP_MASK) {
        case IORING_OFF_SQ_RING:
        case IORING_OFF_CQ_RING:
                /* Don't allow mmap if the ring was setup without it */
@@ -3414,6 +3393,7 @@ static void *io_uring_validate_mmap_request(struct file *file,
        case IORING_OFF_PBUF_RING: {
                struct io_buffer_list *bl;
                unsigned int bgid;
+               void *ptr;
 
                bgid = (offset & ~IORING_OFF_MMAP_MASK) >> IORING_OFF_PBUF_SHIFT;
                bl = io_pbuf_get_bl(ctx, bgid);
@@ -3421,17 +3401,11 @@ static void *io_uring_validate_mmap_request(struct file *file,
                        return bl;
                ptr = bl->buf_ring;
                io_put_bl(ctx, bl);
-               break;
+               return ptr;
                }
-       default:
-               return ERR_PTR(-EINVAL);
        }
 
-       page = virt_to_head_page(ptr);
-       if (sz > page_size(page))
-               return ERR_PTR(-EINVAL);
-
-       return ptr;
+       return ERR_PTR(-EINVAL);
 }
 
 int io_uring_mmap_pages(struct io_ring_ctx *ctx, struct vm_area_struct *vma,
@@ -3450,7 +3424,6 @@ static __cold int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
        struct io_ring_ctx *ctx = file->private_data;
        size_t sz = vma->vm_end - vma->vm_start;
        long offset = vma->vm_pgoff << PAGE_SHIFT;
-       unsigned long pfn;
        void *ptr;
 
        ptr = io_uring_validate_mmap_request(file, vma->vm_pgoff, sz);
@@ -3465,10 +3438,11 @@ static __cold int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
        case IORING_OFF_SQES:
                return io_uring_mmap_pages(ctx, vma, ctx->sqe_pages,
                                                ctx->n_sqe_pages);
+       case IORING_OFF_PBUF_RING:
+               return io_pbuf_mmap(file, vma);
        }
 
-       pfn = virt_to_phys(ptr) >> PAGE_SHIFT;
-       return remap_pfn_range(vma, vma->vm_start, pfn, sz, vma->vm_page_prot);
+       return -EINVAL;
 }
 
 static unsigned long io_uring_mmu_get_unmapped_area(struct file *filp,
index 75230d914007fee1bf35afd43df0eaa253f64884..dec996a1c7895f152672c30600da66357a2057ee 100644 (file)
@@ -109,8 +109,10 @@ bool __io_alloc_req_refill(struct io_ring_ctx *ctx);
 bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
                        bool cancel_all);
 
-void *io_mem_alloc(size_t size);
-void io_mem_free(void *ptr);
+void *io_pages_map(struct page ***out_pages, unsigned short *npages,
+                  size_t size);
+void io_pages_unmap(void *ptr, struct page ***pages, unsigned short *npages,
+                   bool put_pages);
 
 enum {
        IO_EVENTFD_OP_SIGNAL_BIT,
index 4289f4a926934635ad024630ff58481bc9de3c52..820ac599d003e463fb46e8b13f9a1bfa191b07cd 100644 (file)
@@ -32,25 +32,12 @@ struct io_provide_buf {
        __u16                           bid;
 };
 
-struct io_buf_free {
-       struct hlist_node               list;
-       void                            *mem;
-       size_t                          size;
-       int                             inuse;
-};
-
-static inline struct io_buffer_list *__io_buffer_get_list(struct io_ring_ctx *ctx,
-                                                         unsigned int bgid)
-{
-       return xa_load(&ctx->io_bl_xa, bgid);
-}
-
 static inline struct io_buffer_list *io_buffer_get_list(struct io_ring_ctx *ctx,
                                                        unsigned int bgid)
 {
        lockdep_assert_held(&ctx->uring_lock);
 
-       return __io_buffer_get_list(ctx, bgid);
+       return xa_load(&ctx->io_bl_xa, bgid);
 }
 
 static int io_buffer_add_list(struct io_ring_ctx *ctx,
@@ -191,24 +178,6 @@ void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
        return ret;
 }
 
-/*
- * Mark the given mapped range as free for reuse
- */
-static void io_kbuf_mark_free(struct io_ring_ctx *ctx, struct io_buffer_list *bl)
-{
-       struct io_buf_free *ibf;
-
-       hlist_for_each_entry(ibf, &ctx->io_buf_list, list) {
-               if (bl->buf_ring == ibf->mem) {
-                       ibf->inuse = 0;
-                       return;
-               }
-       }
-
-       /* can't happen... */
-       WARN_ON_ONCE(1);
-}
-
 static int __io_remove_buffers(struct io_ring_ctx *ctx,
                               struct io_buffer_list *bl, unsigned nbufs)
 {
@@ -220,23 +189,16 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx,
 
        if (bl->is_buf_ring) {
                i = bl->buf_ring->tail - bl->head;
-               if (bl->is_mmap) {
-                       /*
-                        * io_kbuf_list_free() will free the page(s) at
-                        * ->release() time.
-                        */
-                       io_kbuf_mark_free(ctx, bl);
-                       bl->buf_ring = NULL;
-                       bl->is_mmap = 0;
-               } else if (bl->buf_nr_pages) {
+               if (bl->buf_nr_pages) {
                        int j;
 
-                       for (j = 0; j < bl->buf_nr_pages; j++)
-                               unpin_user_page(bl->buf_pages[j]);
-                       kvfree(bl->buf_pages);
-                       vunmap(bl->buf_ring);
-                       bl->buf_pages = NULL;
-                       bl->buf_nr_pages = 0;
+                       if (!bl->is_mmap) {
+                               for (j = 0; j < bl->buf_nr_pages; j++)
+                                       unpin_user_page(bl->buf_pages[j]);
+                       }
+                       io_pages_unmap(bl->buf_ring, &bl->buf_pages,
+                                       &bl->buf_nr_pages, bl->is_mmap);
+                       bl->is_mmap = 0;
                }
                /* make sure it's seen as empty */
                INIT_LIST_HEAD(&bl->buf_list);
@@ -537,63 +499,18 @@ error_unpin:
        return ret;
 }
 
-/*
- * See if we have a suitable region that we can reuse, rather than allocate
- * both a new io_buf_free and mem region again. We leave it on the list as
- * even a reused entry will need freeing at ring release.
- */
-static struct io_buf_free *io_lookup_buf_free_entry(struct io_ring_ctx *ctx,
-                                                   size_t ring_size)
-{
-       struct io_buf_free *ibf, *best = NULL;
-       size_t best_dist;
-
-       hlist_for_each_entry(ibf, &ctx->io_buf_list, list) {
-               size_t dist;
-
-               if (ibf->inuse || ibf->size < ring_size)
-                       continue;
-               dist = ibf->size - ring_size;
-               if (!best || dist < best_dist) {
-                       best = ibf;
-                       if (!dist)
-                               break;
-                       best_dist = dist;
-               }
-       }
-
-       return best;
-}
-
 static int io_alloc_pbuf_ring(struct io_ring_ctx *ctx,
                              struct io_uring_buf_reg *reg,
                              struct io_buffer_list *bl)
 {
-       struct io_buf_free *ibf;
        size_t ring_size;
-       void *ptr;
 
        ring_size = reg->ring_entries * sizeof(struct io_uring_buf_ring);
 
-       /* Reuse existing entry, if we can */
-       ibf = io_lookup_buf_free_entry(ctx, ring_size);
-       if (!ibf) {
-               ptr = io_mem_alloc(ring_size);
-               if (IS_ERR(ptr))
-                       return PTR_ERR(ptr);
-
-               /* Allocate and store deferred free entry */
-               ibf = kmalloc(sizeof(*ibf), GFP_KERNEL_ACCOUNT);
-               if (!ibf) {
-                       io_mem_free(ptr);
-                       return -ENOMEM;
-               }
-               ibf->mem = ptr;
-               ibf->size = ring_size;
-               hlist_add_head(&ibf->list, &ctx->io_buf_list);
-       }
-       ibf->inuse = 1;
-       bl->buf_ring = ibf->mem;
+       bl->buf_ring = io_pages_map(&bl->buf_pages, &bl->buf_nr_pages, ring_size);
+       if (!bl->buf_ring)
+               return -ENOMEM;
+
        bl->is_buf_ring = 1;
        bl->is_mmap = 1;
        return 0;
@@ -741,18 +658,19 @@ struct io_buffer_list *io_pbuf_get_bl(struct io_ring_ctx *ctx,
        return ERR_PTR(-EINVAL);
 }
 
-/*
- * Called at or after ->release(), free the mmap'ed buffers that we used
- * for memory mapped provided buffer rings.
- */
-void io_kbuf_mmap_list_free(struct io_ring_ctx *ctx)
+int io_pbuf_mmap(struct file *file, struct vm_area_struct *vma)
 {
-       struct io_buf_free *ibf;
-       struct hlist_node *tmp;
+       struct io_ring_ctx *ctx = file->private_data;
+       loff_t pgoff = vma->vm_pgoff << PAGE_SHIFT;
+       struct io_buffer_list *bl;
+       int bgid, ret;
 
-       hlist_for_each_entry_safe(ibf, tmp, &ctx->io_buf_list, list) {
-               hlist_del(&ibf->list);
-               io_mem_free(ibf->mem);
-               kfree(ibf);
-       }
+       bgid = (pgoff & ~IORING_OFF_MMAP_MASK) >> IORING_OFF_PBUF_SHIFT;
+       bl = io_pbuf_get_bl(ctx, bgid);
+       if (IS_ERR(bl))
+               return PTR_ERR(bl);
+
+       ret = io_uring_mmap_pages(ctx, vma, bl->buf_pages, bl->buf_nr_pages);
+       io_put_bl(ctx, bl);
+       return ret;
 }
index df365b8860cf1eeb7eff261e1f5a5a7fc8d9b77f..53c141d9a8b2689d4db1a8dc4d99775a31d78190 100644 (file)
@@ -55,8 +55,6 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg);
 int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg);
 int io_register_pbuf_status(struct io_ring_ctx *ctx, void __user *arg);
 
-void io_kbuf_mmap_list_free(struct io_ring_ctx *ctx);
-
 void __io_put_kbuf(struct io_kiocb *req, unsigned issue_flags);
 
 bool io_kbuf_recycle_legacy(struct io_kiocb *req, unsigned issue_flags);
@@ -64,6 +62,7 @@ bool io_kbuf_recycle_legacy(struct io_kiocb *req, unsigned issue_flags);
 void io_put_bl(struct io_ring_ctx *ctx, struct io_buffer_list *bl);
 struct io_buffer_list *io_pbuf_get_bl(struct io_ring_ctx *ctx,
                                      unsigned long bgid);
+int io_pbuf_mmap(struct file *file, struct vm_area_struct *vma);
 
 static inline bool io_kbuf_recycle_ring(struct io_kiocb *req)
 {