io_uring: add mapping support for NOMMU archs
[linux-2.6-block.git] / fs / io_uring.c
index e1a3b8b667e09695421d34aa455dc6f465c80f73..e6fc401e341f838152f04804beea33f6dd9f6a27 100644 (file)
@@ -271,7 +271,7 @@ struct io_ring_ctx {
                 * manipulate the list, hence no extra locking is needed there.
                 */
                struct list_head        poll_list;
-               struct list_head        cancel_list;
+               struct rb_root          cancel_tree;
 
                spinlock_t              inflight_lock;
                struct list_head        inflight_list;
@@ -323,7 +323,10 @@ struct io_kiocb {
        struct sqe_submit       submit;
 
        struct io_ring_ctx      *ctx;
-       struct list_head        list;
+       union {
+               struct list_head        list;
+               struct rb_node          rb_node;
+       };
        struct list_head        link_list;
        unsigned int            flags;
        refcount_t              refs;
@@ -340,8 +343,9 @@ struct io_kiocb {
 #define REQ_F_TIMEOUT          1024    /* timeout request */
 #define REQ_F_ISREG            2048    /* regular file */
 #define REQ_F_MUST_PUNT                4096    /* must be punted even for NONBLOCK */
-#define REQ_F_INFLIGHT         8192    /* on inflight list */
-#define REQ_F_COMP_LOCKED      16384   /* completion under lock */
+#define REQ_F_TIMEOUT_NOSEQ    8192    /* no timeout sequence */
+#define REQ_F_INFLIGHT         16384   /* on inflight list */
+#define REQ_F_COMP_LOCKED      32768   /* completion under lock */
        u64                     user_data;
        u32                     result;
        u32                     sequence;
@@ -433,7 +437,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        init_waitqueue_head(&ctx->wait);
        spin_lock_init(&ctx->completion_lock);
        INIT_LIST_HEAD(&ctx->poll_list);
-       INIT_LIST_HEAD(&ctx->cancel_list);
+       ctx->cancel_tree = RB_ROOT;
        INIT_LIST_HEAD(&ctx->defer_list);
        INIT_LIST_HEAD(&ctx->timeout_list);
        init_waitqueue_head(&ctx->inflight_wait);
@@ -448,7 +452,7 @@ err:
        return NULL;
 }
 
-static inline bool __io_sequence_defer(struct io_kiocb *req)
+static inline bool __req_need_defer(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
@@ -456,12 +460,12 @@ static inline bool __io_sequence_defer(struct io_kiocb *req)
                                        + atomic_read(&ctx->cached_cq_overflow);
 }
 
-static inline bool io_sequence_defer(struct io_kiocb *req)
+static inline bool req_need_defer(struct io_kiocb *req)
 {
-       if ((req->flags & (REQ_F_IO_DRAIN|REQ_F_IO_DRAINED)) != REQ_F_IO_DRAIN)
-               return false;
+       if ((req->flags & (REQ_F_IO_DRAIN|REQ_F_IO_DRAINED)) == REQ_F_IO_DRAIN)
+               return __req_need_defer(req);
 
-       return __io_sequence_defer(req);
+       return false;
 }
 
 static struct io_kiocb *io_get_deferred_req(struct io_ring_ctx *ctx)
@@ -469,7 +473,7 @@ static struct io_kiocb *io_get_deferred_req(struct io_ring_ctx *ctx)
        struct io_kiocb *req;
 
        req = list_first_entry_or_null(&ctx->defer_list, struct io_kiocb, list);
-       if (req && !io_sequence_defer(req)) {
+       if (req && !req_need_defer(req)) {
                list_del_init(&req->list);
                return req;
        }
@@ -482,9 +486,13 @@ static struct io_kiocb *io_get_timeout_req(struct io_ring_ctx *ctx)
        struct io_kiocb *req;
 
        req = list_first_entry_or_null(&ctx->timeout_list, struct io_kiocb, list);
-       if (req && !__io_sequence_defer(req)) {
-               list_del_init(&req->list);
-               return req;
+       if (req) {
+               if (req->flags & REQ_F_TIMEOUT_NOSEQ)
+                       return NULL;
+               if (!__req_need_defer(req)) {
+                       list_del_init(&req->list);
+                       return req;
+               }
        }
 
        return NULL;
@@ -1412,6 +1420,7 @@ static int io_prep_rw(struct io_kiocb *req, bool force_nonblock)
 
                kiocb->ki_flags |= IOCB_HIPRI;
                kiocb->ki_complete = io_complete_rw_iopoll;
+               req->result = 0;
        } else {
                if (kiocb->ki_flags & IOCB_HIPRI)
                        return -EINVAL;
@@ -1521,7 +1530,7 @@ static int io_import_fixed(struct io_ring_ctx *ctx, int rw,
                }
        }
 
-       return 0;
+       return len;
 }
 
 static ssize_t io_import_iovec(struct io_ring_ctx *ctx, int rw,
@@ -1934,6 +1943,14 @@ static int io_accept(struct io_kiocb *req, const struct io_uring_sqe *sqe,
 #endif
 }
 
+static inline void io_poll_remove_req(struct io_kiocb *req)
+{
+       if (!RB_EMPTY_NODE(&req->rb_node)) {
+               rb_erase(&req->rb_node, &req->ctx->cancel_tree);
+               RB_CLEAR_NODE(&req->rb_node);
+       }
+}
+
 static void io_poll_remove_one(struct io_kiocb *req)
 {
        struct io_poll_iocb *poll = &req->poll;
@@ -1945,17 +1962,17 @@ static void io_poll_remove_one(struct io_kiocb *req)
                io_queue_async_work(req);
        }
        spin_unlock(&poll->head->lock);
-
-       list_del_init(&req->list);
+       io_poll_remove_req(req);
 }
 
 static void io_poll_remove_all(struct io_ring_ctx *ctx)
 {
+       struct rb_node *node;
        struct io_kiocb *req;
 
        spin_lock_irq(&ctx->completion_lock);
-       while (!list_empty(&ctx->cancel_list)) {
-               req = list_first_entry(&ctx->cancel_list, struct io_kiocb,list);
+       while ((node = rb_first(&ctx->cancel_tree)) != NULL) {
+               req = rb_entry(node, struct io_kiocb, rb_node);
                io_poll_remove_one(req);
        }
        spin_unlock_irq(&ctx->completion_lock);
@@ -1963,13 +1980,21 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx)
 
 static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr)
 {
+       struct rb_node *p, *parent = NULL;
        struct io_kiocb *req;
 
-       list_for_each_entry(req, &ctx->cancel_list, list) {
-               if (req->user_data != sqe_addr)
-                       continue;
-               io_poll_remove_one(req);
-               return 0;
+       p = ctx->cancel_tree.rb_node;
+       while (p) {
+               parent = p;
+               req = rb_entry(parent, struct io_kiocb, rb_node);
+               if (sqe_addr < req->user_data) {
+                       p = p->rb_left;
+               } else if (sqe_addr > req->user_data) {
+                       p = p->rb_right;
+               } else {
+                       io_poll_remove_one(req);
+                       return 0;
+               }
        }
 
        return -ENOENT;
@@ -2039,7 +2064,7 @@ static void io_poll_complete_work(struct io_wq_work **workptr)
                spin_unlock_irq(&ctx->completion_lock);
                return;
        }
-       list_del_init(&req->list);
+       io_poll_remove_req(req);
        io_poll_complete(req, mask);
        spin_unlock_irq(&ctx->completion_lock);
 
@@ -2073,7 +2098,7 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
         * for finalizing the request, mark us as having grabbed that already.
         */
        if (mask && spin_trylock_irqsave(&ctx->completion_lock, flags)) {
-               list_del(&req->list);
+               io_poll_remove_req(req);
                io_poll_complete(req, mask);
                req->flags |= REQ_F_COMP_LOCKED;
                io_put_req(req);
@@ -2108,6 +2133,25 @@ static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
        add_wait_queue(head, &pt->req->poll.wait);
 }
 
+static void io_poll_req_insert(struct io_kiocb *req)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+       struct rb_node **p = &ctx->cancel_tree.rb_node;
+       struct rb_node *parent = NULL;
+       struct io_kiocb *tmp;
+
+       while (*p) {
+               parent = *p;
+               tmp = rb_entry(parent, struct io_kiocb, rb_node);
+               if (req->user_data < tmp->user_data)
+                       p = &(*p)->rb_left;
+               else
+                       p = &(*p)->rb_right;
+       }
+       rb_link_node(&req->rb_node, parent, p);
+       rb_insert_color(&req->rb_node, &ctx->cancel_tree);
+}
+
 static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                       struct io_kiocb **nxt)
 {
@@ -2129,6 +2173,7 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        INIT_IO_WORK(&req->work, io_poll_complete_work);
        events = READ_ONCE(sqe->poll_events);
        poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
+       RB_CLEAR_NODE(&req->rb_node);
 
        poll->head = NULL;
        poll->done = false;
@@ -2161,7 +2206,7 @@ static int io_poll_add(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                else if (cancel)
                        WRITE_ONCE(poll->canceled, true);
                else if (!poll->done) /* actually waiting for an event */
-                       list_add_tail(&req->list, &ctx->cancel_list);
+                       io_poll_req_insert(req);
                spin_unlock(&poll->head->lock);
        }
        if (mask) { /* no async, we'd stolen it */
@@ -2301,19 +2346,24 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                mode = HRTIMER_MODE_REL;
 
        hrtimer_init(&req->timeout.timer, CLOCK_MONOTONIC, mode);
+       req->flags |= REQ_F_TIMEOUT;
 
        /*
         * sqe->off holds how many events that need to occur for this
-        * timeout event to be satisfied.
+        * timeout event to be satisfied. If it isn't set, then this is
+        * a pure timeout request, sequence isn't used.
         */
        count = READ_ONCE(sqe->off);
-       if (!count)
-               count = 1;
+       if (!count) {
+               req->flags |= REQ_F_TIMEOUT_NOSEQ;
+               spin_lock_irq(&ctx->completion_lock);
+               entry = ctx->timeout_list.prev;
+               goto add;
+       }
 
        req->sequence = ctx->cached_sq_head + count - 1;
        /* reuse it to store the count */
        req->submit.sequence = count;
-       req->flags |= REQ_F_TIMEOUT;
 
        /*
         * Insertion sort, ensuring the first entry in the list is always
@@ -2325,6 +2375,9 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                unsigned nxt_sq_head;
                long long tmp, tmp_nxt;
 
+               if (nxt->flags & REQ_F_TIMEOUT_NOSEQ)
+                       continue;
+
                /*
                 * Since cached_sq_head + count - 1 can overflow, use type long
                 * long to store it.
@@ -2351,6 +2404,7 @@ static int io_timeout(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                nxt->sequence++;
        }
        req->sequence -= span;
+add:
        list_add(&req->list, entry);
        req->timeout.timer.function = io_timeout_fn;
        hrtimer_start(&req->timeout.timer, timespec64_to_ktime(ts), mode);
@@ -2436,7 +2490,8 @@ static int io_req_defer(struct io_kiocb *req)
        struct io_uring_sqe *sqe_copy;
        struct io_ring_ctx *ctx = req->ctx;
 
-       if (!io_sequence_defer(req) && list_empty(&ctx->defer_list))
+       /* Still need defer if there is pending req in defer list. */
+       if (!req_need_defer(req) && list_empty(&ctx->defer_list))
                return 0;
 
        sqe_copy = kmalloc(sizeof(*sqe_copy), GFP_KERNEL);
@@ -2444,7 +2499,7 @@ static int io_req_defer(struct io_kiocb *req)
                return -EAGAIN;
 
        spin_lock_irq(&ctx->completion_lock);
-       if (!io_sequence_defer(req) && list_empty(&ctx->defer_list)) {
+       if (!req_need_defer(req) && list_empty(&ctx->defer_list)) {
                spin_unlock_irq(&ctx->completion_lock);
                kfree(sqe_copy);
                return 0;
@@ -2598,6 +2653,10 @@ static bool io_op_needs_file(const struct io_uring_sqe *sqe)
        switch (op) {
        case IORING_OP_NOP:
        case IORING_OP_POLL_REMOVE:
+       case IORING_OP_TIMEOUT:
+       case IORING_OP_TIMEOUT_REMOVE:
+       case IORING_OP_ASYNC_CANCEL:
+       case IORING_OP_LINK_TIMEOUT:
                return false;
        default:
                return true;
@@ -4303,7 +4362,6 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
        DEFINE_WAIT(wait);
 
        while (!list_empty_careful(&ctx->inflight_list)) {
-               enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
                struct io_kiocb *cancel_req = NULL;
 
                spin_lock_irq(&ctx->inflight_lock);
@@ -4321,14 +4379,12 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                                                TASK_UNINTERRUPTIBLE);
                spin_unlock_irq(&ctx->inflight_lock);
 
-               if (cancel_req) {
-                       ret = io_wq_cancel_work(ctx->io_wq, &cancel_req->work);
-                       io_put_req(cancel_req);
-               }
-
                /* We need to keep going until we don't find a matching req */
                if (!cancel_req)
                        break;
+
+               io_wq_cancel_work(ctx->io_wq, &cancel_req->work);
+               io_put_req(cancel_req);
                schedule();
        }
        finish_wait(&ctx->inflight_wait, &wait);
@@ -4346,12 +4402,11 @@ static int io_uring_flush(struct file *file, void *data)
        return 0;
 }
 
-static int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
+static void *io_uring_validate_mmap_request(struct file *file,
+                                           loff_t pgoff, size_t sz)
 {
-       loff_t offset = (loff_t) vma->vm_pgoff << PAGE_SHIFT;
-       unsigned long sz = vma->vm_end - vma->vm_start;
        struct io_ring_ctx *ctx = file->private_data;
-       unsigned long pfn;
+       loff_t offset = pgoff << PAGE_SHIFT;
        struct page *page;
        void *ptr;
 
@@ -4364,17 +4419,59 @@ static int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
                ptr = ctx->sq_sqes;
                break;
        default:
-               return -EINVAL;
+               return ERR_PTR(-EINVAL);
        }
 
        page = virt_to_head_page(ptr);
        if (sz > page_size(page))
-               return -EINVAL;
+               return ERR_PTR(-EINVAL);
+
+       return ptr;
+}
+
+#ifdef CONFIG_MMU
+
+static int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
+{
+       size_t sz = vma->vm_end - vma->vm_start;
+       unsigned long pfn;
+       void *ptr;
+
+       ptr = io_uring_validate_mmap_request(file, vma->vm_pgoff, sz);
+       if (IS_ERR(ptr))
+               return PTR_ERR(ptr);
 
        pfn = virt_to_phys(ptr) >> PAGE_SHIFT;
        return remap_pfn_range(vma, vma->vm_start, pfn, sz, vma->vm_page_prot);
 }
 
+#else /* !CONFIG_MMU */
+
+static int io_uring_mmap(struct file *file, struct vm_area_struct *vma)
+{
+       return vma->vm_flags & (VM_SHARED | VM_MAYSHARE) ? 0 : -EINVAL;
+}
+
+static unsigned int io_uring_nommu_mmap_capabilities(struct file *file)
+{
+       return NOMMU_MAP_DIRECT | NOMMU_MAP_READ | NOMMU_MAP_WRITE;
+}
+
+static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
+       unsigned long addr, unsigned long len,
+       unsigned long pgoff, unsigned long flags)
+{
+       void *ptr;
+
+       ptr = io_uring_validate_mmap_request(file, pgoff, len);
+       if (IS_ERR(ptr))
+               return PTR_ERR(ptr);
+
+       return (unsigned long) ptr;
+}
+
+#endif /* !CONFIG_MMU */
+
 SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
                u32, min_complete, u32, flags, const sigset_t __user *, sig,
                size_t, sigsz)
@@ -4445,6 +4542,10 @@ static const struct file_operations io_uring_fops = {
        .release        = io_uring_release,
        .flush          = io_uring_flush,
        .mmap           = io_uring_mmap,
+#ifndef CONFIG_MMU
+       .get_unmapped_area = io_uring_nommu_get_unmapped_area,
+       .mmap_capabilities = io_uring_nommu_mmap_capabilities,
+#endif
        .poll           = io_uring_poll,
        .fasync         = io_uring_fasync,
 };