Merge tag 'integrity-v6.4' of git://git.kernel.org/pub/scm/linux/kernel/git/zohar...
[linux-block.git] / io_uring / kbuf.c
index a90c820ce99e12e5e72d37dee4fefd2d24665c3a..2f0181521c98e41d6a0af6dc83d3d5491956f92b 100644 (file)
@@ -137,7 +137,8 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
                return NULL;
 
        head &= bl->mask;
-       if (head < IO_BUFFER_LIST_BUF_PER_PAGE) {
+       /* mmaped buffers are always contig */
+       if (bl->is_mmap || head < IO_BUFFER_LIST_BUF_PER_PAGE) {
                buf = &br->bufs[head];
        } else {
                int off = head & (IO_BUFFER_LIST_BUF_PER_PAGE - 1);
@@ -179,7 +180,7 @@ void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
 
        bl = io_buffer_get_list(ctx, req->buf_index);
        if (likely(bl)) {
-               if (bl->buf_nr_pages)
+               if (bl->is_mapped)
                        ret = io_ring_buffer_select(req, len, bl, issue_flags);
                else
                        ret = io_provided_buffer_select(req, len, bl);
@@ -214,17 +215,28 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx,
        if (!nbufs)
                return 0;
 
-       if (bl->buf_nr_pages) {
-               int j;
-
+       if (bl->is_mapped) {
                i = bl->buf_ring->tail - bl->head;
-               for (j = 0; j < bl->buf_nr_pages; j++)
-                       unpin_user_page(bl->buf_pages[j]);
-               kvfree(bl->buf_pages);
-               bl->buf_pages = NULL;
-               bl->buf_nr_pages = 0;
+               if (bl->is_mmap) {
+                       struct page *page;
+
+                       page = virt_to_head_page(bl->buf_ring);
+                       if (put_page_testzero(page))
+                               free_compound_page(page);
+                       bl->buf_ring = NULL;
+                       bl->is_mmap = 0;
+               } else if (bl->buf_nr_pages) {
+                       int j;
+
+                       for (j = 0; j < bl->buf_nr_pages; j++)
+                               unpin_user_page(bl->buf_pages[j]);
+                       kvfree(bl->buf_pages);
+                       bl->buf_pages = NULL;
+                       bl->buf_nr_pages = 0;
+               }
                /* make sure it's seen as empty */
                INIT_LIST_HEAD(&bl->buf_list);
+               bl->is_mapped = 0;
                return i;
        }
 
@@ -304,7 +316,7 @@ int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
        if (bl) {
                ret = -EINVAL;
                /* can't use provide/remove buffers command on mapped buffers */
-               if (!bl->buf_nr_pages)
+               if (!bl->is_mapped)
                        ret = __io_remove_buffers(ctx, bl, p->nbufs);
        }
        io_ring_submit_unlock(ctx, issue_flags);
@@ -449,7 +461,7 @@ int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
                }
        }
        /* can't add buffers via this command for a mapped buffer ring */
-       if (bl->buf_nr_pages) {
+       if (bl->is_mapped) {
                ret = -EINVAL;
                goto err;
        }
@@ -464,23 +476,87 @@ err:
        return IOU_OK;
 }
 
-int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
+static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
+                           struct io_buffer_list *bl)
 {
        struct io_uring_buf_ring *br;
-       struct io_uring_buf_reg reg;
-       struct io_buffer_list *bl, *free_bl = NULL;
        struct page **pages;
        int nr_pages;
 
+       pages = io_pin_pages(reg->ring_addr,
+                            flex_array_size(br, bufs, reg->ring_entries),
+                            &nr_pages);
+       if (IS_ERR(pages))
+               return PTR_ERR(pages);
+
+       br = page_address(pages[0]);
+#ifdef SHM_COLOUR
+       /*
+        * On platforms that have specific aliasing requirements, SHM_COLOUR
+        * is set and we must guarantee that the kernel and user side align
+        * nicely. We cannot do that if IOU_PBUF_RING_MMAP isn't set and
+        * the application mmap's the provided ring buffer. Fail the request
+        * if we, by chance, don't end up with aligned addresses. The app
+        * should use IOU_PBUF_RING_MMAP instead, and liburing will handle
+        * this transparently.
+        */
+       if ((reg->ring_addr | (unsigned long) br) & (SHM_COLOUR - 1)) {
+               int i;
+
+               for (i = 0; i < nr_pages; i++)
+                       unpin_user_page(pages[i]);
+               return -EINVAL;
+       }
+#endif
+       bl->buf_pages = pages;
+       bl->buf_nr_pages = nr_pages;
+       bl->buf_ring = br;
+       bl->is_mapped = 1;
+       bl->is_mmap = 0;
+       return 0;
+}
+
+static int io_alloc_pbuf_ring(struct io_uring_buf_reg *reg,
+                             struct io_buffer_list *bl)
+{
+       gfp_t gfp = GFP_KERNEL_ACCOUNT | __GFP_ZERO | __GFP_NOWARN | __GFP_COMP;
+       size_t ring_size;
+       void *ptr;
+
+       ring_size = reg->ring_entries * sizeof(struct io_uring_buf_ring);
+       ptr = (void *) __get_free_pages(gfp, get_order(ring_size));
+       if (!ptr)
+               return -ENOMEM;
+
+       bl->buf_ring = ptr;
+       bl->is_mapped = 1;
+       bl->is_mmap = 1;
+       return 0;
+}
+
+int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
+{
+       struct io_uring_buf_reg reg;
+       struct io_buffer_list *bl, *free_bl = NULL;
+       int ret;
+
        if (copy_from_user(&reg, arg, sizeof(reg)))
                return -EFAULT;
 
-       if (reg.pad || reg.resv[0] || reg.resv[1] || reg.resv[2])
+       if (reg.resv[0] || reg.resv[1] || reg.resv[2])
                return -EINVAL;
-       if (!reg.ring_addr)
-               return -EFAULT;
-       if (reg.ring_addr & ~PAGE_MASK)
+       if (reg.flags & ~IOU_PBUF_RING_MMAP)
                return -EINVAL;
+       if (!(reg.flags & IOU_PBUF_RING_MMAP)) {
+               if (!reg.ring_addr)
+                       return -EFAULT;
+               if (reg.ring_addr & ~PAGE_MASK)
+                       return -EINVAL;
+       } else {
+               if (reg.ring_addr)
+                       return -EINVAL;
+       }
+
        if (!is_power_of_2(reg.ring_entries))
                return -EINVAL;
 
@@ -497,7 +573,7 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        bl = io_buffer_get_list(ctx, reg.bgid);
        if (bl) {
                /* if mapped buffer ring OR classic exists, don't allow */
-               if (bl->buf_nr_pages || !list_empty(&bl->buf_list))
+               if (bl->is_mapped || !list_empty(&bl->buf_list))
                        return -EEXIST;
        } else {
                free_bl = bl = kzalloc(sizeof(*bl), GFP_KERNEL);
@@ -505,22 +581,21 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
                        return -ENOMEM;
        }
 
-       pages = io_pin_pages(reg.ring_addr,
-                            flex_array_size(br, bufs, reg.ring_entries),
-                            &nr_pages);
-       if (IS_ERR(pages)) {
-               kfree(free_bl);
-               return PTR_ERR(pages);
+       if (!(reg.flags & IOU_PBUF_RING_MMAP))
+               ret = io_pin_pbuf_ring(&reg, bl);
+       else
+               ret = io_alloc_pbuf_ring(&reg, bl);
+
+       if (!ret) {
+               bl->nr_entries = reg.ring_entries;
+               bl->mask = reg.ring_entries - 1;
+
+               io_buffer_add_list(ctx, bl, reg.bgid);
+               return 0;
        }
 
-       br = page_address(pages[0]);
-       bl->buf_pages = pages;
-       bl->buf_nr_pages = nr_pages;
-       bl->nr_entries = reg.ring_entries;
-       bl->buf_ring = br;
-       bl->mask = reg.ring_entries - 1;
-       io_buffer_add_list(ctx, bl, reg.bgid);
-       return 0;
+       kfree(free_bl);
+       return ret;
 }
 
 int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
@@ -530,13 +605,15 @@ int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
 
        if (copy_from_user(&reg, arg, sizeof(reg)))
                return -EFAULT;
-       if (reg.pad || reg.resv[0] || reg.resv[1] || reg.resv[2])
+       if (reg.resv[0] || reg.resv[1] || reg.resv[2])
+               return -EINVAL;
+       if (reg.flags)
                return -EINVAL;
 
        bl = io_buffer_get_list(ctx, reg.bgid);
        if (!bl)
                return -ENOENT;
-       if (!bl->buf_nr_pages)
+       if (!bl->is_mapped)
                return -EINVAL;
 
        __io_remove_buffers(ctx, bl, -1U);
@@ -546,3 +623,14 @@ int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        }
        return 0;
 }
+
+void *io_pbuf_get_address(struct io_ring_ctx *ctx, unsigned long bgid)
+{
+       struct io_buffer_list *bl;
+
+       bl = io_buffer_get_list(ctx, bgid);
+       if (!bl || !bl->is_mmap)
+               return NULL;
+
+       return bl->buf_ring;
+}