io_uring: encapsulate task_work state
[linux-block.git] / io_uring / io_uring.c
index 2669aca0ba39d328b378e9a8b092e889e91da18b..536940675c672bbf39c1c7b793a3fe03444d514b 100644 (file)
@@ -247,12 +247,12 @@ static __cold void io_fallback_req_func(struct work_struct *work)
                                                fallback_work.work);
        struct llist_node *node = llist_del_all(&ctx->fallback_llist);
        struct io_kiocb *req, *tmp;
-       bool locked = true;
+       struct io_tw_state ts = { .locked = true, };
 
        mutex_lock(&ctx->uring_lock);
        llist_for_each_entry_safe(req, tmp, node, io_task_work.node)
-               req->io_task_work.func(req, &locked);
-       if (WARN_ON_ONCE(!locked))
+               req->io_task_work.func(req, &ts);
+       if (WARN_ON_ONCE(!ts.locked))
                return;
        io_submit_flush_completions(ctx);
        mutex_unlock(&ctx->uring_lock);
@@ -457,7 +457,7 @@ static void io_prep_async_link(struct io_kiocb *req)
        }
 }
 
-void io_queue_iowq(struct io_kiocb *req, bool *dont_use)
+void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use)
 {
        struct io_kiocb *link = io_prep_linked_timeout(req);
        struct io_uring_task *tctx = req->task->io_uring;
@@ -1153,22 +1153,23 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
        return nxt;
 }
 
-static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked)
+static void ctx_flush_and_put(struct io_ring_ctx *ctx, struct io_tw_state *ts)
 {
        if (!ctx)
                return;
        if (ctx->flags & IORING_SETUP_TASKRUN_FLAG)
                atomic_andnot(IORING_SQ_TASKRUN, &ctx->rings->sq_flags);
-       if (*locked) {
+       if (ts->locked) {
                io_submit_flush_completions(ctx);
                mutex_unlock(&ctx->uring_lock);
-               *locked = false;
+               ts->locked = false;
        }
        percpu_ref_put(&ctx->refs);
 }
 
 static unsigned int handle_tw_list(struct llist_node *node,
-                                  struct io_ring_ctx **ctx, bool *locked,
+                                  struct io_ring_ctx **ctx,
+                                  struct io_tw_state *ts,
                                   struct llist_node *last)
 {
        unsigned int count = 0;
@@ -1181,17 +1182,17 @@ static unsigned int handle_tw_list(struct llist_node *node,
                prefetch(container_of(next, struct io_kiocb, io_task_work.node));
 
                if (req->ctx != *ctx) {
-                       ctx_flush_and_put(*ctx, locked);
+                       ctx_flush_and_put(*ctx, ts);
                        *ctx = req->ctx;
                        /* if not contended, grab and improve batching */
-                       *locked = mutex_trylock(&(*ctx)->uring_lock);
+                       ts->locked = mutex_trylock(&(*ctx)->uring_lock);
                        percpu_ref_get(&(*ctx)->refs);
                }
-               req->io_task_work.func(req, locked);
+               req->io_task_work.func(req, ts);
                node = next;
                count++;
                if (unlikely(need_resched())) {
-                       ctx_flush_and_put(*ctx, locked);
+                       ctx_flush_and_put(*ctx, ts);
                        *ctx = NULL;
                        cond_resched();
                }
@@ -1232,7 +1233,7 @@ static inline struct llist_node *io_llist_cmpxchg(struct llist_head *head,
 
 void tctx_task_work(struct callback_head *cb)
 {
-       bool uring_locked = false;
+       struct io_tw_state ts = {};
        struct io_ring_ctx *ctx = NULL;
        struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
                                                  task_work);
@@ -1249,12 +1250,12 @@ void tctx_task_work(struct callback_head *cb)
        do {
                loops++;
                node = io_llist_xchg(&tctx->task_list, &fake);
-               count += handle_tw_list(node, &ctx, &uring_locked, &fake);
+               count += handle_tw_list(node, &ctx, &ts, &fake);
 
                /* skip expensive cmpxchg if there are items in the list */
                if (READ_ONCE(tctx->task_list.first) != &fake)
                        continue;
-               if (uring_locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) {
+               if (ts.locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) {
                        io_submit_flush_completions(ctx);
                        if (READ_ONCE(tctx->task_list.first) != &fake)
                                continue;
@@ -1262,7 +1263,7 @@ void tctx_task_work(struct callback_head *cb)
                node = io_llist_cmpxchg(&tctx->task_list, &fake, NULL);
        } while (node != &fake);
 
-       ctx_flush_and_put(ctx, &uring_locked);
+       ctx_flush_and_put(ctx, &ts);
 
        /* relaxed read is enough as only the task itself sets ->in_cancel */
        if (unlikely(atomic_read(&tctx->in_cancel)))
@@ -1351,7 +1352,7 @@ static void __cold io_move_task_work_from_local(struct io_ring_ctx *ctx)
        }
 }
 
-static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
+static int __io_run_local_work(struct io_ring_ctx *ctx, struct io_tw_state *ts)
 {
        struct llist_node *node;
        unsigned int loops = 0;
@@ -1368,7 +1369,7 @@ again:
                struct io_kiocb *req = container_of(node, struct io_kiocb,
                                                    io_task_work.node);
                prefetch(container_of(next, struct io_kiocb, io_task_work.node));
-               req->io_task_work.func(req, locked);
+               req->io_task_work.func(req, ts);
                ret++;
                node = next;
        }
@@ -1376,7 +1377,7 @@ again:
 
        if (!llist_empty(&ctx->work_llist))
                goto again;
-       if (*locked) {
+       if (ts->locked) {
                io_submit_flush_completions(ctx);
                if (!llist_empty(&ctx->work_llist))
                        goto again;
@@ -1387,46 +1388,46 @@ again:
 
 static inline int io_run_local_work_locked(struct io_ring_ctx *ctx)
 {
-       bool locked;
+       struct io_tw_state ts = { .locked = true, };
        int ret;
 
        if (llist_empty(&ctx->work_llist))
                return 0;
 
-       locked = true;
-       ret = __io_run_local_work(ctx, &locked);
+       ret = __io_run_local_work(ctx, &ts);
        /* shouldn't happen! */
-       if (WARN_ON_ONCE(!locked))
+       if (WARN_ON_ONCE(!ts.locked))
                mutex_lock(&ctx->uring_lock);
        return ret;
 }
 
 static int io_run_local_work(struct io_ring_ctx *ctx)
 {
-       bool locked = mutex_trylock(&ctx->uring_lock);
+       struct io_tw_state ts = {};
        int ret;
 
-       ret = __io_run_local_work(ctx, &locked);
-       if (locked)
+       ts.locked = mutex_trylock(&ctx->uring_lock);
+       ret = __io_run_local_work(ctx, &ts);
+       if (ts.locked)
                mutex_unlock(&ctx->uring_lock);
 
        return ret;
 }
 
-static void io_req_task_cancel(struct io_kiocb *req, bool *locked)
+static void io_req_task_cancel(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        io_req_defer_failed(req, req->cqe.res);
 }
 
-void io_req_task_submit(struct io_kiocb *req, bool *locked)
+void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        /* req->task == current here, checking PF_EXITING is safe */
        if (unlikely(req->task->flags & PF_EXITING))
                io_req_defer_failed(req, -EFAULT);
        else if (req->flags & REQ_F_FORCE_ASYNC)
-               io_queue_iowq(req, locked);
+               io_queue_iowq(req, ts);
        else
                io_queue_sqe(req);
 }
@@ -1652,9 +1653,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
        return ret;
 }
 
-void io_req_task_complete(struct io_kiocb *req, bool *locked)
+void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       if (*locked)
+       if (ts->locked)
                io_req_complete_defer(req);
        else
                io_req_complete_post(req, IO_URING_F_UNLOCKED);
@@ -1933,9 +1934,9 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
        return 0;
 }
 
-int io_poll_issue(struct io_kiocb *req, bool *locked)
+int io_poll_issue(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        return io_issue_sqe(req, IO_URING_F_NONBLOCK|IO_URING_F_MULTISHOT|
                                 IO_URING_F_COMPLETE_DEFER);
 }