io_uring: encapsulate task_work state
authorPavel Begunkov <asml.silence@gmail.com>
Mon, 27 Mar 2023 15:38:15 +0000 (16:38 +0100)
committerJens Axboe <axboe@kernel.dk>
Mon, 3 Apr 2023 13:16:15 +0000 (07:16 -0600)
For task works we're passing around a bool pointer for whether the
current ring is locked or not, let's wrap it in a structure, that
will make it more opaque preventing abuse and will also help us
to pass more info in the future if needed.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/1ecec9483d58696e248d1bfd52cf62b04442df1d.1679931367.git.asml.silence@gmail.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
include/linux/io_uring_types.h
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/notif.c
io_uring/poll.c
io_uring/rw.c
io_uring/timeout.c
io_uring/uring_cmd.c

index 3d152bdcd30a09fc80b1cead358292356e496f2f..561fa421c453869e0d72982801aac374dd82bb66 100644 (file)
@@ -367,6 +367,11 @@ struct io_ring_ctx {
        unsigned                        evfd_last_cq_tail;
 };
 
+struct io_tw_state {
+       /* ->uring_lock is taken, callbacks can use io_tw_lock to lock it */
+       bool locked;
+};
+
 enum {
        REQ_F_FIXED_FILE_BIT    = IOSQE_FIXED_FILE_BIT,
        REQ_F_IO_DRAIN_BIT      = IOSQE_IO_DRAIN_BIT,
@@ -473,7 +478,7 @@ enum {
        REQ_F_HASH_LOCKED       = BIT(REQ_F_HASH_LOCKED_BIT),
 };
 
-typedef void (*io_req_tw_func_t)(struct io_kiocb *req, bool *locked);
+typedef void (*io_req_tw_func_t)(struct io_kiocb *req, struct io_tw_state *ts);
 
 struct io_task_work {
        struct llist_node               node;
index 2669aca0ba39d328b378e9a8b092e889e91da18b..536940675c672bbf39c1c7b793a3fe03444d514b 100644 (file)
@@ -247,12 +247,12 @@ static __cold void io_fallback_req_func(struct work_struct *work)
                                                fallback_work.work);
        struct llist_node *node = llist_del_all(&ctx->fallback_llist);
        struct io_kiocb *req, *tmp;
-       bool locked = true;
+       struct io_tw_state ts = { .locked = true, };
 
        mutex_lock(&ctx->uring_lock);
        llist_for_each_entry_safe(req, tmp, node, io_task_work.node)
-               req->io_task_work.func(req, &locked);
-       if (WARN_ON_ONCE(!locked))
+               req->io_task_work.func(req, &ts);
+       if (WARN_ON_ONCE(!ts.locked))
                return;
        io_submit_flush_completions(ctx);
        mutex_unlock(&ctx->uring_lock);
@@ -457,7 +457,7 @@ static void io_prep_async_link(struct io_kiocb *req)
        }
 }
 
-void io_queue_iowq(struct io_kiocb *req, bool *dont_use)
+void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use)
 {
        struct io_kiocb *link = io_prep_linked_timeout(req);
        struct io_uring_task *tctx = req->task->io_uring;
@@ -1153,22 +1153,23 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
        return nxt;
 }
 
-static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked)
+static void ctx_flush_and_put(struct io_ring_ctx *ctx, struct io_tw_state *ts)
 {
        if (!ctx)
                return;
        if (ctx->flags & IORING_SETUP_TASKRUN_FLAG)
                atomic_andnot(IORING_SQ_TASKRUN, &ctx->rings->sq_flags);
-       if (*locked) {
+       if (ts->locked) {
                io_submit_flush_completions(ctx);
                mutex_unlock(&ctx->uring_lock);
-               *locked = false;
+               ts->locked = false;
        }
        percpu_ref_put(&ctx->refs);
 }
 
 static unsigned int handle_tw_list(struct llist_node *node,
-                                  struct io_ring_ctx **ctx, bool *locked,
+                                  struct io_ring_ctx **ctx,
+                                  struct io_tw_state *ts,
                                   struct llist_node *last)
 {
        unsigned int count = 0;
@@ -1181,17 +1182,17 @@ static unsigned int handle_tw_list(struct llist_node *node,
                prefetch(container_of(next, struct io_kiocb, io_task_work.node));
 
                if (req->ctx != *ctx) {
-                       ctx_flush_and_put(*ctx, locked);
+                       ctx_flush_and_put(*ctx, ts);
                        *ctx = req->ctx;
                        /* if not contended, grab and improve batching */
-                       *locked = mutex_trylock(&(*ctx)->uring_lock);
+                       ts->locked = mutex_trylock(&(*ctx)->uring_lock);
                        percpu_ref_get(&(*ctx)->refs);
                }
-               req->io_task_work.func(req, locked);
+               req->io_task_work.func(req, ts);
                node = next;
                count++;
                if (unlikely(need_resched())) {
-                       ctx_flush_and_put(*ctx, locked);
+                       ctx_flush_and_put(*ctx, ts);
                        *ctx = NULL;
                        cond_resched();
                }
@@ -1232,7 +1233,7 @@ static inline struct llist_node *io_llist_cmpxchg(struct llist_head *head,
 
 void tctx_task_work(struct callback_head *cb)
 {
-       bool uring_locked = false;
+       struct io_tw_state ts = {};
        struct io_ring_ctx *ctx = NULL;
        struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
                                                  task_work);
@@ -1249,12 +1250,12 @@ void tctx_task_work(struct callback_head *cb)
        do {
                loops++;
                node = io_llist_xchg(&tctx->task_list, &fake);
-               count += handle_tw_list(node, &ctx, &uring_locked, &fake);
+               count += handle_tw_list(node, &ctx, &ts, &fake);
 
                /* skip expensive cmpxchg if there are items in the list */
                if (READ_ONCE(tctx->task_list.first) != &fake)
                        continue;
-               if (uring_locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) {
+               if (ts.locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) {
                        io_submit_flush_completions(ctx);
                        if (READ_ONCE(tctx->task_list.first) != &fake)
                                continue;
@@ -1262,7 +1263,7 @@ void tctx_task_work(struct callback_head *cb)
                node = io_llist_cmpxchg(&tctx->task_list, &fake, NULL);
        } while (node != &fake);
 
-       ctx_flush_and_put(ctx, &uring_locked);
+       ctx_flush_and_put(ctx, &ts);
 
        /* relaxed read is enough as only the task itself sets ->in_cancel */
        if (unlikely(atomic_read(&tctx->in_cancel)))
@@ -1351,7 +1352,7 @@ static void __cold io_move_task_work_from_local(struct io_ring_ctx *ctx)
        }
 }
 
-static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
+static int __io_run_local_work(struct io_ring_ctx *ctx, struct io_tw_state *ts)
 {
        struct llist_node *node;
        unsigned int loops = 0;
@@ -1368,7 +1369,7 @@ again:
                struct io_kiocb *req = container_of(node, struct io_kiocb,
                                                    io_task_work.node);
                prefetch(container_of(next, struct io_kiocb, io_task_work.node));
-               req->io_task_work.func(req, locked);
+               req->io_task_work.func(req, ts);
                ret++;
                node = next;
        }
@@ -1376,7 +1377,7 @@ again:
 
        if (!llist_empty(&ctx->work_llist))
                goto again;
-       if (*locked) {
+       if (ts->locked) {
                io_submit_flush_completions(ctx);
                if (!llist_empty(&ctx->work_llist))
                        goto again;
@@ -1387,46 +1388,46 @@ again:
 
 static inline int io_run_local_work_locked(struct io_ring_ctx *ctx)
 {
-       bool locked;
+       struct io_tw_state ts = { .locked = true, };
        int ret;
 
        if (llist_empty(&ctx->work_llist))
                return 0;
 
-       locked = true;
-       ret = __io_run_local_work(ctx, &locked);
+       ret = __io_run_local_work(ctx, &ts);
        /* shouldn't happen! */
-       if (WARN_ON_ONCE(!locked))
+       if (WARN_ON_ONCE(!ts.locked))
                mutex_lock(&ctx->uring_lock);
        return ret;
 }
 
 static int io_run_local_work(struct io_ring_ctx *ctx)
 {
-       bool locked = mutex_trylock(&ctx->uring_lock);
+       struct io_tw_state ts = {};
        int ret;
 
-       ret = __io_run_local_work(ctx, &locked);
-       if (locked)
+       ts.locked = mutex_trylock(&ctx->uring_lock);
+       ret = __io_run_local_work(ctx, &ts);
+       if (ts.locked)
                mutex_unlock(&ctx->uring_lock);
 
        return ret;
 }
 
-static void io_req_task_cancel(struct io_kiocb *req, bool *locked)
+static void io_req_task_cancel(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        io_req_defer_failed(req, req->cqe.res);
 }
 
-void io_req_task_submit(struct io_kiocb *req, bool *locked)
+void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        /* req->task == current here, checking PF_EXITING is safe */
        if (unlikely(req->task->flags & PF_EXITING))
                io_req_defer_failed(req, -EFAULT);
        else if (req->flags & REQ_F_FORCE_ASYNC)
-               io_queue_iowq(req, locked);
+               io_queue_iowq(req, ts);
        else
                io_queue_sqe(req);
 }
@@ -1652,9 +1653,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
        return ret;
 }
 
-void io_req_task_complete(struct io_kiocb *req, bool *locked)
+void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       if (*locked)
+       if (ts->locked)
                io_req_complete_defer(req);
        else
                io_req_complete_post(req, IO_URING_F_UNLOCKED);
@@ -1933,9 +1934,9 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
        return 0;
 }
 
-int io_poll_issue(struct io_kiocb *req, bool *locked)
+int io_poll_issue(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       io_tw_lock(req->ctx, locked);
+       io_tw_lock(req->ctx, ts);
        return io_issue_sqe(req, IO_URING_F_NONBLOCK|IO_URING_F_MULTISHOT|
                                 IO_URING_F_COMPLETE_DEFER);
 }
index 2711865f1e198e6c1273084265cf3c2f70f36fe9..c33f719731ac8c24af0ff1260282a828b89ebfd9 100644 (file)
@@ -52,16 +52,16 @@ void __io_req_task_work_add(struct io_kiocb *req, bool allow_local);
 bool io_is_uring_fops(struct file *file);
 bool io_alloc_async_data(struct io_kiocb *req);
 void io_req_task_queue(struct io_kiocb *req);
-void io_queue_iowq(struct io_kiocb *req, bool *dont_use);
-void io_req_task_complete(struct io_kiocb *req, bool *locked);
+void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use);
+void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts);
 void io_req_task_queue_fail(struct io_kiocb *req, int ret);
-void io_req_task_submit(struct io_kiocb *req, bool *locked);
+void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts);
 void tctx_task_work(struct callback_head *cb);
 __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 int io_uring_alloc_task_context(struct task_struct *task,
                                struct io_ring_ctx *ctx);
 
-int io_poll_issue(struct io_kiocb *req, bool *locked);
+int io_poll_issue(struct io_kiocb *req, struct io_tw_state *ts);
 int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr);
 int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin);
 void io_free_batch_list(struct io_ring_ctx *ctx, struct io_wq_work_node *node);
@@ -299,11 +299,11 @@ static inline bool io_task_work_pending(struct io_ring_ctx *ctx)
        return task_work_pending(current) || !wq_list_empty(&ctx->work_llist);
 }
 
-static inline void io_tw_lock(struct io_ring_ctx *ctx, bool *locked)
+static inline void io_tw_lock(struct io_ring_ctx *ctx, struct io_tw_state *ts)
 {
-       if (!*locked) {
+       if (!ts->locked) {
                mutex_lock(&ctx->uring_lock);
-               *locked = true;
+               ts->locked = true;
        }
 }
 
index 09dfd0832d19f9bc022d68919bea10c37a55dd43..172105eb347d5612d42b747a3e5beb4e03985efe 100644 (file)
@@ -9,7 +9,7 @@
 #include "notif.h"
 #include "rsrc.h"
 
-static void io_notif_complete_tw_ext(struct io_kiocb *notif, bool *locked)
+static void io_notif_complete_tw_ext(struct io_kiocb *notif, struct io_tw_state *ts)
 {
        struct io_notif_data *nd = io_notif_to_data(notif);
        struct io_ring_ctx *ctx = notif->ctx;
@@ -21,7 +21,7 @@ static void io_notif_complete_tw_ext(struct io_kiocb *notif, bool *locked)
                __io_unaccount_mem(ctx->user, nd->account_pages);
                nd->account_pages = 0;
        }
-       io_req_task_complete(notif, locked);
+       io_req_task_complete(notif, ts);
 }
 
 static void io_tx_ubuf_callback(struct sk_buff *skb, struct ubuf_info *uarg,
index 55306e801081376acb380487fed2308e09a0f1d9..c90e47dc1e293594b9ab106899a261a3769f945c 100644 (file)
@@ -148,7 +148,7 @@ static void io_poll_req_insert_locked(struct io_kiocb *req)
        hlist_add_head(&req->hash_node, &table->hbs[index].list);
 }
 
-static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
+static void io_poll_tw_hash_eject(struct io_kiocb *req, struct io_tw_state *ts)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
@@ -159,7 +159,7 @@ static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
                 * already grabbed the mutex for us, but there is a chance it
                 * failed.
                 */
-               io_tw_lock(ctx, locked);
+               io_tw_lock(ctx, ts);
                hash_del(&req->hash_node);
                req->flags &= ~REQ_F_HASH_LOCKED;
        } else {
@@ -238,7 +238,7 @@ enum {
  * req->cqe.res. IOU_POLL_REMOVE_POLL_USE_RES indicates to remove multishot
  * poll and that the result is stored in req->cqe.
  */
-static int io_poll_check_events(struct io_kiocb *req, bool *locked)
+static int io_poll_check_events(struct io_kiocb *req, struct io_tw_state *ts)
 {
        int v;
 
@@ -300,13 +300,13 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
                        __poll_t mask = mangle_poll(req->cqe.res &
                                                    req->apoll_events);
 
-                       if (!io_aux_cqe(req->ctx, *locked, req->cqe.user_data,
+                       if (!io_aux_cqe(req->ctx, ts->locked, req->cqe.user_data,
                                        mask, IORING_CQE_F_MORE, false)) {
                                io_req_set_res(req, mask, 0);
                                return IOU_POLL_REMOVE_POLL_USE_RES;
                        }
                } else {
-                       int ret = io_poll_issue(req, locked);
+                       int ret = io_poll_issue(req, ts);
                        if (ret == IOU_STOP_MULTISHOT)
                                return IOU_POLL_REMOVE_POLL_USE_RES;
                        if (ret < 0)
@@ -326,15 +326,15 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
        return IOU_POLL_NO_ACTION;
 }
 
-static void io_poll_task_func(struct io_kiocb *req, bool *locked)
+static void io_poll_task_func(struct io_kiocb *req, struct io_tw_state *ts)
 {
        int ret;
 
-       ret = io_poll_check_events(req, locked);
+       ret = io_poll_check_events(req, ts);
        if (ret == IOU_POLL_NO_ACTION)
                return;
        io_poll_remove_entries(req);
-       io_poll_tw_hash_eject(req, locked);
+       io_poll_tw_hash_eject(req, ts);
 
        if (req->opcode == IORING_OP_POLL_ADD) {
                if (ret == IOU_POLL_DONE) {
@@ -343,7 +343,7 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
                        poll = io_kiocb_to_cmd(req, struct io_poll);
                        req->cqe.res = mangle_poll(req->cqe.res & poll->events);
                } else if (ret == IOU_POLL_REISSUE) {
-                       io_req_task_submit(req, locked);
+                       io_req_task_submit(req, ts);
                        return;
                } else if (ret != IOU_POLL_REMOVE_POLL_USE_RES) {
                        req->cqe.res = ret;
@@ -351,14 +351,14 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
                }
 
                io_req_set_res(req, req->cqe.res, 0);
-               io_req_task_complete(req, locked);
+               io_req_task_complete(req, ts);
        } else {
-               io_tw_lock(req->ctx, locked);
+               io_tw_lock(req->ctx, ts);
 
                if (ret == IOU_POLL_REMOVE_POLL_USE_RES)
-                       io_req_task_complete(req, locked);
+                       io_req_task_complete(req, ts);
                else if (ret == IOU_POLL_DONE || ret == IOU_POLL_REISSUE)
-                       io_req_task_submit(req, locked);
+                       io_req_task_submit(req, ts);
                else
                        io_req_defer_failed(req, ret);
        }
@@ -977,7 +977,7 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
        struct io_hash_bucket *bucket;
        struct io_kiocb *preq;
        int ret2, ret = 0;
-       bool locked;
+       struct io_tw_state ts = {};
 
        preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
        ret2 = io_poll_disarm(preq);
@@ -1027,8 +1027,8 @@ found:
 
        req_set_fail(preq);
        io_req_set_res(preq, -ECANCELED, 0);
-       locked = !(issue_flags & IO_URING_F_UNLOCKED);
-       io_req_task_complete(preq, &locked);
+       ts.locked = !(issue_flags & IO_URING_F_UNLOCKED);
+       io_req_task_complete(preq, &ts);
 out:
        if (ret < 0) {
                req_set_fail(req);
index 4c233910e20097e2833c9e3a6c2f78995c2b7947..f14868624f4182dae5b1ce8bd581303dce6ccde8 100644 (file)
@@ -283,16 +283,16 @@ static inline int io_fixup_rw_res(struct io_kiocb *req, long res)
        return res;
 }
 
-static void io_req_rw_complete(struct io_kiocb *req, bool *locked)
+static void io_req_rw_complete(struct io_kiocb *req, struct io_tw_state *ts)
 {
        io_req_io_end(req);
 
        if (req->flags & (REQ_F_BUFFER_SELECTED|REQ_F_BUFFER_RING)) {
-               unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED;
+               unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
 
                req->cqe.flags |= io_put_kbuf(req, issue_flags);
        }
-       io_req_task_complete(req, locked);
+       io_req_task_complete(req, ts);
 }
 
 static void io_complete_rw(struct kiocb *kiocb, long res)
index 826a51bca3e498cbff5014bca0104a68c5890978..5c6c6f720809c7b259df80069a3ac05ef2040dea 100644 (file)
@@ -101,9 +101,9 @@ __cold void io_flush_timeouts(struct io_ring_ctx *ctx)
        spin_unlock_irq(&ctx->timeout_lock);
 }
 
-static void io_req_tw_fail_links(struct io_kiocb *link, bool *locked)
+static void io_req_tw_fail_links(struct io_kiocb *link, struct io_tw_state *ts)
 {
-       io_tw_lock(link->ctx, locked);
+       io_tw_lock(link->ctx, ts);
        while (link) {
                struct io_kiocb *nxt = link->link;
                long res = -ECANCELED;
@@ -112,7 +112,7 @@ static void io_req_tw_fail_links(struct io_kiocb *link, bool *locked)
                        res = link->cqe.res;
                link->link = NULL;
                io_req_set_res(link, res, 0);
-               io_req_task_complete(link, locked);
+               io_req_task_complete(link, ts);
                link = nxt;
        }
 }
@@ -265,9 +265,9 @@ int io_timeout_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd)
        return 0;
 }
 
-static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked)
+static void io_req_task_link_timeout(struct io_kiocb *req, struct io_tw_state *ts)
 {
-       unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED;
+       unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
        struct io_timeout *timeout = io_kiocb_to_cmd(req, struct io_timeout);
        struct io_kiocb *prev = timeout->prev;
        int ret = -ENOENT;
@@ -282,11 +282,11 @@ static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked)
                        ret = io_try_cancel(req->task->io_uring, &cd, issue_flags);
                }
                io_req_set_res(req, ret ?: -ETIME, 0);
-               io_req_task_complete(req, locked);
+               io_req_task_complete(req, ts);
                io_put_req(prev);
        } else {
                io_req_set_res(req, -ETIME, 0);
-               io_req_task_complete(req, locked);
+               io_req_task_complete(req, ts);
        }
 }
 
index 9a1dee5718724a109461560b39fc789132d6a13b..3d825d939b13cc9038ed580938450fdbe621fd1a 100644 (file)
 #include "rsrc.h"
 #include "uring_cmd.h"
 
-static void io_uring_cmd_work(struct io_kiocb *req, bool *locked)
+static void io_uring_cmd_work(struct io_kiocb *req, struct io_tw_state *ts)
 {
        struct io_uring_cmd *ioucmd = io_kiocb_to_cmd(req, struct io_uring_cmd);
-       unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED;
+       unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
 
        ioucmd->task_work_cb(ioucmd, issue_flags);
 }