io_uring: skip request refcounting
[linux-block.git] / fs / io_uring.c
index 8a1c461559ac3d4a8e34a3a45f05b1cd671d6ac4..dae87c57694395bce90ca6338904fa1ae1679299 100644 (file)
@@ -710,6 +710,7 @@ enum {
        REQ_F_REISSUE_BIT,
        REQ_F_DONT_REISSUE_BIT,
        REQ_F_CREDS_BIT,
+       REQ_F_REFCOUNT_BIT,
        /* keep async read/write and isreg together and in order */
        REQ_F_NOWAIT_READ_BIT,
        REQ_F_NOWAIT_WRITE_BIT,
@@ -765,6 +766,8 @@ enum {
        REQ_F_ISREG             = BIT(REQ_F_ISREG_BIT),
        /* has creds assigned */
        REQ_F_CREDS             = BIT(REQ_F_CREDS_BIT),
+       /* skip refcounting if not set */
+       REQ_F_REFCOUNT          = BIT(REQ_F_REFCOUNT_BIT),
 };
 
 struct async_poll {
@@ -1041,7 +1044,7 @@ static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 static bool io_cqring_fill_event(struct io_ring_ctx *ctx, u64 user_data,
                                 long res, unsigned int cflags);
 static void io_put_req(struct io_kiocb *req);
-static void io_put_req_deferred(struct io_kiocb *req, int nr);
+static void io_put_req_deferred(struct io_kiocb *req);
 static void io_dismantle_req(struct io_kiocb *req);
 static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req);
 static void io_queue_linked_timeout(struct io_kiocb *req);
@@ -1078,6 +1081,49 @@ EXPORT_SYMBOL(io_uring_get_socket);
 #define io_for_each_link(pos, head) \
        for (pos = (head); pos; pos = pos->link)
 
+/*
+ * Shamelessly stolen from the mm implementation of page reference checking,
+ * see commit f958d7b528b1 for details.
+ */
+#define req_ref_zero_or_close_to_overflow(req) \
+       ((unsigned int) atomic_read(&(req->refs)) + 127u <= 127u)
+
+static inline bool req_ref_inc_not_zero(struct io_kiocb *req)
+{
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
+       return atomic_inc_not_zero(&req->refs);
+}
+
+static inline bool req_ref_put_and_test(struct io_kiocb *req)
+{
+       if (likely(!(req->flags & REQ_F_REFCOUNT)))
+               return true;
+
+       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
+       return atomic_dec_and_test(&req->refs);
+}
+
+static inline void req_ref_put(struct io_kiocb *req)
+{
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
+       WARN_ON_ONCE(req_ref_put_and_test(req));
+}
+
+static inline void req_ref_get(struct io_kiocb *req)
+{
+       WARN_ON_ONCE(!(req->flags & REQ_F_REFCOUNT));
+       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
+       atomic_inc(&req->refs);
+}
+
+static inline void io_req_refcount(struct io_kiocb *req)
+{
+       if (!(req->flags & REQ_F_REFCOUNT)) {
+               req->flags |= REQ_F_REFCOUNT;
+               atomic_set(&req->refs, 1);
+       }
+}
+
 static inline void io_req_set_rsrc_node(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
@@ -1292,10 +1338,10 @@ static void io_prep_async_link(struct io_kiocb *req)
        if (req->flags & REQ_F_LINK_TIMEOUT) {
                struct io_ring_ctx *ctx = req->ctx;
 
-               spin_lock_irq(&ctx->completion_lock);
+               spin_lock(&ctx->completion_lock);
                io_for_each_link(cur, req)
                        io_prep_async_work(cur);
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
        } else {
                io_for_each_link(cur, req)
                        io_prep_async_work(cur);
@@ -1342,7 +1388,7 @@ static void io_kill_timeout(struct io_kiocb *req, int status)
                        atomic_read(&req->ctx->cq_timeouts) + 1);
                list_del_init(&req->timeout.list);
                io_cqring_fill_event(req->ctx, req->user_data, status, 0);
-               io_put_req_deferred(req, 1);
+               io_put_req_deferred(req);
        }
 }
 
@@ -1364,9 +1410,8 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
        __must_hold(&ctx->completion_lock)
 {
        u32 seq = ctx->cached_cq_tail - atomic_read(&ctx->cq_timeouts);
-       unsigned long flags;
 
-       spin_lock_irqsave(&ctx->timeout_lock, flags);
+       spin_lock_irq(&ctx->timeout_lock);
        while (!list_empty(&ctx->timeout_list)) {
                u32 events_needed, events_got;
                struct io_kiocb *req = list_first_entry(&ctx->timeout_list,
@@ -1391,7 +1436,7 @@ static void io_flush_timeouts(struct io_ring_ctx *ctx)
                io_kill_timeout(req, 0);
        }
        ctx->cq_last_tm_flush = seq;
-       spin_unlock_irqrestore(&ctx->timeout_lock, flags);
+       spin_unlock_irq(&ctx->timeout_lock);
 }
 
 static void __io_commit_cqring_flush(struct io_ring_ctx *ctx)
@@ -1484,14 +1529,13 @@ static void io_cqring_ev_posted_iopoll(struct io_ring_ctx *ctx)
 /* Returns true if there are no backlogged entries after the flush */
 static bool __io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
 {
-       unsigned long flags;
        bool all_flushed, posted;
 
        if (!force && __io_cqring_events(ctx) == ctx->cq_entries)
                return false;
 
        posted = false;
-       spin_lock_irqsave(&ctx->completion_lock, flags);
+       spin_lock(&ctx->completion_lock);
        while (!list_empty(&ctx->cq_overflow_list)) {
                struct io_uring_cqe *cqe = io_get_cqe(ctx);
                struct io_overflow_cqe *ocqe;
@@ -1519,7 +1563,7 @@ static bool __io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
 
        if (posted)
                io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+       spin_unlock(&ctx->completion_lock);
        if (posted)
                io_cqring_ev_posted(ctx);
        return all_flushed;
@@ -1541,41 +1585,6 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx)
        return ret;
 }
 
-/*
- * Shamelessly stolen from the mm implementation of page reference checking,
- * see commit f958d7b528b1 for details.
- */
-#define req_ref_zero_or_close_to_overflow(req) \
-       ((unsigned int) atomic_read(&(req->refs)) + 127u <= 127u)
-
-static inline bool req_ref_inc_not_zero(struct io_kiocb *req)
-{
-       return atomic_inc_not_zero(&req->refs);
-}
-
-static inline bool req_ref_sub_and_test(struct io_kiocb *req, int refs)
-{
-       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
-       return atomic_sub_and_test(refs, &req->refs);
-}
-
-static inline bool req_ref_put_and_test(struct io_kiocb *req)
-{
-       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
-       return atomic_dec_and_test(&req->refs);
-}
-
-static inline void req_ref_put(struct io_kiocb *req)
-{
-       WARN_ON_ONCE(req_ref_put_and_test(req));
-}
-
-static inline void req_ref_get(struct io_kiocb *req)
-{
-       WARN_ON_ONCE(req_ref_zero_or_close_to_overflow(req));
-       atomic_inc(&req->refs);
-}
-
 /* must to be called somewhat shortly after putting a request */
 static inline void io_put_task(struct task_struct *task, int nr)
 {
@@ -1648,9 +1657,8 @@ static void io_req_complete_post(struct io_kiocb *req, long res,
                                 unsigned int cflags)
 {
        struct io_ring_ctx *ctx = req->ctx;
-       unsigned long flags;
 
-       spin_lock_irqsave(&ctx->completion_lock, flags);
+       spin_lock(&ctx->completion_lock);
        __io_cqring_fill_event(ctx, req->user_data, res, cflags);
        /*
         * If we're the last reference to this request, add to our locked
@@ -1674,7 +1682,7 @@ static void io_req_complete_post(struct io_kiocb *req, long res,
                        req = NULL;
        }
        io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+       spin_unlock(&ctx->completion_lock);
 
        if (req) {
                io_cqring_ev_posted(ctx);
@@ -1714,7 +1722,6 @@ static inline void io_req_complete(struct io_kiocb *req, long res)
 static void io_req_complete_failed(struct io_kiocb *req, long res)
 {
        req_set_fail(req);
-       io_put_req(req);
        io_req_complete_post(req, res, 0);
 }
 
@@ -1734,10 +1741,10 @@ static void io_preinit_req(struct io_kiocb *req, struct io_ring_ctx *ctx)
 static void io_flush_cached_locked_reqs(struct io_ring_ctx *ctx,
                                        struct io_submit_state *state)
 {
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        list_splice_init(&ctx->locked_free_list, &state->free_list);
        ctx->locked_free_nr = 0;
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 }
 
 /* Returns true IFF there are requests in the cache */
@@ -1769,7 +1776,14 @@ static bool io_flush_cached_reqs(struct io_ring_ctx *ctx)
        return nr != 0;
 }
 
+/*
+ * A request might get retired back into the request caches even before opcode
+ * handlers and io_issue_sqe() are done with it, e.g. inline completion path.
+ * Because of that, io_alloc_req() should be called only under ->uring_lock
+ * and with extra caution to not get a request that is still worked on.
+ */
 static struct io_kiocb *io_alloc_req(struct io_ring_ctx *ctx)
+       __must_hold(&ctx->uring_lock)
 {
        struct io_submit_state *state = &ctx->submit_state;
        gfp_t gfp = GFP_KERNEL | __GFP_NOWARN;
@@ -1827,15 +1841,14 @@ static void io_dismantle_req(struct io_kiocb *req)
 static void __io_free_req(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
-       unsigned long flags;
 
        io_dismantle_req(req);
        io_put_task(req->task, 1);
 
-       spin_lock_irqsave(&ctx->completion_lock, flags);
+       spin_lock(&ctx->completion_lock);
        list_add(&req->inflight_entry, &ctx->locked_free_list);
        ctx->locked_free_nr++;
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+       spin_unlock(&ctx->completion_lock);
 
        percpu_ref_put(&ctx->refs);
 }
@@ -1866,7 +1879,7 @@ static bool io_kill_linked_timeout(struct io_kiocb *req)
                if (hrtimer_try_to_cancel(&io->timer) != -1) {
                        io_cqring_fill_event(link->ctx, link->user_data,
                                             -ECANCELED, 0);
-                       io_put_req_deferred(link, 1);
+                       io_put_req_deferred(link);
                        return true;
                }
        }
@@ -1885,7 +1898,7 @@ static void io_fail_links(struct io_kiocb *req)
 
                trace_io_uring_fail_link(req, link);
                io_cqring_fill_event(link->ctx, link->user_data, -ECANCELED, 0);
-               io_put_req_deferred(link, 2);
+               io_put_req_deferred(link);
                link = nxt;
        }
 }
@@ -1922,14 +1935,13 @@ static struct io_kiocb *__io_req_find_next(struct io_kiocb *req)
         */
        if (req->flags & (REQ_F_LINK_TIMEOUT | REQ_F_FAIL)) {
                struct io_ring_ctx *ctx = req->ctx;
-               unsigned long flags;
                bool posted;
 
-               spin_lock_irqsave(&ctx->completion_lock, flags);
+               spin_lock(&ctx->completion_lock);
                posted = io_disarm_next(req);
                if (posted)
                        io_commit_cqring(req->ctx);
-               spin_unlock_irqrestore(&ctx->completion_lock, flags);
+               spin_unlock(&ctx->completion_lock);
                if (posted)
                        io_cqring_ev_posted(ctx);
        }
@@ -2152,7 +2164,7 @@ static void io_submit_flush_completions(struct io_ring_ctx *ctx)
        int i, nr = state->compl_nr;
        struct req_batch rb;
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        for (i = 0; i < nr; i++) {
                struct io_kiocb *req = state->compl_reqs[i];
 
@@ -2160,15 +2172,14 @@ static void io_submit_flush_completions(struct io_ring_ctx *ctx)
                                        req->compl.cflags);
        }
        io_commit_cqring(ctx);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        io_cqring_ev_posted(ctx);
 
        io_init_req_batch(&rb);
        for (i = 0; i < nr; i++) {
                struct io_kiocb *req = state->compl_reqs[i];
 
-               /* submission and completion refs */
-               if (req_ref_sub_and_test(req, 2))
+               if (req_ref_put_and_test(req))
                        io_req_free_batch(&rb, req, &ctx->submit_state);
        }
 
@@ -2197,9 +2208,9 @@ static inline void io_put_req(struct io_kiocb *req)
                io_free_req(req);
 }
 
-static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
+static inline void io_put_req_deferred(struct io_kiocb *req)
 {
-       if (req_ref_sub_and_test(req, refs)) {
+       if (req_ref_put_and_test(req)) {
                req->io_task_work.func = io_free_req;
                io_req_task_work_add(req);
        }
@@ -2272,7 +2283,6 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
                if (READ_ONCE(req->result) == -EAGAIN && resubmit &&
                    !(req->flags & REQ_F_DONT_REISSUE)) {
                        req->iopoll_completed = 0;
-                       req_ref_get(req);
                        io_req_task_queue_reissue(req);
                        continue;
                }
@@ -2482,31 +2492,48 @@ static bool io_rw_should_reissue(struct io_kiocb *req)
 }
 #endif
 
-static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
-                            unsigned int issue_flags)
+static bool __io_complete_rw_common(struct io_kiocb *req, long res)
 {
-       int cflags = 0;
-
        if (req->rw.kiocb.ki_flags & IOCB_WRITE)
                kiocb_end_write(req);
        if (res != req->result) {
                if ((res == -EAGAIN || res == -EOPNOTSUPP) &&
                    io_rw_should_reissue(req)) {
                        req->flags |= REQ_F_REISSUE;
-                       return;
+                       return true;
                }
                req_set_fail(req);
+               req->result = res;
        }
+       return false;
+}
+
+static void io_req_task_complete(struct io_kiocb *req)
+{
+       int cflags = 0;
+
        if (req->flags & REQ_F_BUFFER_SELECTED)
                cflags = io_put_rw_kbuf(req);
-       __io_req_complete(req, issue_flags, res, cflags);
+       __io_req_complete(req, 0, req->result, cflags);
+}
+
+static void __io_complete_rw(struct io_kiocb *req, long res, long res2,
+                            unsigned int issue_flags)
+{
+       if (__io_complete_rw_common(req, res))
+               return;
+       io_req_task_complete(req);
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res, long res2)
 {
        struct io_kiocb *req = container_of(kiocb, struct io_kiocb, rw.kiocb);
 
-       __io_complete_rw(req, res, res2, 0);
+       if (__io_complete_rw_common(req, res))
+               return;
+       req->result = res;
+       req->io_task_work.func = io_req_task_complete;
+       io_req_task_work_add(req);
 }
 
 static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2)
@@ -2753,7 +2780,6 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
        if (check_reissue && (req->flags & REQ_F_REISSUE)) {
                req->flags &= ~REQ_F_REISSUE;
                if (io_resubmit_prep(req)) {
-                       req_ref_get(req);
                        io_req_task_queue_reissue(req);
                } else {
                        int cflags = 0;
@@ -3179,9 +3205,6 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
 
        req->rw.kiocb.ki_flags &= ~IOCB_WAITQ;
        list_del_init(&wait->entry);
-
-       /* submit ref gets dropped, acquire a new one */
-       req_ref_get(req);
        io_req_task_queue(req);
        return 1;
 }
@@ -4848,7 +4871,7 @@ static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
                req->result = vfs_poll(req->file, &pt) & poll->events;
        }
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        if (!req->result && !READ_ONCE(poll->canceled)) {
                add_wait_queue(poll->head, &poll->wait);
                return true;
@@ -4882,12 +4905,12 @@ static void io_poll_remove_double(struct io_kiocb *req)
        if (poll && poll->head) {
                struct wait_queue_head *head = poll->head;
 
-               spin_lock(&head->lock);
+               spin_lock_irq(&head->lock);
                list_del_init(&poll->wait.entry);
                if (poll->wait.private)
                        req_ref_put(req);
                poll->head = NULL;
-               spin_unlock(&head->lock);
+               spin_unlock_irq(&head->lock);
        }
 }
 
@@ -4923,7 +4946,7 @@ static void io_poll_task_func(struct io_kiocb *req)
        struct io_kiocb *nxt;
 
        if (io_poll_rewait(req, &req->poll)) {
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
        } else {
                bool done;
 
@@ -4935,7 +4958,7 @@ static void io_poll_task_func(struct io_kiocb *req)
                        req->result = 0;
                        add_wait_queue(req->poll.head, &req->poll.wait);
                }
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
                io_cqring_ev_posted(ctx);
 
                if (done) {
@@ -4952,6 +4975,7 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
        struct io_kiocb *req = wait->private;
        struct io_poll_iocb *poll = io_poll_get_single(req);
        __poll_t mask = key_to_poll(key);
+       unsigned long flags;
 
        /* for instances that support it check for an event match first: */
        if (mask && !(mask & poll->events))
@@ -4964,13 +4988,13 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
        if (poll->head) {
                bool done;
 
-               spin_lock(&poll->head->lock);
+               spin_lock_irqsave(&poll->head->lock, flags);
                done = list_empty(&poll->wait.entry);
                if (!done)
                        list_del_init(&poll->wait.entry);
                /* make sure double remove sees this as being gone */
                wait->private = NULL;
-               spin_unlock(&poll->head->lock);
+               spin_unlock_irqrestore(&poll->head->lock, flags);
                if (!done) {
                        /* use wait func handler, so it matches the rq type */
                        poll->wait.func(&poll->wait, mode, sync, key);
@@ -5058,13 +5082,13 @@ static void io_async_task_func(struct io_kiocb *req)
        trace_io_uring_task_run(req->ctx, req, req->opcode, req->user_data);
 
        if (io_poll_rewait(req, &apoll->poll)) {
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
                return;
        }
 
        hash_del(&req->hash_node);
        io_poll_remove_double(req);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        if (!READ_ONCE(apoll->poll.canceled))
                io_req_task_submit(req);
@@ -5116,11 +5140,11 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
        if (unlikely(!ipt->nr_entries) && !ipt->error)
                ipt->error = -EINVAL;
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        if (ipt->error || (mask && (poll->events & EPOLLONESHOT)))
                io_poll_remove_double(req);
        if (likely(poll->head)) {
-               spin_lock(&poll->head->lock);
+               spin_lock_irq(&poll->head->lock);
                if (unlikely(list_empty(&poll->wait.entry))) {
                        if (ipt->error)
                                cancel = true;
@@ -5133,7 +5157,7 @@ static __poll_t __io_arm_poll_handler(struct io_kiocb *req,
                        WRITE_ONCE(poll->canceled, true);
                else if (!poll->done) /* actually waiting for an event */
                        io_poll_req_insert(req);
-               spin_unlock(&poll->head->lock);
+               spin_unlock_irq(&poll->head->lock);
        }
 
        return mask;
@@ -5185,16 +5209,17 @@ static int io_arm_poll_handler(struct io_kiocb *req)
        req->apoll = apoll;
        req->flags |= REQ_F_POLLED;
        ipt.pt._qproc = io_async_queue_proc;
+       io_req_refcount(req);
 
        ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask,
                                        io_async_wake);
        if (ret || ipt.error) {
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
                if (ret)
                        return IO_APOLL_READY;
                return IO_APOLL_ABORTED;
        }
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        trace_io_uring_poll_arm(ctx, req, req->opcode, req->user_data,
                                mask, apoll->poll.events);
        return IO_APOLL_OK;
@@ -5208,14 +5233,14 @@ static bool __io_poll_remove_one(struct io_kiocb *req,
 
        if (!poll->head)
                return false;
-       spin_lock(&poll->head->lock);
+       spin_lock_irq(&poll->head->lock);
        if (do_cancel)
                WRITE_ONCE(poll->canceled, true);
        if (!list_empty(&poll->wait.entry)) {
                list_del_init(&poll->wait.entry);
                do_complete = true;
        }
-       spin_unlock(&poll->head->lock);
+       spin_unlock_irq(&poll->head->lock);
        hash_del(&req->hash_node);
        return do_complete;
 }
@@ -5223,7 +5248,6 @@ static bool __io_poll_remove_one(struct io_kiocb *req,
 static bool io_poll_remove_one(struct io_kiocb *req)
        __must_hold(&req->ctx->completion_lock)
 {
-       int refs;
        bool do_complete;
 
        io_poll_remove_double(req);
@@ -5233,10 +5257,7 @@ static bool io_poll_remove_one(struct io_kiocb *req)
                io_cqring_fill_event(req->ctx, req->user_data, -ECANCELED, 0);
                io_commit_cqring(req->ctx);
                req_set_fail(req);
-
-               /* non-poll requests have submit ref still */
-               refs = 1 + (req->opcode != IORING_OP_POLL_ADD);
-               io_put_req_deferred(req, refs);
+               io_put_req_deferred(req);
        }
        return do_complete;
 }
@@ -5251,7 +5272,7 @@ static bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
        struct io_kiocb *req;
        int posted = 0, i;
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) {
                struct hlist_head *list;
 
@@ -5261,7 +5282,7 @@ static bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
                                posted += io_poll_remove_one(req);
                }
        }
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        if (posted)
                io_cqring_ev_posted(ctx);
@@ -5379,6 +5400,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
        if (flags & ~IORING_POLL_ADD_MULTI)
                return -EINVAL;
 
+       io_req_refcount(req);
        poll->events = io_poll_parse_events(sqe, flags);
        return 0;
 }
@@ -5399,7 +5421,7 @@ static int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
                ipt.error = 0;
                io_poll_complete(req, mask);
        }
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        if (mask) {
                io_cqring_ev_posted(ctx);
@@ -5416,7 +5438,7 @@ static int io_poll_update(struct io_kiocb *req, unsigned int issue_flags)
        bool completing;
        int ret;
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        preq = io_poll_find(ctx, req->poll_update.old_user_data, true);
        if (!preq) {
                ret = -ENOENT;
@@ -5443,7 +5465,7 @@ static int io_poll_update(struct io_kiocb *req, unsigned int issue_flags)
        ret = 0;
 err:
        if (ret < 0) {
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
                req_set_fail(req);
                io_req_complete(req, ret);
                return 0;
@@ -5456,7 +5478,7 @@ err:
        }
        if (req->poll_update.update_user_data)
                preq->user_data = req->poll_update.new_user_data;
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        /* complete update request, we're done with it */
        io_req_complete(req, ret);
@@ -5475,10 +5497,10 @@ static void io_req_task_timeout(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        io_cqring_fill_event(ctx, req->user_data, -ETIME, 0);
        io_commit_cqring(ctx);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        io_cqring_ev_posted(ctx);
        req_set_fail(req);
@@ -5537,7 +5559,7 @@ static int io_timeout_cancel(struct io_ring_ctx *ctx, __u64 user_data)
 
        req_set_fail(req);
        io_cqring_fill_event(ctx, req->user_data, -ECANCELED, 0);
-       io_put_req_deferred(req, 1);
+       io_put_req_deferred(req);
        return 0;
 }
 
@@ -5610,10 +5632,10 @@ static int io_timeout_remove(struct io_kiocb *req, unsigned int issue_flags)
                                        io_translate_timeout_mode(tr->flags));
        spin_unlock_irq(&ctx->timeout_lock);
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        io_cqring_fill_event(ctx, req->user_data, ret, 0);
        io_commit_cqring(ctx);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        io_cqring_ev_posted(ctx);
        if (ret < 0)
                req_set_fail(req);
@@ -5751,16 +5773,15 @@ static void io_async_find_and_cancel(struct io_ring_ctx *ctx,
                                     struct io_kiocb *req, __u64 sqe_addr,
                                     int success_ret)
 {
-       unsigned long flags;
        int ret;
 
        ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
-       spin_lock_irqsave(&ctx->completion_lock, flags);
+       spin_lock(&ctx->completion_lock);
        if (ret != -ENOENT)
                goto done;
-       spin_lock(&ctx->timeout_lock);
+       spin_lock_irq(&ctx->timeout_lock);
        ret = io_timeout_cancel(ctx, sqe_addr);
-       spin_unlock(&ctx->timeout_lock);
+       spin_unlock_irq(&ctx->timeout_lock);
        if (ret != -ENOENT)
                goto done;
        ret = io_poll_cancel(ctx, sqe_addr, false);
@@ -5769,7 +5790,7 @@ done:
                ret = success_ret;
        io_cqring_fill_event(ctx, req->user_data, ret, 0);
        io_commit_cqring(ctx);
-       spin_unlock_irqrestore(&ctx->completion_lock, flags);
+       spin_unlock(&ctx->completion_lock);
        io_cqring_ev_posted(ctx);
 
        if (ret < 0)
@@ -5799,18 +5820,18 @@ static int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
 
        /* tasks should wait for their io-wq threads, so safe w/o sync */
        ret = io_async_cancel_one(req->task->io_uring, sqe_addr, ctx);
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        if (ret != -ENOENT)
                goto done;
-       spin_lock(&ctx->timeout_lock);
+       spin_lock_irq(&ctx->timeout_lock);
        ret = io_timeout_cancel(ctx, sqe_addr);
-       spin_unlock(&ctx->timeout_lock);
+       spin_unlock_irq(&ctx->timeout_lock);
        if (ret != -ENOENT)
                goto done;
        ret = io_poll_cancel(ctx, sqe_addr, false);
        if (ret != -ENOENT)
                goto done;
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        /* slow path, try all io-wq's */
        io_ring_submit_lock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
@@ -5824,11 +5845,11 @@ static int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
        }
        io_ring_submit_unlock(ctx, !(issue_flags & IO_URING_F_NONBLOCK));
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
 done:
        io_cqring_fill_event(ctx, req->user_data, ret, 0);
        io_commit_cqring(ctx);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        io_cqring_ev_posted(ctx);
 
        if (ret < 0)
@@ -6044,9 +6065,9 @@ fail:
                return true;
        }
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        if (!req_need_defer(req, seq) && list_empty(&ctx->defer_list)) {
-               spin_unlock_irq(&ctx->completion_lock);
+               spin_unlock(&ctx->completion_lock);
                kfree(de);
                io_queue_async_work(req);
                return true;
@@ -6056,7 +6077,7 @@ fail:
        de->req = req;
        de->seq = seq;
        list_add_tail(&de->list, &ctx->defer_list);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        return true;
 }
 
@@ -6271,6 +6292,10 @@ static void io_wq_submit_work(struct io_wq_work *work)
        struct io_kiocb *timeout;
        int ret = 0;
 
+       io_req_refcount(req);
+       /* will be dropped by ->io_free_work() after returning to io-wq */
+       req_ref_get(req);
+
        timeout = io_prep_linked_timeout(req);
        if (timeout)
                io_queue_linked_timeout(timeout);
@@ -6293,11 +6318,8 @@ static void io_wq_submit_work(struct io_wq_work *work)
        }
 
        /* avoid locking problems by failing it from a clean context */
-       if (ret) {
-               /* io-wq is going to take one down */
-               req_ref_get(req);
+       if (ret)
                io_req_task_queue_fail(req, ret);
-       }
 }
 
 static inline struct io_fixed_file *io_fixed_file_slot(struct io_file_table *table,
@@ -6439,6 +6461,11 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req)
            nxt->opcode != IORING_OP_LINK_TIMEOUT)
                return NULL;
 
+       /* linked timeouts should have two refs once prep'ed */
+       io_req_refcount(req);
+       io_req_refcount(nxt);
+       req_ref_get(nxt);
+
        nxt->timeout.head = req;
        nxt->flags |= REQ_F_LTIMEOUT_ACTIVE;
        req->flags |= REQ_F_LINK_TIMEOUT;
@@ -6459,7 +6486,6 @@ issue_sqe:
         * doesn't support non-blocking read/write attempts
         */
        if (likely(!ret)) {
-               /* drop submission reference */
                if (req->flags & REQ_F_COMPLETE_INLINE) {
                        struct io_ring_ctx *ctx = req->ctx;
                        struct io_submit_state *state = &ctx->submit_state;
@@ -6467,8 +6493,6 @@ issue_sqe:
                        state->compl_reqs[state->compl_nr++] = req;
                        if (state->compl_nr == ARRAY_SIZE(state->compl_reqs))
                                io_submit_flush_completions(ctx);
-               } else {
-                       io_put_req(req);
                }
        } else if (ret == -EAGAIN && !(req->flags & REQ_F_NOWAIT)) {
                switch (io_arm_poll_handler(req)) {
@@ -6548,8 +6572,6 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        req->user_data = READ_ONCE(sqe->user_data);
        req->file = NULL;
        req->fixed_rsrc_refs = NULL;
-       /* one is dropped after submission, the other at completion */
-       atomic_set(&req->refs, 2);
        req->task = current;
 
        /* enforce forwards compatibility on users */
@@ -6796,18 +6818,18 @@ static inline bool io_sqd_events_pending(struct io_sq_data *sqd)
 static inline void io_ring_set_wakeup_flag(struct io_ring_ctx *ctx)
 {
        /* Tell userspace we may need a wakeup call */
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        WRITE_ONCE(ctx->rings->sq_flags,
                   ctx->rings->sq_flags | IORING_SQ_NEED_WAKEUP);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 }
 
 static inline void io_ring_clear_wakeup_flag(struct io_ring_ctx *ctx)
 {
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        WRITE_ONCE(ctx->rings->sq_flags,
                   ctx->rings->sq_flags & ~IORING_SQ_NEED_WAKEUP);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 }
 
 static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
@@ -7654,11 +7676,11 @@ static void __io_rsrc_put_work(struct io_rsrc_node *ref_node)
                        bool lock_ring = ctx->flags & IORING_SETUP_IOPOLL;
 
                        io_ring_submit_lock(ctx, lock_ring);
-                       spin_lock_irq(&ctx->completion_lock);
+                       spin_lock(&ctx->completion_lock);
                        io_cqring_fill_event(ctx, prsrc->tag, 0, 0);
                        ctx->cq_extra++;
                        io_commit_cqring(ctx);
-                       spin_unlock_irq(&ctx->completion_lock);
+                       spin_unlock(&ctx->completion_lock);
                        io_cqring_ev_posted(ctx);
                        io_ring_submit_unlock(ctx, lock_ring);
                }
@@ -8829,8 +8851,8 @@ static void io_ring_exit_work(struct work_struct *work)
                mutex_lock(&ctx->uring_lock);
        }
        mutex_unlock(&ctx->uring_lock);
-       spin_lock_irq(&ctx->completion_lock);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
 
        io_ring_ctx_free(ctx);
 }
@@ -8842,18 +8864,18 @@ static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk,
        struct io_kiocb *req, *tmp;
        int canceled = 0;
 
-       spin_lock_irq(&ctx->completion_lock);
-       spin_lock(&ctx->timeout_lock);
+       spin_lock(&ctx->completion_lock);
+       spin_lock_irq(&ctx->timeout_lock);
        list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
                if (io_match_task(req, tsk, cancel_all)) {
                        io_kill_timeout(req, -ECANCELED);
                        canceled++;
                }
        }
-       spin_unlock(&ctx->timeout_lock);
+       spin_unlock_irq(&ctx->timeout_lock);
        if (canceled != 0)
                io_commit_cqring(ctx);
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        if (canceled != 0)
                io_cqring_ev_posted(ctx);
        return canceled != 0;
@@ -8909,13 +8931,12 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
        bool ret;
 
        if (!cancel->all && (req->flags & REQ_F_LINK_TIMEOUT)) {
-               unsigned long flags;
                struct io_ring_ctx *ctx = req->ctx;
 
                /* protect against races with linked timeouts */
-               spin_lock_irqsave(&ctx->completion_lock, flags);
+               spin_lock(&ctx->completion_lock);
                ret = io_match_task(req, cancel->task, cancel->all);
-               spin_unlock_irqrestore(&ctx->completion_lock, flags);
+               spin_unlock(&ctx->completion_lock);
        } else {
                ret = io_match_task(req, cancel->task, cancel->all);
        }
@@ -8928,14 +8949,14 @@ static bool io_cancel_defer_files(struct io_ring_ctx *ctx,
        struct io_defer_entry *de;
        LIST_HEAD(list);
 
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        list_for_each_entry_reverse(de, &ctx->defer_list, list) {
                if (io_match_task(de->req, task, cancel_all)) {
                        list_cut_position(&list, &ctx->defer_list, &de->list);
                        break;
                }
        }
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        if (list_empty(&list))
                return false;
 
@@ -9485,7 +9506,7 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                        io_uring_show_cred(m, index, cred);
        }
        seq_printf(m, "PollList:\n");
-       spin_lock_irq(&ctx->completion_lock);
+       spin_lock(&ctx->completion_lock);
        for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) {
                struct hlist_head *list = &ctx->cancel_hash[i];
                struct io_kiocb *req;
@@ -9494,7 +9515,7 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                        seq_printf(m, "  op=%d, task_works=%d\n", req->opcode,
                                        req->task->task_works != NULL);
        }
-       spin_unlock_irq(&ctx->completion_lock);
+       spin_unlock(&ctx->completion_lock);
        if (has_lock)
                mutex_unlock(&ctx->uring_lock);
 }