io_uring: skip request refcounting
[linux-block.git] / fs / io_uring.c
index 1b79f6e2da2e27d872de309c8b8cea4c5f747625..dae87c57694395bce90ca6338904fa1ae1679299 100644 (file)
@@ -710,6 +710,7 @@ enum {
        REQ_F_REISSUE_BIT,
        REQ_F_DONT_REISSUE_BIT,
        REQ_F_CREDS_BIT,
+       REQ_F_REFCOUNT_BIT,
        /* keep async read/write and isreg together and in order */
        REQ_F_NOWAIT_READ_BIT,
        REQ_F_NOWAIT_WRITE_BIT,
@@ -765,6 +766,8 @@ enum {
        REQ_F_ISREG             = BIT(REQ_F_ISREG_BIT),
        /* has creds assigned */
        REQ_F_CREDS             = BIT(REQ_F_CREDS_BIT),
+       /* skip refcounting if not set */
+       REQ_F_REFCOUNT          = BIT(REQ_F_REFCOUNT_BIT),
 };
 
 struct async_poll {
@@ -1041,7 +1044,7 @@ static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 static bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
                                 long res, unsigned int cflags);
 static void io_put_req(struct io_kiocb *req);
-static void io_put_req_deferred(struct io_kiocb *req, int nr);
+static void io_put_req_deferred(struct io_kiocb *req);
 static void io_dismantle_req(struct io_kiocb *req);
 static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req);
 static void io_queue_linked_timeout(struct io_kiocb *req);
@@ -1087,32 +1090,40 @@ EXPORT_SYMBOL(io_uring_get_socket);
 
 static inline bool req_ref_inc_not_zero(struct io_kiocb *req)
 {
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
        return atomic_inc_not_zero(&req->refs);
 }
 
-static inline bool req_ref_sub_and_test(struct io_kiocb *req, int refs)
-{
-       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
-       return atomic_sub_and_test(refs, &req->refs);
-}
-
 static inline bool req_ref_put_and_test(struct io_kiocb *req)
 {
+       if (likely(!(req->flags & REQ_F_REFCOUNT)))
+               return true;
+
        WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
        return atomic_dec_and_test(&req->refs);
 }
 
 static inline void req_ref_put(struct io_kiocb *req)
 {
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
        WARN_ON_ONCE(req_ref_put_and_test(req));
 }
 
 static inline void req_ref_get(struct io_kiocb *req)
 {
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
        WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
        atomic_inc(&req->refs);
 }
 
+static inline void io_req_refcount(struct io_kiocb *req)
+{
+       if (!(req->flags & REQ_F_REFCOUNT)) {
+               req->flags |= REQ_F_REFCOUNT;
+               atomic_set(&req->refs, 1);
+       }
+}
+
 static inline void io_req_set_rsrc_node(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
@@ -1377,7 +1388,7 @@ static void io_kill_timeout(struct io_kiocb *req, int status)
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
                io_cqring_fill_event(req->ctx, req->user_data, status, 0);
-               io_put_req_deferred(req, 1);
+               io_put_req_deferred(req);
        }
 }
 
@@ -1711,7 +1722,6 @@ static inline void io_req_complete(struct io_kiocb *req, long res)
 static void io_req_complete_failed(struct io_kiocb *req, long res)
 {
        req_set_fail(req);
-       io_put_req(req);
        io_req_complete_post(req, res, 0);
 }
 
@@ -1766,7 +1776,14 @@ static bool io_flush_cached_reqs(struct io_ring_ctx *ctx)
        return nr != 0;
 }
 
+/*
+ * A request might get retired back into the request caches even before opcode
+ * handlers and io_issue_sqe() are done with it, e.g. inline completion path.
+ * Because of that, io_alloc_req() should be called only under ->uring_lock
+ * and with extra caution to not get a request that is still worked on.
+ */
 static struct io_kiocb *io_alloc_req(struct io_ring_ctx *ctx)
+       __must_hold(&ctx->uring_lock)
 {
        struct io_submit_state *state = &ctx->submit_state;
        gfp_t gfp = GFP_KERNEL | __GFP_NOWARN;
@@ -1862,7 +1879,7 @@ static bool io_kill_linked_timeout(struct io_kiocb *req)
                if (hrtimer_try_to_cancel(&io->timer) != -1) {
                        io_cqring_fill_event(link->ctx, link->user_data,
                                             -ECANCELED, 0);
-                       io_put_req_deferred(link, 1);
+                       io_put_req_deferred(link);
                        return true;
                }
        }
@@ -1881,7 +1898,7 @@ static void io_fail_links(struct io_kiocb *req)
 
                trace_io_uring_fail_link(req, link);
                io_cqring_fill_event(link->ctx, link->user_data, -ECANCELED, 0);
-               io_put_req_deferred(link, 2);
+               io_put_req_deferred(link);
                link = nxt;
        }
 }
@@ -2162,8 +2179,7 @@ static void io_submit_flush_completions(struct io_ring_ctx *ctx)
        for (i = 0; i < nr; i++) {
                struct io_kiocb *req = state->compl_reqs[i];
 
-               /* submission and completion refs */
-               if (req_ref_sub_and_test(req, 2))
+               if (req_ref_put_and_test(req))
                        io_req_free_batch(&rb, req, &ctx->submit_state);
        }
 
@@ -2192,9 +2208,9 @@ static inline void io_put_req(struct io_kiocb *req)
                io_free_req(req);
 }
 
-static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
+static inline void io_put_req_deferred(struct io_kiocb *req)
 {
-       if (req_ref_sub_and_test(req, refs)) {
+       if (req_ref_put_and_test(req)) {
                req->io_task_work.func = io_free_req;
                io_req_task_work_add(req);
        }
@@ -2267,7 +2283,6 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
                if (READ_ONCE(req->result) == -EAGAIN && resubmit &&
                    !(req->flags & REQ_F_DONT_REISSUE)) {
                        req->iopoll_completed = 0;
-                       req_ref_get(req);
                        io_req_task_queue_reissue(req);
                        continue;
                }
@@ -2765,7 +2780,6 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
        if (check_reissue && (req->flags & REQ_F_REISSUE)) {
                req->flags &= ~REQ_F_REISSUE;
                if (io_resubmit_prep(req)) {
-                       req_ref_get(req);
                        io_req_task_queue_reissue(req);
                } else {
                        int cflags = 0;
@@ -3191,9 +3205,6 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
 
        req->rw.kiocb.ki_flags &= ~IOCB_WAITQ;
        list_del_init(&wait->entry);
-
-       /* submit ref gets dropped, acquire a new one */
-       req_ref_get(req);
        io_req_task_queue(req);
        return 1;
 }
@@ -5198,6 +5209,7 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        req->apoll = apoll;
        req->flags |= REQ_F_POLLED;
        ipt.pt._qproc = io_async_queue_proc;
+       io_req_refcount(req);
 
        ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
                                        io_async_wake);
@@ -5236,7 +5248,6 @@ static bool __io_poll_remove_one(struct io_kiocb *req,
 static bool io_poll_remove_one(struct io_kiocb *req)
        __must_hold(&req->ctx->completion_lock)
 {
-       int refs;
        bool do_complete;
 
        io_poll_remove_double(req);
@@ -5246,10 +5257,7 @@ static bool io_poll_remove_one(struct io_kiocb *req)
                io_cqring_fill_event(req->ctx, req->user_data, -ECANCELED, 0);
                io_commit_cqring(req->ctx);
                req_set_fail(req);
-
-               /* non-poll requests have submit ref still */
-               refs = 1 + (req->opcode != IORING_OP_POLL_ADD);
-               io_put_req_deferred(req, refs);
+               io_put_req_deferred(req);
        }
        return do_complete;
 }
@@ -5392,6 +5400,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
        if (flags & ~IORING_POLL_ADD_MULTI)
                return -EINVAL;
 
+       io_req_refcount(req);
        poll->events = io_poll_parse_events(sqe, flags);
        return 0;
 }
@@ -5550,7 +5559,7 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
 
        req_set_fail(req);
        io_cqring_fill_event(ctx, req->user_data, -ECANCELED, 0);
-       io_put_req_deferred(req, 1);
+       io_put_req_deferred(req);
        return 0;
 }
 
@@ -6283,6 +6292,10 @@ static void io_wq_submit_work(struct io_wq_work *work)
        struct io_kiocb *timeout;
        int ret = 0;
 
+       io_req_refcount(req);
+       /* will be dropped by ->io_free_work() after returning to io-wq */
+       req_ref_get(req);
+
        timeout = io_prep_linked_timeout(req);
        if (timeout)
                io_queue_linked_timeout(timeout);
@@ -6305,11 +6318,8 @@ static void io_wq_submit_work(struct io_wq_work *work)
        }
 
        /* avoid locking problems by failing it from a clean context */
-       if (ret) {
-               /* io-wq is going to take one down */
-               req_ref_get(req);
+       if (ret)
                io_req_task_queue_fail(req, ret);
-       }
 }
 
 static inline struct io_fixed_file *io_fixed_file_slot(struct io_file_table *table,
@@ -6451,6 +6461,11 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req)
            nxt->opcode != IORING_OP_LINK_TIMEOUT)
                return NULL;
 
+       /* linked timeouts should have two refs once prep'ed */
+       io_req_refcount(req);
+       io_req_refcount(nxt);
+       req_ref_get(nxt);
+
        nxt->timeout.head = req;
        nxt->flags |= REQ_F_LTIMEOUT_ACTIVE;
        req->flags |= REQ_F_LINK_TIMEOUT;
@@ -6471,7 +6486,6 @@ issue_sqe:
         * doesn't support non-blocking read/write attempts
         */
        if (likely(!ret)) {
-               /* drop submission reference */
                if (req->flags & REQ_F_COMPLETE_INLINE) {
                        struct io_ring_ctx *ctx = req->ctx;
                        struct io_submit_state *state = &ctx->submit_state;
@@ -6479,8 +6493,6 @@ issue_sqe:
                        state->compl_reqs[state->compl_nr++] = req;
                        if (state->compl_nr == ARRAY_SIZE(state->compl_reqs))
                                io_submit_flush_completions(ctx);
-               } else {
-                       io_put_req(req);
                }
        } else if (ret == -EAGAIN && !(req->flags & REQ_F_NOWAIT)) {
                switch (io_arm_poll_handler(req)) {
@@ -6560,8 +6572,6 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        req->user_data = READ_ONCE(sqe->user_data);
        req->file = NULL;
        req->fixed_rsrc_refs = NULL;
-       /* one is dropped after submission, the other at completion */
-       atomic_set(&req->refs, 2);
        req->task = current;
 
        /* enforce forwards compatibility on users */