io_uring/kbuf: add helpers for getting/peeking multiple buffers
[linux-2.6-block.git] / io_uring / kbuf.c
index 693c26da4ee1a36b4b0acee0b0cc7a7d0cfde3e6..d2945c9c812b5c7e81845340d3f28192a7aab623 100644 (file)
@@ -7,6 +7,7 @@
 #include <linux/slab.h>
 #include <linux/namei.h>
 #include <linux/poll.h>
+#include <linux/vmalloc.h>
 #include <linux/io_uring.h>
 
 #include <uapi/linux/io_uring.h>
 #include "io_uring.h"
 #include "opdef.h"
 #include "kbuf.h"
-
-#define IO_BUFFER_LIST_BUF_PER_PAGE (PAGE_SIZE / sizeof(struct io_uring_buf))
-
-#define BGID_ARRAY     64
+#include "memmap.h"
 
 /* BIDs are addressed by a 16-bit field in a CQE */
 #define MAX_BIDS_PER_BGID (1 << 16)
@@ -33,29 +31,12 @@ struct io_provide_buf {
        __u16                           bid;
 };
 
-struct io_buf_free {
-       struct hlist_node               list;
-       void                            *mem;
-       size_t                          size;
-       int                             inuse;
-};
-
-static struct io_buffer_list *__io_buffer_get_list(struct io_ring_ctx *ctx,
-                                                  struct io_buffer_list *bl,
-                                                  unsigned int bgid)
-{
-       if (bl && bgid < BGID_ARRAY)
-               return &bl[bgid];
-
-       return xa_load(&ctx->io_bl_xa, bgid);
-}
-
 static inline struct io_buffer_list *io_buffer_get_list(struct io_ring_ctx *ctx,
                                                        unsigned int bgid)
 {
        lockdep_assert_held(&ctx->uring_lock);
 
-       return __io_buffer_get_list(ctx, ctx->io_bl, bgid);
+       return xa_load(&ctx->io_bl_xa, bgid);
 }
 
 static int io_buffer_add_list(struct io_ring_ctx *ctx,
@@ -67,11 +48,7 @@ static int io_buffer_add_list(struct io_ring_ctx *ctx,
         * always under the ->uring_lock, but the RCU lookup from mmap does.
         */
        bl->bgid = bgid;
-       smp_store_release(&bl->is_ready, 1);
-
-       if (bgid < BGID_ARRAY)
-               return 0;
-
+       atomic_set(&bl->refs, 1);
        return xa_err(xa_store(&ctx->io_bl_xa, bgid, bl, GFP_KERNEL));
 }
 
@@ -140,6 +117,27 @@ static void __user *io_provided_buffer_select(struct io_kiocb *req, size_t *len,
        return NULL;
 }
 
+static int io_provided_buffers_select(struct io_kiocb *req, size_t *len,
+                                     struct io_buffer_list *bl,
+                                     struct iovec *iov)
+{
+       void __user *buf;
+
+       buf = io_provided_buffer_select(req, len, bl);
+       if (unlikely(!buf))
+               return -ENOBUFS;
+
+       iov[0].iov_base = buf;
+       iov[0].iov_len = *len;
+       return 0;
+}
+
+static struct io_uring_buf *io_ring_head_to_buf(struct io_uring_buf_ring *br,
+                                               __u16 head, __u16 mask)
+{
+       return &br->bufs[head & mask];
+}
+
 static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
                                          struct io_buffer_list *bl,
                                          unsigned int issue_flags)
@@ -155,19 +153,10 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
        if (head + 1 == tail)
                req->flags |= REQ_F_BL_EMPTY;
 
-       head &= bl->mask;
-       /* mmaped buffers are always contig */
-       if (bl->is_mmap || head < IO_BUFFER_LIST_BUF_PER_PAGE) {
-               buf = &br->bufs[head];
-       } else {
-               int off = head & (IO_BUFFER_LIST_BUF_PER_PAGE - 1);
-               int index = head / IO_BUFFER_LIST_BUF_PER_PAGE;
-               buf = page_address(bl->buf_pages[index]);
-               buf += off;
-       }
+       buf = io_ring_head_to_buf(br, head, bl->mask);
        if (*len == 0 || *len > buf->len)
                *len = buf->len;
-       req->flags |= REQ_F_BUFFER_RING;
+       req->flags |= REQ_F_BUFFER_RING | REQ_F_BUFFERS_COMMIT;
        req->buf_list = bl;
        req->buf_index = buf->bid;
 
@@ -182,6 +171,7 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
                 * the transfer completes (or if we get -EAGAIN and must poll of
                 * retry).
                 */
+               req->flags &= ~REQ_F_BUFFERS_COMMIT;
                req->buf_list = NULL;
                bl->head++;
        }
@@ -208,40 +198,134 @@ void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
        return ret;
 }
 
-static __cold int io_init_bl_list(struct io_ring_ctx *ctx)
+/* cap it at a reasonable 256, will be one page even for 4K */
+#define PEEK_MAX_IMPORT                256
+
+static int io_ring_buffers_peek(struct io_kiocb *req, struct buf_sel_arg *arg,
+                               struct io_buffer_list *bl)
 {
-       struct io_buffer_list *bl;
-       int i;
+       struct io_uring_buf_ring *br = bl->buf_ring;
+       struct iovec *iov = arg->iovs;
+       int nr_iovs = arg->nr_iovs;
+       __u16 nr_avail, tail, head;
+       struct io_uring_buf *buf;
 
-       bl = kcalloc(BGID_ARRAY, sizeof(struct io_buffer_list), GFP_KERNEL);
-       if (!bl)
-               return -ENOMEM;
+       tail = smp_load_acquire(&br->tail);
+       head = bl->head;
+       nr_avail = min_t(__u16, tail - head, UIO_MAXIOV);
+       if (unlikely(!nr_avail))
+               return -ENOBUFS;
+
+       buf = io_ring_head_to_buf(br, head, bl->mask);
+       if (arg->max_len) {
+               int needed;
+
+               needed = (arg->max_len + buf->len - 1) / buf->len;
+               needed = min(needed, PEEK_MAX_IMPORT);
+               if (nr_avail > needed)
+                       nr_avail = needed;
+       }
 
-       for (i = 0; i < BGID_ARRAY; i++) {
-               INIT_LIST_HEAD(&bl[i].buf_list);
-               bl[i].bgid = i;
+       /*
+        * only alloc a bigger array if we know we have data to map, eg not
+        * a speculative peek operation.
+        */
+       if (arg->mode & KBUF_MODE_EXPAND && nr_avail > nr_iovs && arg->max_len) {
+               iov = kmalloc_array(nr_avail, sizeof(struct iovec), GFP_KERNEL);
+               if (unlikely(!iov))
+                       return -ENOMEM;
+               if (arg->mode & KBUF_MODE_FREE)
+                       kfree(arg->iovs);
+               arg->iovs = iov;
+               nr_iovs = nr_avail;
+       } else if (nr_avail < nr_iovs) {
+               nr_iovs = nr_avail;
        }
 
-       smp_store_release(&ctx->io_bl, bl);
-       return 0;
+       /* set it to max, if not set, so we can use it unconditionally */
+       if (!arg->max_len)
+               arg->max_len = INT_MAX;
+
+       req->buf_index = buf->bid;
+       do {
+               /* truncate end piece, if needed */
+               if (buf->len > arg->max_len)
+                       buf->len = arg->max_len;
+
+               iov->iov_base = u64_to_user_ptr(buf->addr);
+               iov->iov_len = buf->len;
+               iov++;
+
+               arg->out_len += buf->len;
+               arg->max_len -= buf->len;
+               if (!arg->max_len)
+                       break;
+
+               buf = io_ring_head_to_buf(br, ++head, bl->mask);
+       } while (--nr_iovs);
+
+       if (head == tail)
+               req->flags |= REQ_F_BL_EMPTY;
+
+       req->flags |= REQ_F_BUFFER_RING;
+       req->buf_list = bl;
+       return iov - arg->iovs;
 }
 
-/*
- * Mark the given mapped range as free for reuse
- */
-static void io_kbuf_mark_free(struct io_ring_ctx *ctx, struct io_buffer_list *bl)
+int io_buffers_select(struct io_kiocb *req, struct buf_sel_arg *arg,
+                     unsigned int issue_flags)
 {
-       struct io_buf_free *ibf;
+       struct io_ring_ctx *ctx = req->ctx;
+       struct io_buffer_list *bl;
+       int ret = -ENOENT;
+
+       io_ring_submit_lock(ctx, issue_flags);
+       bl = io_buffer_get_list(ctx, req->buf_index);
+       if (unlikely(!bl))
+               goto out_unlock;
 
-       hlist_for_each_entry(ibf, &ctx->io_buf_list, list) {
-               if (bl->buf_ring == ibf->mem) {
-                       ibf->inuse = 0;
-                       return;
+       if (bl->is_buf_ring) {
+               ret = io_ring_buffers_peek(req, arg, bl);
+               /*
+                * Don't recycle these buffers if we need to go through poll.
+                * Nobody else can use them anyway, and holding on to provided
+                * buffers for a send/write operation would happen on the app
+                * side anyway with normal buffers. Besides, we already
+                * committed them, they cannot be put back in the queue.
+                */
+               if (ret > 0) {
+                       req->flags |= REQ_F_BL_NO_RECYCLE;
+                       req->buf_list->head += ret;
                }
+       } else {
+               ret = io_provided_buffers_select(req, &arg->out_len, bl, arg->iovs);
+       }
+out_unlock:
+       io_ring_submit_unlock(ctx, issue_flags);
+       return ret;
+}
+
+int io_buffers_peek(struct io_kiocb *req, struct buf_sel_arg *arg)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+       struct io_buffer_list *bl;
+       int ret;
+
+       lockdep_assert_held(&ctx->uring_lock);
+
+       bl = io_buffer_get_list(ctx, req->buf_index);
+       if (unlikely(!bl))
+               return -ENOENT;
+
+       if (bl->is_buf_ring) {
+               ret = io_ring_buffers_peek(req, arg, bl);
+               if (ret > 0)
+                       req->flags |= REQ_F_BUFFERS_COMMIT;
+               return ret;
        }
 
-       /* can't happen... */
-       WARN_ON_ONCE(1);
+       /* don't support multiple buffer selections for legacy */
+       return io_provided_buffers_select(req, &arg->max_len, bl, arg->iovs);
 }
 
 static int __io_remove_buffers(struct io_ring_ctx *ctx,
@@ -255,22 +339,16 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx,
 
        if (bl->is_buf_ring) {
                i = bl->buf_ring->tail - bl->head;
-               if (bl->is_mmap) {
-                       /*
-                        * io_kbuf_list_free() will free the page(s) at
-                        * ->release() time.
-                        */
-                       io_kbuf_mark_free(ctx, bl);
-                       bl->buf_ring = NULL;
-                       bl->is_mmap = 0;
-               } else if (bl->buf_nr_pages) {
+               if (bl->buf_nr_pages) {
                        int j;
 
-                       for (j = 0; j < bl->buf_nr_pages; j++)
-                               unpin_user_page(bl->buf_pages[j]);
-                       kvfree(bl->buf_pages);
-                       bl->buf_pages = NULL;
-                       bl->buf_nr_pages = 0;
+                       if (!bl->is_mmap) {
+                               for (j = 0; j < bl->buf_nr_pages; j++)
+                                       unpin_user_page(bl->buf_pages[j]);
+                       }
+                       io_pages_unmap(bl->buf_ring, &bl->buf_pages,
+                                       &bl->buf_nr_pages, bl->is_mmap);
+                       bl->is_mmap = 0;
                }
                /* make sure it's seen as empty */
                INIT_LIST_HEAD(&bl->buf_list);
@@ -294,24 +372,24 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx,
        return i;
 }
 
+void io_put_bl(struct io_ring_ctx *ctx, struct io_buffer_list *bl)
+{
+       if (atomic_dec_and_test(&bl->refs)) {
+               __io_remove_buffers(ctx, bl, -1U);
+               kfree_rcu(bl, rcu);
+       }
+}
+
 void io_destroy_buffers(struct io_ring_ctx *ctx)
 {
        struct io_buffer_list *bl;
        struct list_head *item, *tmp;
        struct io_buffer *buf;
        unsigned long index;
-       int i;
-
-       for (i = 0; i < BGID_ARRAY; i++) {
-               if (!ctx->io_bl)
-                       break;
-               __io_remove_buffers(ctx, &ctx->io_bl[i], -1U);
-       }
 
        xa_for_each(&ctx->io_bl_xa, index, bl) {
                xa_erase(&ctx->io_bl_xa, bl->bgid);
-               __io_remove_buffers(ctx, bl, -1U);
-               kfree_rcu(bl, rcu);
+               io_put_bl(ctx, bl);
        }
 
        /*
@@ -489,12 +567,6 @@ int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
        io_ring_submit_lock(ctx, issue_flags);
 
-       if (unlikely(p->bgid < BGID_ARRAY && !ctx->io_bl)) {
-               ret = io_init_bl_list(ctx);
-               if (ret)
-                       goto err;
-       }
-
        bl = io_buffer_get_list(ctx, p->bgid);
        if (unlikely(!bl)) {
                bl = kzalloc(sizeof(*bl), GFP_KERNEL_ACCOUNT);
@@ -507,14 +579,9 @@ int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
                if (ret) {
                        /*
                         * Doesn't need rcu free as it was never visible, but
-                        * let's keep it consistent throughout. Also can't
-                        * be a lower indexed array group, as adding one
-                        * where lookup failed cannot happen.
+                        * let's keep it consistent throughout.
                         */
-                       if (p->bgid >= BGID_ARRAY)
-                               kfree_rcu(bl, rcu);
-                       else
-                               WARN_ON_ONCE(1);
+                       kfree_rcu(bl, rcu);
                        goto err;
                }
        }
@@ -537,9 +604,9 @@ err:
 static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
                            struct io_buffer_list *bl)
 {
-       struct io_uring_buf_ring *br;
+       struct io_uring_buf_ring *br = NULL;
        struct page **pages;
-       int i, nr_pages;
+       int nr_pages, ret;
 
        pages = io_pin_pages(reg->ring_addr,
                             flex_array_size(br, bufs, reg->ring_entries),
@@ -547,18 +614,12 @@ static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
        if (IS_ERR(pages))
                return PTR_ERR(pages);
 
-       /*
-        * Apparently some 32-bit boxes (ARM) will return highmem pages,
-        * which then need to be mapped. We could support that, but it'd
-        * complicate the code and slowdown the common cases quite a bit.
-        * So just error out, returning -EINVAL just like we did on kernels
-        * that didn't support mapped buffer rings.
-        */
-       for (i = 0; i < nr_pages; i++)
-               if (PageHighMem(pages[i]))
-                       goto error_unpin;
+       br = vmap(pages, nr_pages, VM_MAP, PAGE_KERNEL);
+       if (!br) {
+               ret = -ENOMEM;
+               goto error_unpin;
+       }
 
-       br = page_address(pages[0]);
 #ifdef SHM_COLOUR
        /*
         * On platforms that have specific aliasing requirements, SHM_COLOUR
@@ -569,8 +630,10 @@ static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
         * should use IOU_PBUF_RING_MMAP instead, and liburing will handle
         * this transparently.
         */
-       if ((reg->ring_addr | (unsigned long) br) & (SHM_COLOUR - 1))
+       if ((reg->ring_addr | (unsigned long) br) & (SHM_COLOUR - 1)) {
+               ret = -EINVAL;
                goto error_unpin;
+       }
 #endif
        bl->buf_pages = pages;
        bl->buf_nr_pages = nr_pages;
@@ -579,69 +642,24 @@ static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
        bl->is_mmap = 0;
        return 0;
 error_unpin:
-       for (i = 0; i < nr_pages; i++)
-               unpin_user_page(pages[i]);
+       unpin_user_pages(pages, nr_pages);
        kvfree(pages);
-       return -EINVAL;
-}
-
-/*
- * See if we have a suitable region that we can reuse, rather than allocate
- * both a new io_buf_free and mem region again. We leave it on the list as
- * even a reused entry will need freeing at ring release.
- */
-static struct io_buf_free *io_lookup_buf_free_entry(struct io_ring_ctx *ctx,
-                                                   size_t ring_size)
-{
-       struct io_buf_free *ibf, *best = NULL;
-       size_t best_dist;
-
-       hlist_for_each_entry(ibf, &ctx->io_buf_list, list) {
-               size_t dist;
-
-               if (ibf->inuse || ibf->size < ring_size)
-                       continue;
-               dist = ibf->size - ring_size;
-               if (!best || dist < best_dist) {
-                       best = ibf;
-                       if (!dist)
-                               break;
-                       best_dist = dist;
-               }
-       }
-
-       return best;
+       vunmap(br);
+       return ret;
 }
 
 static int io_alloc_pbuf_ring(struct io_ring_ctx *ctx,
                              struct io_uring_buf_reg *reg,
                              struct io_buffer_list *bl)
 {
-       struct io_buf_free *ibf;
        size_t ring_size;
-       void *ptr;
 
        ring_size = reg->ring_entries * sizeof(struct io_uring_buf_ring);
 
-       /* Reuse existing entry, if we can */
-       ibf = io_lookup_buf_free_entry(ctx, ring_size);
-       if (!ibf) {
-               ptr = io_mem_alloc(ring_size);
-               if (IS_ERR(ptr))
-                       return PTR_ERR(ptr);
-
-               /* Allocate and store deferred free entry */
-               ibf = kmalloc(sizeof(*ibf), GFP_KERNEL_ACCOUNT);
-               if (!ibf) {
-                       io_mem_free(ptr);
-                       return -ENOMEM;
-               }
-               ibf->mem = ptr;
-               ibf->size = ring_size;
-               hlist_add_head(&ibf->list, &ctx->io_buf_list);
-       }
-       ibf->inuse = 1;
-       bl->buf_ring = ibf->mem;
+       bl->buf_ring = io_pages_map(&bl->buf_pages, &bl->buf_nr_pages, ring_size);
+       if (!bl->buf_ring)
+               return -ENOMEM;
+
        bl->is_buf_ring = 1;
        bl->is_mmap = 1;
        return 0;
@@ -679,12 +697,6 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        if (reg.ring_entries >= 65536)
                return -EINVAL;
 
-       if (unlikely(reg.bgid < BGID_ARRAY && !ctx->io_bl)) {
-               int ret = io_init_bl_list(ctx);
-               if (ret)
-                       return ret;
-       }
-
        bl = io_buffer_get_list(ctx, reg.bgid);
        if (bl) {
                /* if mapped buffer ring OR classic exists, don't allow */
@@ -733,11 +745,8 @@ int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        if (!bl->is_buf_ring)
                return -EINVAL;
 
-       __io_remove_buffers(ctx, bl, -1U);
-       if (bl->bgid >= BGID_ARRAY) {
-               xa_erase(&ctx->io_bl_xa, bl->bgid);
-               kfree_rcu(bl, rcu);
-       }
+       xa_erase(&ctx->io_bl_xa, bl->bgid);
+       io_put_bl(ctx, bl);
        return 0;
 }
 
@@ -767,37 +776,50 @@ int io_register_pbuf_status(struct io_ring_ctx *ctx, void __user *arg)
        return 0;
 }
 
-void *io_pbuf_get_address(struct io_ring_ctx *ctx, unsigned long bgid)
+struct io_buffer_list *io_pbuf_get_bl(struct io_ring_ctx *ctx,
+                                     unsigned long bgid)
 {
        struct io_buffer_list *bl;
+       bool ret;
 
-       bl = __io_buffer_get_list(ctx, smp_load_acquire(&ctx->io_bl), bgid);
-
-       if (!bl || !bl->is_mmap)
-               return NULL;
        /*
-        * Ensure the list is fully setup. Only strictly needed for RCU lookup
-        * via mmap, and in that case only for the array indexed groups. For
-        * the xarray lookups, it's either visible and ready, or not at all.
+        * We have to be a bit careful here - we're inside mmap and cannot grab
+        * the uring_lock. This means the buffer_list could be simultaneously
+        * going away, if someone is trying to be sneaky. Look it up under rcu
+        * so we know it's not going away, and attempt to grab a reference to
+        * it. If the ref is already zero, then fail the mapping. If successful,
+        * the caller will call io_put_bl() to drop the the reference at at the
+        * end. This may then safely free the buffer_list (and drop the pages)
+        * at that point, vm_insert_pages() would've already grabbed the
+        * necessary vma references.
         */
-       if (!smp_load_acquire(&bl->is_ready))
-               return NULL;
-
-       return bl->buf_ring;
+       rcu_read_lock();
+       bl = xa_load(&ctx->io_bl_xa, bgid);
+       /* must be a mmap'able buffer ring and have pages */
+       ret = false;
+       if (bl && bl->is_mmap)
+               ret = atomic_inc_not_zero(&bl->refs);
+       rcu_read_unlock();
+
+       if (ret)
+               return bl;
+
+       return ERR_PTR(-EINVAL);
 }
 
-/*
- * Called at or after ->release(), free the mmap'ed buffers that we used
- * for memory mapped provided buffer rings.
- */
-void io_kbuf_mmap_list_free(struct io_ring_ctx *ctx)
+int io_pbuf_mmap(struct file *file, struct vm_area_struct *vma)
 {
-       struct io_buf_free *ibf;
-       struct hlist_node *tmp;
+       struct io_ring_ctx *ctx = file->private_data;
+       loff_t pgoff = vma->vm_pgoff << PAGE_SHIFT;
+       struct io_buffer_list *bl;
+       int bgid, ret;
 
-       hlist_for_each_entry_safe(ibf, tmp, &ctx->io_buf_list, list) {
-               hlist_del(&ibf->list);
-               io_mem_free(ibf->mem);
-               kfree(ibf);
-       }
+       bgid = (pgoff & ~IORING_OFF_MMAP_MASK) >> IORING_OFF_PBUF_SHIFT;
+       bl = io_pbuf_get_bl(ctx, bgid);
+       if (IS_ERR(bl))
+               return PTR_ERR(bl);
+
+       ret = io_uring_mmap_pages(ctx, vma, bl->buf_pages, bl->buf_nr_pages);
+       io_put_bl(ctx, bl);
+       return ret;
 }