io_uring: use private ctx wait queue entries for SQPOLL
[linux-block.git] / fs / io_uring.c
index 3790c7fe9fee2253b2a608015b099c05c0b0f8ae..128ffa79d9d38d3c4b17150a4b79111d082fabb5 100644 (file)
@@ -79,6 +79,7 @@
 #include <linux/splice.h>
 #include <linux/task_work.h>
 #include <linux/pagemap.h>
+#include <linux/io_uring.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
@@ -98,6 +99,8 @@
 #define IORING_MAX_FILES_TABLE (1U << IORING_FILE_TABLE_SHIFT)
 #define IORING_FILE_TABLE_MASK (IORING_MAX_FILES_TABLE - 1)
 #define IORING_MAX_FIXED_FILES (64 * IORING_MAX_FILES_TABLE)
+#define IORING_MAX_RESTRICTIONS        (IORING_RESTRICTION_LAST + \
+                                IORING_REGISTER_LAST + IORING_OP_LAST)
 
 struct io_uring {
        u32 head ____cacheline_aligned_in_smp;
@@ -219,6 +222,14 @@ struct io_buffer {
        __u16 bid;
 };
 
+struct io_restriction {
+       DECLARE_BITMAP(register_op, IORING_REGISTER_LAST);
+       DECLARE_BITMAP(sqe_op, IORING_OP_LAST);
+       u8 sqe_flags_allowed;
+       u8 sqe_flags_required;
+       bool registered;
+};
+
 struct io_ring_ctx {
        struct {
                struct percpu_ref       refs;
@@ -231,6 +242,7 @@ struct io_ring_ctx {
                unsigned int            cq_overflow_flushed: 1;
                unsigned int            drain_next: 1;
                unsigned int            eventfd_async: 1;
+               unsigned int            restricted: 1;
 
                /*
                 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -265,8 +277,19 @@ struct io_ring_ctx {
        /* IO offload */
        struct io_wq            *io_wq;
        struct task_struct      *sqo_thread;    /* if using sq thread polling */
-       struct mm_struct        *sqo_mm;
-       wait_queue_head_t       sqo_wait;
+
+       /*
+        * For SQPOLL usage - we hold a reference to the parent task, so we
+        * have access to the ->files
+        */
+       struct task_struct      *sqo_task;
+
+       /* Only used for accounting purposes */
+       struct mm_struct        *mm_account;
+
+       struct wait_queue_head  *sqo_wait;
+       struct wait_queue_head  __sqo_wait;
+       struct wait_queue_entry sqo_wait_entry;
 
        /*
         * If used, fixed file set. Writers must ensure that ->refs is dead,
@@ -275,8 +298,6 @@ struct io_ring_ctx {
         */
        struct fixed_file_data  *file_data;
        unsigned                nr_user_files;
-       int                     ring_fd;
-       struct file             *ring_file;
 
        /* if used, fixed mapped user buffers */
        unsigned                nr_user_bufs;
@@ -338,6 +359,7 @@ struct io_ring_ctx {
        struct llist_head               file_put_llist;
 
        struct work_struct              exit_work;
+       struct io_restriction           restrictions;
 };
 
 /*
@@ -544,7 +566,6 @@ enum {
        REQ_F_BUFFER_SELECTED_BIT,
        REQ_F_NO_FILE_TABLE_BIT,
        REQ_F_WORK_INITIALIZED_BIT,
-       REQ_F_TASK_PINNED_BIT,
 
        /* not a real bit, just to check we're not overflowing the space */
        __REQ_F_LAST_BIT,
@@ -590,8 +611,6 @@ enum {
        REQ_F_NO_FILE_TABLE     = BIT(REQ_F_NO_FILE_TABLE_BIT),
        /* io_wq_work is initialized */
        REQ_F_WORK_INITIALIZED  = BIT(REQ_F_WORK_INITIALIZED_BIT),
-       /* req->task is refcounted */
-       REQ_F_TASK_PINNED       = BIT(REQ_F_TASK_PINNED_BIT),
 };
 
 struct async_poll {
@@ -933,14 +952,6 @@ struct sock *io_uring_get_socket(struct file *file)
 }
 EXPORT_SYMBOL(io_uring_get_socket);
 
-static void io_get_req_task(struct io_kiocb *req)
-{
-       if (req->flags & REQ_F_TASK_PINNED)
-               return;
-       get_task_struct(req->task);
-       req->flags |= REQ_F_TASK_PINNED;
-}
-
 static inline void io_clean_op(struct io_kiocb *req)
 {
        if (req->flags & (REQ_F_NEED_CLEANUP | REQ_F_BUFFER_SELECTED |
@@ -948,13 +959,6 @@ static inline void io_clean_op(struct io_kiocb *req)
                __io_clean_op(req);
 }
 
-/* not idempotent -- it doesn't clear REQ_F_TASK_PINNED */
-static void __io_put_req_task(struct io_kiocb *req)
-{
-       if (req->flags & REQ_F_TASK_PINNED)
-               put_task_struct(req->task);
-}
-
 static void io_sq_thread_drop_mm(void)
 {
        struct mm_struct *mm = current->mm;
@@ -969,9 +973,10 @@ static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
 {
        if (!current->mm) {
                if (unlikely(!(ctx->flags & IORING_SETUP_SQPOLL) ||
-                            !mmget_not_zero(ctx->sqo_mm)))
+                            !ctx->sqo_task->mm ||
+                            !mmget_not_zero(ctx->sqo_task->mm)))
                        return -EFAULT;
-               kthread_use_mm(ctx->sqo_mm);
+               kthread_use_mm(ctx->sqo_task->mm);
        }
 
        return 0;
@@ -1054,7 +1059,8 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
                goto err;
 
        ctx->flags = p->flags;
-       init_waitqueue_head(&ctx->sqo_wait);
+       init_waitqueue_head(&ctx->__sqo_wait);
+       ctx->sqo_wait = &ctx->__sqo_wait;
        init_waitqueue_head(&ctx->cq_wait);
        INIT_LIST_HEAD(&ctx->cq_overflow_list);
        init_completion(&ctx->ref_comp);
@@ -1226,14 +1232,34 @@ static void io_kill_timeout(struct io_kiocb *req)
        }
 }
 
-static void io_kill_timeouts(struct io_ring_ctx *ctx)
+static bool io_task_match(struct io_kiocb *req, struct task_struct *tsk)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+
+       if (!tsk || req->task == tsk)
+               return true;
+       if ((ctx->flags & IORING_SETUP_SQPOLL) && req->task == ctx->sqo_thread)
+               return true;
+       return false;
+}
+
+/*
+ * Returns true if we found and killed one or more timeouts
+ */
+static bool io_kill_timeouts(struct io_ring_ctx *ctx, struct task_struct *tsk)
 {
        struct io_kiocb *req, *tmp;
+       int canceled = 0;
 
        spin_lock_irq(&ctx->completion_lock);
-       list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list)
-               io_kill_timeout(req);
+       list_for_each_entry_safe(req, tmp, &ctx->timeout_list, timeout.list) {
+               if (io_task_match(req, tsk)) {
+                       io_kill_timeout(req);
+                       canceled++;
+               }
+       }
        spin_unlock_irq(&ctx->completion_lock);
+       return canceled != 0;
 }
 
 static void __io_queue_deferred(struct io_ring_ctx *ctx)
@@ -1317,8 +1343,8 @@ static void io_cqring_ev_posted(struct io_ring_ctx *ctx)
 {
        if (waitqueue_active(&ctx->wait))
                wake_up(&ctx->wait);
-       if (waitqueue_active(&ctx->sqo_wait))
-               wake_up(&ctx->sqo_wait);
+       if (waitqueue_active(ctx->sqo_wait))
+               wake_up(ctx->sqo_wait);
        if (io_should_trigger_evfd(ctx))
                eventfd_signal(ctx->cq_ev_fd, 1);
 }
@@ -1332,12 +1358,24 @@ static void io_cqring_mark_overflow(struct io_ring_ctx *ctx)
        }
 }
 
+static inline bool io_match_files(struct io_kiocb *req,
+                                      struct files_struct *files)
+{
+       if (!files)
+               return true;
+       if (req->flags & REQ_F_WORK_INITIALIZED)
+               return req->work.files == files;
+       return false;
+}
+
 /* Returns true if there are no backlogged entries after the flush */
-static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
+static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force,
+                                    struct task_struct *tsk,
+                                    struct files_struct *files)
 {
        struct io_rings *rings = ctx->rings;
+       struct io_kiocb *req, *tmp;
        struct io_uring_cqe *cqe;
-       struct io_kiocb *req;
        unsigned long flags;
        LIST_HEAD(list);
 
@@ -1356,13 +1394,16 @@ static bool io_cqring_overflow_flush(struct io_ring_ctx *ctx, bool force)
                ctx->cq_overflow_flushed = 1;
 
        cqe = NULL;
-       while (!list_empty(&ctx->cq_overflow_list)) {
+       list_for_each_entry_safe(req, tmp, &ctx->cq_overflow_list, compl.list) {
+               if (tsk && req->task != tsk)
+                       continue;
+               if (!io_match_files(req, files))
+                       continue;
+
                cqe = io_get_cqring(ctx);
                if (!cqe && !force)
                        break;
 
-               req = list_first_entry(&ctx->cq_overflow_list, struct io_kiocb,
-                                               compl.list);
                list_move(&req->compl.list, &list);
                if (cqe) {
                        WRITE_ONCE(cqe->user_data, req->user_data);
@@ -1406,7 +1447,12 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
                WRITE_ONCE(cqe->user_data, req->user_data);
                WRITE_ONCE(cqe->res, res);
                WRITE_ONCE(cqe->flags, cflags);
-       } else if (ctx->cq_overflow_flushed) {
+       } else if (ctx->cq_overflow_flushed || req->task->io_uring->in_idle) {
+               /*
+                * If we're in ring overflow flush mode, or in task cancel mode,
+                * then we cannot store the request for later flushing, we need
+                * to drop it on the floor.
+                */
                WRITE_ONCE(ctx->rings->cq_overflow,
                                atomic_inc_return(&ctx->cached_cq_overflow));
        } else {
@@ -1564,9 +1610,14 @@ static bool io_dismantle_req(struct io_kiocb *req)
 
 static void __io_free_req_finish(struct io_kiocb *req)
 {
+       struct io_uring_task *tctx = req->task->io_uring;
        struct io_ring_ctx *ctx = req->ctx;
 
-       __io_put_req_task(req);
+       atomic_long_inc(&tctx->req_complete);
+       if (tctx->in_idle)
+               wake_up(&tctx->wait);
+       put_task_struct(req->task);
+
        if (likely(!io_is_fallback_req(req)))
                kmem_cache_free(req_cachep, req);
        else
@@ -1753,6 +1804,9 @@ static int io_req_task_work_add(struct io_kiocb *req, struct callback_head *cb,
        struct io_ring_ctx *ctx = req->ctx;
        int ret, notify;
 
+       if (tsk->flags & PF_EXITING)
+               return -ESRCH;
+
        /*
         * SQPOLL kernel thread doesn't need notification, just a wakeup. For
         * all other cases, use TWA_SIGNAL unconditionally to ensure we're
@@ -1787,8 +1841,10 @@ static void __io_req_task_cancel(struct io_kiocb *req, int error)
 static void io_req_task_cancel(struct callback_head *cb)
 {
        struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
+       struct io_ring_ctx *ctx = req->ctx;
 
        __io_req_task_cancel(req, -ECANCELED);
+       percpu_ref_put(&ctx->refs);
 }
 
 static void __io_req_task_submit(struct io_kiocb *req)
@@ -1874,6 +1930,7 @@ static void io_req_free_batch_finish(struct io_ring_ctx *ctx,
        if (rb->to_free)
                __io_req_free_batch_flush(ctx, rb);
        if (rb->task) {
+               atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
                put_task_struct_many(rb->task, rb->task_refs);
                rb->task = NULL;
        }
@@ -1888,16 +1945,15 @@ static void io_req_free_batch(struct req_batch *rb, struct io_kiocb *req)
        if (req->flags & REQ_F_LINK_HEAD)
                io_queue_next(req);
 
-       if (req->flags & REQ_F_TASK_PINNED) {
-               if (req->task != rb->task) {
-                       if (rb->task)
-                               put_task_struct_many(rb->task, rb->task_refs);
-                       rb->task = req->task;
-                       rb->task_refs = 0;
+       if (req->task != rb->task) {
+               if (rb->task) {
+                       atomic_long_add(rb->task_refs, &rb->task->io_uring->req_complete);
+                       put_task_struct_many(rb->task, rb->task_refs);
                }
-               rb->task_refs++;
-               req->flags &= ~REQ_F_TASK_PINNED;
+               rb->task = req->task;
+               rb->task_refs = 0;
        }
+       rb->task_refs++;
 
        WARN_ON_ONCE(io_dismantle_req(req));
        rb->reqs[rb->to_free++] = req;
@@ -1973,7 +2029,7 @@ static unsigned io_cqring_events(struct io_ring_ctx *ctx, bool noflush)
                if (noflush && !list_empty(&ctx->cq_overflow_list))
                        return -1U;
 
-               io_cqring_overflow_flush(ctx, false);
+               io_cqring_overflow_flush(ctx, false, NULL, NULL);
        }
 
        /* See comment at the top of this file */
@@ -2010,6 +2066,12 @@ static inline unsigned int io_put_rw_kbuf(struct io_kiocb *req)
 
 static inline bool io_run_task_work(void)
 {
+       /*
+        * Not safe to run on exiting task, and the task_work handling will
+        * not add work to such a task.
+        */
+       if (unlikely(current->flags & PF_EXITING))
+               return false;
        if (current->task_works) {
                __set_current_state(TASK_RUNNING);
                task_work_run();
@@ -2283,13 +2345,17 @@ static bool io_resubmit_prep(struct io_kiocb *req, int error)
                goto end_req;
        }
 
-       ret = io_import_iovec(rw, req, &iovec, &iter, false);
-       if (ret < 0)
-               goto end_req;
-       ret = io_setup_async_rw(req, iovec, inline_vecs, &iter, false);
-       if (!ret)
+       if (!req->io) {
+               ret = io_import_iovec(rw, req, &iovec, &iter, false);
+               if (ret < 0)
+                       goto end_req;
+               ret = io_setup_async_rw(req, iovec, inline_vecs, &iter, false);
+               if (!ret)
+                       return true;
+               kfree(iovec);
+       } else {
                return true;
-       kfree(iovec);
+       }
 end_req:
        req_set_fail_links(req);
        io_req_complete(req, ret);
@@ -2385,9 +2451,8 @@ static void io_iopoll_req_issued(struct io_kiocb *req)
        else
                list_add_tail(&req->inflight_entry, &ctx->iopoll_list);
 
-       if ((ctx->flags & IORING_SETUP_SQPOLL) &&
-           wq_has_sleeper(&ctx->sqo_wait))
-               wake_up(&ctx->sqo_wait);
+       if ((ctx->flags & IORING_SETUP_SQPOLL) && wq_has_sleeper(ctx->sqo_wait))
+               wake_up(ctx->sqo_wait);
 }
 
 static void __io_state_file_put(struct io_submit_state *state)
@@ -2512,9 +2577,6 @@ static int io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe,
        if (kiocb->ki_flags & IOCB_NOWAIT)
                req->flags |= REQ_F_NOWAIT;
 
-       if (kiocb->ki_flags & IOCB_DIRECT)
-               io_get_req_task(req);
-
        if (force_nonblock)
                kiocb->ki_flags |= IOCB_NOWAIT;
 
@@ -2526,7 +2588,6 @@ static int io_prep_rw(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                kiocb->ki_flags |= IOCB_HIPRI;
                kiocb->ki_complete = io_complete_rw_iopoll;
                req->iopoll_completed = 0;
-               io_get_req_task(req);
        } else {
                if (kiocb->ki_flags & IOCB_HIPRI)
                        return -EINVAL;
@@ -3034,6 +3095,7 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
        if (!wake_page_match(wpq, key))
                return 0;
 
+       req->rw.kiocb.ki_flags &= ~IOCB_WAITQ;
        list_del_init(&wait->entry);
 
        init_task_work(&req->task_work, io_req_task_submit);
@@ -3091,9 +3153,8 @@ static bool io_rw_should_retry(struct io_kiocb *req)
        wait->wait.flags = 0;
        INIT_LIST_HEAD(&wait->wait.entry);
        kiocb->ki_flags |= IOCB_WAITQ;
+       kiocb->ki_flags &= ~IOCB_NOWAIT;
        kiocb->ki_waitq = wait;
-
-       io_get_req_task(req);
        return true;
 }
 
@@ -3115,6 +3176,7 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
        struct iov_iter __iter, *iter = &__iter;
        ssize_t io_size, ret, ret2;
        size_t iov_count;
+       bool no_async;
 
        if (req->io)
                iter = &req->io->rw.iter;
@@ -3132,7 +3194,8 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
                kiocb->ki_flags &= ~IOCB_NOWAIT;
 
        /* If the file doesn't support async, just async punt */
-       if (force_nonblock && !io_file_supports_async(req->file, READ))
+       no_async = force_nonblock && !io_file_supports_async(req->file, READ);
+       if (no_async)
                goto copy_iov;
 
        ret = rw_verify_area(READ, req->file, io_kiocb_ppos(kiocb), iov_count);
@@ -3155,10 +3218,8 @@ static int io_read(struct io_kiocb *req, bool force_nonblock,
                        goto done;
                /* some cases will consume bytes even on error returns */
                iov_iter_revert(iter, iov_count - iov_iter_count(iter));
-               ret = io_setup_async_rw(req, iovec, inline_vecs, iter, false);
-               if (ret)
-                       goto out_free;
-               return -EAGAIN;
+               ret = 0;
+               goto copy_iov;
        } else if (ret < 0) {
                /* make sure -ERESTARTSYS -> -EINTR is done */
                goto done;
@@ -3176,6 +3237,8 @@ copy_iov:
                ret = ret2;
                goto out_free;
        }
+       if (no_async)
+               return -EAGAIN;
        /* it's copied and will be cleaned with ->io */
        iovec = NULL;
        /* now use our persistent iterator, if we aren't already */
@@ -3508,8 +3571,6 @@ static int __io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
        const char __user *fname;
        int ret;
 
-       if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL)))
-               return -EINVAL;
        if (unlikely(sqe->ioprio || sqe->buf_index))
                return -EINVAL;
        if (unlikely(req->flags & REQ_F_FIXED_FILE))
@@ -3536,6 +3597,8 @@ static int io_openat_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
        u64 flags, mode;
 
+       if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL)))
+               return -EINVAL;
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
        mode = READ_ONCE(sqe->len);
@@ -3550,6 +3613,8 @@ static int io_openat2_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        size_t len;
        int ret;
 
+       if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL|IORING_SETUP_SQPOLL)))
+               return -EINVAL;
        if (req->flags & REQ_F_NEED_CLEANUP)
                return 0;
        how = u64_to_user_ptr(READ_ONCE(sqe->addr2));
@@ -3767,7 +3832,7 @@ static int io_epoll_ctl_prep(struct io_kiocb *req,
 #if defined(CONFIG_EPOLL)
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
-       if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
+       if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL | IORING_SETUP_SQPOLL)))
                return -EINVAL;
 
        req->epoll.epfd = READ_ONCE(sqe->fd);
@@ -3882,7 +3947,7 @@ static int io_fadvise(struct io_kiocb *req, bool force_nonblock)
 
 static int io_statx_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 {
-       if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL))
+       if (unlikely(req->ctx->flags & (IORING_SETUP_IOPOLL | IORING_SETUP_SQPOLL)))
                return -EINVAL;
        if (sqe->ioprio || sqe->buf_index)
                return -EINVAL;
@@ -3938,8 +4003,7 @@ static int io_close_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
                return -EBADF;
 
        req->close.fd = READ_ONCE(sqe->fd);
-       if ((req->file && req->file->f_op == &io_uring_fops) ||
-           req->close.fd == req->ctx->ring_fd)
+       if ((req->file && req->file->f_op == &io_uring_fops))
                return -EBADF;
 
        req->close.put_file = NULL;
@@ -4724,6 +4788,8 @@ static int io_poll_double_wake(struct wait_queue_entry *wait, unsigned mode,
        if (mask && !(mask & poll->events))
                return 0;
 
+       list_del_init(&wait->entry);
+
        if (poll && poll->head) {
                bool done;
 
@@ -4919,7 +4985,6 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
        apoll->double_poll = NULL;
 
        req->flags |= REQ_F_POLLED;
-       io_get_req_task(req);
        req->apoll = apoll;
        INIT_HLIST_NODE(&req->hash_node);
 
@@ -4994,7 +5059,10 @@ static bool io_poll_remove_one(struct io_kiocb *req)
        return do_complete;
 }
 
-static void io_poll_remove_all(struct io_ring_ctx *ctx)
+/*
+ * Returns true if we found and killed one or more poll requests
+ */
+static bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk)
 {
        struct hlist_node *tmp;
        struct io_kiocb *req;
@@ -5005,13 +5073,17 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx)
                struct hlist_head *list;
 
                list = &ctx->cancel_hash[i];
-               hlist_for_each_entry_safe(req, tmp, list, hash_node)
-                       posted += io_poll_remove_one(req);
+               hlist_for_each_entry_safe(req, tmp, list, hash_node) {
+                       if (io_task_match(req, tsk))
+                               posted += io_poll_remove_one(req);
+               }
        }
        spin_unlock_irq(&ctx->completion_lock);
 
        if (posted)
                io_cqring_ev_posted(ctx);
+
+       return posted != 0;
 }
 
 static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr)
@@ -5100,8 +5172,6 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
 #endif
        poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP |
                       (events & EPOLLEXCLUSIVE);
-
-       io_get_req_task(req);
        return 0;
 }
 
@@ -5399,6 +5469,8 @@ static int io_async_cancel(struct io_kiocb *req)
 static int io_files_update_prep(struct io_kiocb *req,
                                const struct io_uring_sqe *sqe)
 {
+       if (unlikely(req->ctx->flags & IORING_SETUP_SQPOLL))
+               return -EINVAL;
        if (unlikely(req->flags & (REQ_F_FIXED_FILE | REQ_F_BUFFER_SELECT)))
                return -EINVAL;
        if (sqe->ioprio || sqe->rw_flags)
@@ -5449,6 +5521,8 @@ static int io_req_defer_prep(struct io_kiocb *req,
        if (unlikely(ret))
                return ret;
 
+       io_prep_async_work(req);
+
        switch (req->opcode) {
        case IORING_OP_NOP:
                break;
@@ -5606,6 +5680,22 @@ static int io_req_defer(struct io_kiocb *req, const struct io_uring_sqe *sqe)
        return -EIOCBQUEUED;
 }
 
+static void io_req_drop_files(struct io_kiocb *req)
+{
+       struct io_ring_ctx *ctx = req->ctx;
+       unsigned long flags;
+
+       spin_lock_irqsave(&ctx->inflight_lock, flags);
+       list_del(&req->inflight_entry);
+       if (waitqueue_active(&ctx->inflight_wait))
+               wake_up(&ctx->inflight_wait);
+       spin_unlock_irqrestore(&ctx->inflight_lock, flags);
+       req->flags &= ~REQ_F_INFLIGHT;
+       put_files_struct(req->work.files);
+       put_nsproxy(req->work.nsproxy);
+       req->work.files = NULL;
+}
+
 static void __io_clean_op(struct io_kiocb *req)
 {
        struct io_async_ctx *io = req->io;
@@ -5646,21 +5736,17 @@ static void __io_clean_op(struct io_kiocb *req)
                        io_put_file(req, req->splice.file_in,
                                    (req->splice.flags & SPLICE_F_FD_IN_FIXED));
                        break;
+               case IORING_OP_OPENAT:
+               case IORING_OP_OPENAT2:
+                       if (req->open.filename)
+                               putname(req->open.filename);
+                       break;
                }
                req->flags &= ~REQ_F_NEED_CLEANUP;
        }
 
-       if (req->flags & REQ_F_INFLIGHT) {
-               struct io_ring_ctx *ctx = req->ctx;
-               unsigned long flags;
-
-               spin_lock_irqsave(&ctx->inflight_lock, flags);
-               list_del(&req->inflight_entry);
-               if (waitqueue_active(&ctx->inflight_wait))
-                       wake_up(&ctx->inflight_wait);
-               spin_unlock_irqrestore(&ctx->inflight_lock, flags);
-               req->flags &= ~REQ_F_INFLIGHT;
-       }
+       if (req->flags & REQ_F_INFLIGHT)
+               io_req_drop_files(req);
 }
 
 static int io_issue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
@@ -6007,34 +6093,22 @@ static int io_req_set_file(struct io_submit_state *state, struct io_kiocb *req,
 
 static int io_grab_files(struct io_kiocb *req)
 {
-       int ret = -EBADF;
        struct io_ring_ctx *ctx = req->ctx;
 
        io_req_init_async(req);
 
        if (req->work.files || (req->flags & REQ_F_NO_FILE_TABLE))
                return 0;
-       if (!ctx->ring_file)
-               return -EBADF;
 
-       rcu_read_lock();
+       req->work.files = get_files_struct(current);
+       get_nsproxy(current->nsproxy);
+       req->work.nsproxy = current->nsproxy;
+       req->flags |= REQ_F_INFLIGHT;
+
        spin_lock_irq(&ctx->inflight_lock);
-       /*
-        * We use the f_ops->flush() handler to ensure that we can flush
-        * out work accessing these files if the fd is closed. Check if
-        * the fd has changed since we started down this path, and disallow
-        * this operation if it has.
-        */
-       if (fcheck(ctx->ring_fd) == ctx->ring_file) {
-               list_add(&req->inflight_entry, &ctx->inflight_list);
-               req->flags |= REQ_F_INFLIGHT;
-               req->work.files = current->files;
-               ret = 0;
-       }
+       list_add(&req->inflight_entry, &ctx->inflight_list);
        spin_unlock_irq(&ctx->inflight_lock);
-       rcu_read_unlock();
-
-       return ret;
+       return 0;
 }
 
 static inline int io_prep_work_files(struct io_kiocb *req)
@@ -6274,7 +6348,6 @@ static int io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
                        return ret;
                }
                trace_io_uring_link(ctx, req, head);
-               io_get_req_task(req);
                list_add_tail(&req->link_list, &head->link_list);
 
                /* last request of a link, enqueue the link */
@@ -6323,9 +6396,6 @@ static void io_submit_state_start(struct io_submit_state *state,
                                  struct io_ring_ctx *ctx, unsigned int max_ios)
 {
        blk_start_plug(&state->plug);
-#ifdef CONFIG_BLOCK
-       state->plug.nowait = true;
-#endif
        state->comp.nr = 0;
        INIT_LIST_HEAD(&state->comp.list);
        state->comp.ctx = ctx;
@@ -6382,6 +6452,32 @@ static inline void io_consume_sqe(struct io_ring_ctx *ctx)
        ctx->cached_sq_head++;
 }
 
+/*
+ * Check SQE restrictions (opcode and flags).
+ *
+ * Returns 'true' if SQE is allowed, 'false' otherwise.
+ */
+static inline bool io_check_restriction(struct io_ring_ctx *ctx,
+                                       struct io_kiocb *req,
+                                       unsigned int sqe_flags)
+{
+       if (!ctx->restricted)
+               return true;
+
+       if (!test_bit(req->opcode, ctx->restrictions.sqe_op))
+               return false;
+
+       if ((sqe_flags & ctx->restrictions.sqe_flags_required) !=
+           ctx->restrictions.sqe_flags_required)
+               return false;
+
+       if (sqe_flags & ~(ctx->restrictions.sqe_flags_allowed |
+                         ctx->restrictions.sqe_flags_required))
+               return false;
+
+       return true;
+}
+
 #define SQE_VALID_FLAGS        (IOSQE_FIXED_FILE|IOSQE_IO_DRAIN|IOSQE_IO_LINK| \
                                IOSQE_IO_HARDLINK | IOSQE_ASYNC | \
                                IOSQE_BUFFER_SELECT)
@@ -6402,6 +6498,8 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        /* one is dropped after submission, the other at completion */
        refcount_set(&req->refs, 2);
        req->task = current;
+       get_task_struct(req->task);
+       atomic_long_inc(&req->task->io_uring->req_issue);
        req->result = 0;
 
        if (unlikely(req->opcode >= IORING_OP_LAST))
@@ -6415,6 +6513,9 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        if (unlikely(sqe_flags & ~SQE_VALID_FLAGS))
                return -EINVAL;
 
+       if (unlikely(!io_check_restriction(ctx, req, sqe_flags)))
+               return -EACCES;
+
        if ((sqe_flags & IOSQE_BUFFER_SELECT) &&
            !io_op_defs[req->opcode].buffer_select)
                return -EOPNOTSUPP;
@@ -6437,8 +6538,7 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
        return io_req_set_file(state, req, READ_ONCE(sqe->fd));
 }
 
-static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
-                         struct file *ring_file, int ring_fd)
+static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
 {
        struct io_submit_state state;
        struct io_kiocb *link = NULL;
@@ -6447,7 +6547,7 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
        /* if we have a backlog and couldn't flush it all, return BUSY */
        if (test_bit(0, &ctx->sq_check_overflow)) {
                if (!list_empty(&ctx->cq_overflow_list) &&
-                   !io_cqring_overflow_flush(ctx, false))
+                   !io_cqring_overflow_flush(ctx, false, NULL, NULL))
                        return -EBUSY;
        }
 
@@ -6459,9 +6559,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
 
        io_submit_state_start(&state, ctx, nr);
 
-       ctx->ring_fd = ring_fd;
-       ctx->ring_file = ring_file;
-
        for (i = 0; i < nr; i++) {
                const struct io_uring_sqe *sqe;
                struct io_kiocb *req;
@@ -6532,10 +6629,11 @@ static int io_sq_thread(void *data)
 {
        struct io_ring_ctx *ctx = data;
        const struct cred *old_cred;
-       DEFINE_WAIT(wait);
        unsigned long timeout;
        int ret = 0;
 
+       init_wait(&ctx->sqo_wait_entry);
+
        complete(&ctx->sq_thread_comp);
 
        old_cred = override_creds(ctx->creds);
@@ -6585,7 +6683,7 @@ static int io_sq_thread(void *data)
                                continue;
                        }
 
-                       prepare_to_wait(&ctx->sqo_wait, &wait,
+                       prepare_to_wait(ctx->sqo_wait, &ctx->sqo_wait_entry,
                                                TASK_INTERRUPTIBLE);
 
                        /*
@@ -6597,7 +6695,7 @@ static int io_sq_thread(void *data)
                         */
                        if ((ctx->flags & IORING_SETUP_IOPOLL) &&
                            !list_empty_careful(&ctx->iopoll_list)) {
-                               finish_wait(&ctx->sqo_wait, &wait);
+                               finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
                                continue;
                        }
 
@@ -6606,31 +6704,29 @@ static int io_sq_thread(void *data)
                        to_submit = io_sqring_entries(ctx);
                        if (!to_submit || ret == -EBUSY) {
                                if (kthread_should_park()) {
-                                       finish_wait(&ctx->sqo_wait, &wait);
+                                       finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
                                        break;
                                }
                                if (io_run_task_work()) {
-                                       finish_wait(&ctx->sqo_wait, &wait);
+                                       finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
                                        io_ring_clear_wakeup_flag(ctx);
                                        continue;
                                }
-                               if (signal_pending(current))
-                                       flush_signals(current);
                                schedule();
-                               finish_wait(&ctx->sqo_wait, &wait);
+                               finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
 
                                io_ring_clear_wakeup_flag(ctx);
                                ret = 0;
                                continue;
                        }
-                       finish_wait(&ctx->sqo_wait, &wait);
+                       finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
 
                        io_ring_clear_wakeup_flag(ctx);
                }
 
                mutex_lock(&ctx->uring_lock);
                if (likely(!percpu_ref_is_dying(&ctx->refs)))
-                       ret = io_submit_sqes(ctx, to_submit, NULL, -1);
+                       ret = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
                timeout = jiffies + ctx->sq_thread_idle;
        }
@@ -6816,6 +6912,14 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 static void io_sq_thread_stop(struct io_ring_ctx *ctx)
 {
        if (ctx->sqo_thread) {
+               /*
+                * We may arrive here from the error branch in
+                * io_sq_offload_create() where the kthread is created
+                * without being waked up, thus wake it up now to make
+                * sure the wait will complete.
+                */
+               wake_up_process(ctx->sqo_thread);
+
                wait_for_completion(&ctx->sq_thread_comp);
                /*
                 * The park is a bit of a work-around, without it we get
@@ -7459,8 +7563,36 @@ out_fput:
        return ret;
 }
 
-static int io_sq_offload_start(struct io_ring_ctx *ctx,
-                              struct io_uring_params *p)
+static int io_uring_alloc_task_context(struct task_struct *task)
+{
+       struct io_uring_task *tctx;
+
+       tctx = kmalloc(sizeof(*tctx), GFP_KERNEL);
+       if (unlikely(!tctx))
+               return -ENOMEM;
+
+       xa_init(&tctx->xa);
+       init_waitqueue_head(&tctx->wait);
+       tctx->last = NULL;
+       tctx->in_idle = 0;
+       atomic_long_set(&tctx->req_issue, 0);
+       atomic_long_set(&tctx->req_complete, 0);
+       task->io_uring = tctx;
+       return 0;
+}
+
+void __io_uring_free(struct task_struct *tsk)
+{
+       struct io_uring_task *tctx = tsk->io_uring;
+
+       WARN_ON_ONCE(!xa_empty(&tctx->xa));
+       xa_destroy(&tctx->xa);
+       kfree(tctx);
+       tsk->io_uring = NULL;
+}
+
+static int io_sq_offload_create(struct io_ring_ctx *ctx,
+                               struct io_uring_params *p)
 {
        int ret;
 
@@ -7494,7 +7626,9 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
                        ctx->sqo_thread = NULL;
                        goto err;
                }
-               wake_up_process(ctx->sqo_thread);
+               ret = io_uring_alloc_task_context(ctx->sqo_thread);
+               if (ret)
+                       goto err;
        } else if (p->flags & IORING_SETUP_SQ_AFF) {
                /* Can't have SQ_AFF without SQPOLL */
                ret = -EINVAL;
@@ -7511,6 +7645,12 @@ err:
        return ret;
 }
 
+static void io_sq_offload_start(struct io_ring_ctx *ctx)
+{
+       if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sqo_thread)
+               wake_up_process(ctx->sqo_thread);
+}
+
 static inline void __io_unaccount_mem(struct user_struct *user,
                                      unsigned long nr_pages)
 {
@@ -7542,11 +7682,11 @@ static void io_unaccount_mem(struct io_ring_ctx *ctx, unsigned long nr_pages,
        if (ctx->limit_mem)
                __io_unaccount_mem(ctx->user, nr_pages);
 
-       if (ctx->sqo_mm) {
+       if (ctx->mm_account) {
                if (acct == ACCT_LOCKED)
-                       ctx->sqo_mm->locked_vm -= nr_pages;
+                       ctx->mm_account->locked_vm -= nr_pages;
                else if (acct == ACCT_PINNED)
-                       atomic64_sub(nr_pages, &ctx->sqo_mm->pinned_vm);
+                       atomic64_sub(nr_pages, &ctx->mm_account->pinned_vm);
        }
 }
 
@@ -7561,11 +7701,11 @@ static int io_account_mem(struct io_ring_ctx *ctx, unsigned long nr_pages,
                        return ret;
        }
 
-       if (ctx->sqo_mm) {
+       if (ctx->mm_account) {
                if (acct == ACCT_LOCKED)
-                       ctx->sqo_mm->locked_vm += nr_pages;
+                       ctx->mm_account->locked_vm += nr_pages;
                else if (acct == ACCT_PINNED)
-                       atomic64_add(nr_pages, &ctx->sqo_mm->pinned_vm);
+                       atomic64_add(nr_pages, &ctx->mm_account->pinned_vm);
        }
 
        return 0;
@@ -7869,9 +8009,12 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
        io_finish_async(ctx);
        io_sqe_buffer_unregister(ctx);
-       if (ctx->sqo_mm) {
-               mmdrop(ctx->sqo_mm);
-               ctx->sqo_mm = NULL;
+
+       if (ctx->sqo_task) {
+               put_task_struct(ctx->sqo_task);
+               ctx->sqo_task = NULL;
+               mmdrop(ctx->mm_account);
+               ctx->mm_account = NULL;
        }
 
        io_sqe_files_unregister(ctx);
@@ -7948,7 +8091,7 @@ static void io_ring_exit_work(struct work_struct *work)
         */
        do {
                if (ctx->rings)
-                       io_cqring_overflow_flush(ctx, true);
+                       io_cqring_overflow_flush(ctx, true, NULL, NULL);
                io_iopoll_try_reap_events(ctx);
        } while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20));
        io_ring_ctx_free(ctx);
@@ -7960,15 +8103,15 @@ static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
        percpu_ref_kill(&ctx->refs);
        mutex_unlock(&ctx->uring_lock);
 
-       io_kill_timeouts(ctx);
-       io_poll_remove_all(ctx);
+       io_kill_timeouts(ctx, NULL);
+       io_poll_remove_all(ctx, NULL);
 
        if (ctx->io_wq)
                io_wq_cancel_all(ctx->io_wq);
 
        /* if we failed setting up the ctx, we might not have any rings */
        if (ctx->rings)
-               io_cqring_overflow_flush(ctx, true);
+               io_cqring_overflow_flush(ctx, true, NULL, NULL);
        io_iopoll_try_reap_events(ctx);
        idr_for_each(&ctx->personality_idr, io_remove_personalities, ctx);
 
@@ -8003,7 +8146,7 @@ static bool io_wq_files_match(struct io_wq_work *work, void *data)
 {
        struct files_struct *files = data;
 
-       return work->files == files;
+       return !files || work->files == files;
 }
 
 /*
@@ -8024,12 +8167,6 @@ static bool io_match_link(struct io_kiocb *preq, struct io_kiocb *req)
        return false;
 }
 
-static inline bool io_match_files(struct io_kiocb *req,
-                                      struct files_struct *files)
-{
-       return (req->flags & REQ_F_WORK_INITIALIZED) && req->work.files == files;
-}
-
 static bool io_match_link_files(struct io_kiocb *req,
                                struct files_struct *files)
 {
@@ -8145,11 +8282,14 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
        }
 }
 
-static void io_uring_cancel_files(struct io_ring_ctx *ctx,
+/*
+ * Returns true if we found and killed one or more files pinning requests
+ */
+static bool io_uring_cancel_files(struct io_ring_ctx *ctx,
                                  struct files_struct *files)
 {
        if (list_empty_careful(&ctx->inflight_list))
-               return;
+               return false;
 
        io_cancel_defer_files(ctx, files);
        /* cancel all at once, should be faster than doing it one by one*/
@@ -8161,7 +8301,7 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 
                spin_lock_irq(&ctx->inflight_lock);
                list_for_each_entry(req, &ctx->inflight_list, inflight_entry) {
-                       if (req->work.files != files)
+                       if (files && req->work.files != files)
                                continue;
                        /* req is being completed, ignore */
                        if (!refcount_inc_not_zero(&req->refs))
@@ -8180,9 +8320,13 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
                /* cancel this request, or head link requests */
                io_attempt_cancel(ctx, cancel_req);
                io_put_req(cancel_req);
+               /* cancellations _may_ trigger task work */
+               io_run_task_work();
                schedule();
                finish_wait(&ctx->inflight_wait, &wait);
        }
+
+       return true;
 }
 
 static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
@@ -8190,21 +8334,220 @@ static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct task_struct *task = data;
 
-       return req->task == task;
+       return io_task_match(req, task);
+}
+
+static bool __io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
+                                           struct task_struct *task,
+                                           struct files_struct *files)
+{
+       bool ret;
+
+       ret = io_uring_cancel_files(ctx, files);
+       if (!files) {
+               enum io_wq_cancel cret;
+
+               cret = io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, task, true);
+               if (cret != IO_WQ_CANCEL_NOTFOUND)
+                       ret = true;
+
+               /* SQPOLL thread does its own polling */
+               if (!(ctx->flags & IORING_SETUP_SQPOLL)) {
+                       while (!list_empty_careful(&ctx->iopoll_list)) {
+                               io_iopoll_try_reap_events(ctx);
+                               ret = true;
+                       }
+               }
+
+               ret |= io_poll_remove_all(ctx, task);
+               ret |= io_kill_timeouts(ctx, task);
+       }
+
+       return ret;
+}
+
+/*
+ * We need to iteratively cancel requests, in case a request has dependent
+ * hard links. These persist even for failure of cancelations, hence keep
+ * looping until none are found.
+ */
+static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
+                                         struct files_struct *files)
+{
+       struct task_struct *task = current;
+
+       if (ctx->flags & IORING_SETUP_SQPOLL)
+               task = ctx->sqo_thread;
+
+       io_cqring_overflow_flush(ctx, true, task, files);
+
+       while (__io_uring_cancel_task_requests(ctx, task, files)) {
+               io_run_task_work();
+               cond_resched();
+       }
+}
+
+/*
+ * Note that this task has used io_uring. We use it for cancelation purposes.
+ */
+static int io_uring_add_task_file(struct file *file)
+{
+       if (unlikely(!current->io_uring)) {
+               int ret;
+
+               ret = io_uring_alloc_task_context(current);
+               if (unlikely(ret))
+                       return ret;
+       }
+       if (current->io_uring->last != file) {
+               XA_STATE(xas, &current->io_uring->xa, (unsigned long) file);
+               void *old;
+
+               rcu_read_lock();
+               old = xas_load(&xas);
+               if (old != file) {
+                       get_file(file);
+                       xas_lock(&xas);
+                       xas_store(&xas, file);
+                       xas_unlock(&xas);
+               }
+               rcu_read_unlock();
+               current->io_uring->last = file;
+       }
+
+       return 0;
+}
+
+/*
+ * Remove this io_uring_file -> task mapping.
+ */
+static void io_uring_del_task_file(struct file *file)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       XA_STATE(xas, &tctx->xa, (unsigned long) file);
+
+       if (tctx->last == file)
+               tctx->last = NULL;
+
+       xas_lock(&xas);
+       file = xas_store(&xas, NULL);
+       xas_unlock(&xas);
+
+       if (file)
+               fput(file);
+}
+
+static void __io_uring_attempt_task_drop(struct file *file)
+{
+       XA_STATE(xas, &current->io_uring->xa, (unsigned long) file);
+       struct file *old;
+
+       rcu_read_lock();
+       old = xas_load(&xas);
+       rcu_read_unlock();
+
+       if (old == file)
+               io_uring_del_task_file(file);
+}
+
+/*
+ * Drop task note for this file if we're the only ones that hold it after
+ * pending fput()
+ */
+static void io_uring_attempt_task_drop(struct file *file, bool exiting)
+{
+       if (!current->io_uring)
+               return;
+       /*
+        * fput() is pending, will be 2 if the only other ref is our potential
+        * task file note. If the task is exiting, drop regardless of count.
+        */
+       if (!exiting && atomic_long_read(&file->f_count) != 2)
+               return;
+
+       __io_uring_attempt_task_drop(file);
+}
+
+void __io_uring_files_cancel(struct files_struct *files)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       XA_STATE(xas, &tctx->xa, 0);
+
+       /* make sure overflow events are dropped */
+       tctx->in_idle = true;
+
+       do {
+               struct io_ring_ctx *ctx;
+               struct file *file;
+
+               xas_lock(&xas);
+               file = xas_next_entry(&xas, ULONG_MAX);
+               xas_unlock(&xas);
+
+               if (!file)
+                       break;
+
+               ctx = file->private_data;
+
+               io_uring_cancel_task_requests(ctx, files);
+               if (files)
+                       io_uring_del_task_file(file);
+       } while (1);
+}
+
+static inline bool io_uring_task_idle(struct io_uring_task *tctx)
+{
+       return atomic_long_read(&tctx->req_issue) ==
+               atomic_long_read(&tctx->req_complete);
+}
+
+/*
+ * Find any io_uring fd that this task has registered or done IO on, and cancel
+ * requests.
+ */
+void __io_uring_task_cancel(void)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       DEFINE_WAIT(wait);
+       long completions;
+
+       /* make sure overflow events are dropped */
+       tctx->in_idle = true;
+
+       while (!io_uring_task_idle(tctx)) {
+               /* read completions before cancelations */
+               completions = atomic_long_read(&tctx->req_complete);
+               __io_uring_files_cancel(NULL);
+
+               prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
+
+               /*
+                * If we've seen completions, retry. This avoids a race where
+                * a completion comes in before we did prepare_to_wait().
+                */
+               if (completions != atomic_long_read(&tctx->req_complete))
+                       continue;
+               if (io_uring_task_idle(tctx))
+                       break;
+               schedule();
+       }
+
+       finish_wait(&tctx->wait, &wait);
+       tctx->in_idle = false;
 }
 
 static int io_uring_flush(struct file *file, void *data)
 {
        struct io_ring_ctx *ctx = file->private_data;
 
-       io_uring_cancel_files(ctx, data);
-
        /*
         * If the task is going away, cancel work it may have pending
         */
        if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
-               io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, current, true);
+               data = NULL;
 
+       io_uring_cancel_task_requests(ctx, data);
+       io_uring_attempt_task_drop(file, !data);
        return 0;
 }
 
@@ -8305,6 +8648,10 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
        if (!percpu_ref_tryget(&ctx->refs))
                goto out_fput;
 
+       ret = -EBADFD;
+       if (ctx->flags & IORING_SETUP_R_DISABLED)
+               goto out;
+
        /*
         * For SQ polling, the thread will do all submissions and completions.
         * Just return the requested submit count, and wake the thread if
@@ -8313,13 +8660,16 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
        ret = 0;
        if (ctx->flags & IORING_SETUP_SQPOLL) {
                if (!list_empty_careful(&ctx->cq_overflow_list))
-                       io_cqring_overflow_flush(ctx, false);
+                       io_cqring_overflow_flush(ctx, false, NULL, NULL);
                if (flags & IORING_ENTER_SQ_WAKEUP)
-                       wake_up(&ctx->sqo_wait);
+                       wake_up(ctx->sqo_wait);
                submitted = to_submit;
        } else if (to_submit) {
+               ret = io_uring_add_task_file(f.file);
+               if (unlikely(ret))
+                       goto out;
                mutex_lock(&ctx->uring_lock);
-               submitted = io_submit_sqes(ctx, to_submit, f.file, fd);
+               submitted = io_submit_sqes(ctx, to_submit);
                mutex_unlock(&ctx->uring_lock);
 
                if (submitted != to_submit)
@@ -8385,11 +8735,19 @@ static int io_uring_show_cred(int id, void *p, void *data)
 
 static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
 {
+       bool has_lock;
        int i;
 
-       mutex_lock(&ctx->uring_lock);
+       /*
+        * Avoid ABBA deadlock between the seq lock and the io_uring mutex,
+        * since fdinfo case grabs it in the opposite direction of normal use
+        * cases. If we fail to get the lock, we just don't iterate any
+        * structures that could be going away outside the io_uring mutex.
+        */
+       has_lock = mutex_trylock(&ctx->uring_lock);
+
        seq_printf(m, "UserFiles:\t%u\n", ctx->nr_user_files);
-       for (i = 0; i < ctx->nr_user_files; i++) {
+       for (i = 0; has_lock && i < ctx->nr_user_files; i++) {
                struct fixed_file_table *table;
                struct file *f;
 
@@ -8401,13 +8759,13 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                        seq_printf(m, "%5u: <none>\n", i);
        }
        seq_printf(m, "UserBufs:\t%u\n", ctx->nr_user_bufs);
-       for (i = 0; i < ctx->nr_user_bufs; i++) {
+       for (i = 0; has_lock && i < ctx->nr_user_bufs; i++) {
                struct io_mapped_ubuf *buf = &ctx->user_bufs[i];
 
                seq_printf(m, "%5u: 0x%llx/%u\n", i, buf->ubuf,
                                                (unsigned int) buf->len);
        }
-       if (!idr_is_empty(&ctx->personality_idr)) {
+       if (has_lock && !idr_is_empty(&ctx->personality_idr)) {
                seq_printf(m, "Personalities:\n");
                idr_for_each(&ctx->personality_idr, io_uring_show_cred, m);
        }
@@ -8422,7 +8780,8 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m)
                                        req->task->task_works != NULL);
        }
        spin_unlock_irq(&ctx->completion_lock);
-       mutex_unlock(&ctx->uring_lock);
+       if (has_lock)
+               mutex_unlock(&ctx->uring_lock);
 }
 
 static void io_uring_show_fdinfo(struct seq_file *m, struct file *f)
@@ -8520,6 +8879,7 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
        file = anon_inode_getfile("[io_uring]", &io_uring_fops, ctx,
                                        O_RDWR | O_CLOEXEC);
        if (IS_ERR(file)) {
+err_fd:
                put_unused_fd(ret);
                ret = PTR_ERR(file);
                goto err;
@@ -8528,6 +8888,10 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
 #if defined(CONFIG_UNIX)
        ctx->ring_sock->file = file;
 #endif
+       if (unlikely(io_uring_add_task_file(file))) {
+               file = ERR_PTR(-ENOMEM);
+               goto err_fd;
+       }
        fd_install(ret, file);
        return ret;
 err:
@@ -8605,8 +8969,16 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        ctx->user = user;
        ctx->creds = get_current_cred();
 
+       ctx->sqo_task = get_task_struct(current);
+
+       /*
+        * This is just grabbed for accounting purposes. When a process exits,
+        * the mm is exited and dropped before the files, hence we need to hang
+        * on to this mm purely for the purposes of being able to unaccount
+        * memory (locked/pinned vm). It's not used for anything else.
+        */
        mmgrab(current->mm);
-       ctx->sqo_mm = current->mm;
+       ctx->mm_account = current->mm;
 
        /*
         * Account memory _before_ installing the file descriptor. Once
@@ -8622,10 +8994,13 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        if (ret)
                goto err;
 
-       ret = io_sq_offload_start(ctx, p);
+       ret = io_sq_offload_create(ctx, p);
        if (ret)
                goto err;
 
+       if (!(p->flags & IORING_SETUP_R_DISABLED))
+               io_sq_offload_start(ctx);
+
        memset(&p->sq_off, 0, sizeof(p->sq_off));
        p->sq_off.head = offsetof(struct io_rings, sq.head);
        p->sq_off.tail = offsetof(struct io_rings, sq.tail);
@@ -8688,7 +9063,8 @@ static long io_uring_setup(u32 entries, struct io_uring_params __user *params)
 
        if (p.flags & ~(IORING_SETUP_IOPOLL | IORING_SETUP_SQPOLL |
                        IORING_SETUP_SQ_AFF | IORING_SETUP_CQSIZE |
-                       IORING_SETUP_CLAMP | IORING_SETUP_ATTACH_WQ))
+                       IORING_SETUP_CLAMP | IORING_SETUP_ATTACH_WQ |
+                       IORING_SETUP_R_DISABLED))
                return -EINVAL;
 
        return  io_uring_create(entries, &p, params);
@@ -8764,6 +9140,91 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
        return -EINVAL;
 }
 
+static int io_register_restrictions(struct io_ring_ctx *ctx, void __user *arg,
+                                   unsigned int nr_args)
+{
+       struct io_uring_restriction *res;
+       size_t size;
+       int i, ret;
+
+       /* Restrictions allowed only if rings started disabled */
+       if (!(ctx->flags & IORING_SETUP_R_DISABLED))
+               return -EBADFD;
+
+       /* We allow only a single restrictions registration */
+       if (ctx->restrictions.registered)
+               return -EBUSY;
+
+       if (!arg || nr_args > IORING_MAX_RESTRICTIONS)
+               return -EINVAL;
+
+       size = array_size(nr_args, sizeof(*res));
+       if (size == SIZE_MAX)
+               return -EOVERFLOW;
+
+       res = memdup_user(arg, size);
+       if (IS_ERR(res))
+               return PTR_ERR(res);
+
+       ret = 0;
+
+       for (i = 0; i < nr_args; i++) {
+               switch (res[i].opcode) {
+               case IORING_RESTRICTION_REGISTER_OP:
+                       if (res[i].register_op >= IORING_REGISTER_LAST) {
+                               ret = -EINVAL;
+                               goto out;
+                       }
+
+                       __set_bit(res[i].register_op,
+                                 ctx->restrictions.register_op);
+                       break;
+               case IORING_RESTRICTION_SQE_OP:
+                       if (res[i].sqe_op >= IORING_OP_LAST) {
+                               ret = -EINVAL;
+                               goto out;
+                       }
+
+                       __set_bit(res[i].sqe_op, ctx->restrictions.sqe_op);
+                       break;
+               case IORING_RESTRICTION_SQE_FLAGS_ALLOWED:
+                       ctx->restrictions.sqe_flags_allowed = res[i].sqe_flags;
+                       break;
+               case IORING_RESTRICTION_SQE_FLAGS_REQUIRED:
+                       ctx->restrictions.sqe_flags_required = res[i].sqe_flags;
+                       break;
+               default:
+                       ret = -EINVAL;
+                       goto out;
+               }
+       }
+
+out:
+       /* Reset all restrictions if an error happened */
+       if (ret != 0)
+               memset(&ctx->restrictions, 0, sizeof(ctx->restrictions));
+       else
+               ctx->restrictions.registered = true;
+
+       kfree(res);
+       return ret;
+}
+
+static int io_register_enable_rings(struct io_ring_ctx *ctx)
+{
+       if (!(ctx->flags & IORING_SETUP_R_DISABLED))
+               return -EBADFD;
+
+       if (ctx->restrictions.registered)
+               ctx->restricted = 1;
+
+       ctx->flags &= ~IORING_SETUP_R_DISABLED;
+
+       io_sq_offload_start(ctx);
+
+       return 0;
+}
+
 static bool io_register_op_must_quiesce(int op)
 {
        switch (op) {
@@ -8810,6 +9271,18 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
                if (ret) {
                        percpu_ref_resurrect(&ctx->refs);
                        ret = -EINTR;
+                       goto out_quiesce;
+               }
+       }
+
+       if (ctx->restricted) {
+               if (opcode >= IORING_REGISTER_LAST) {
+                       ret = -EINVAL;
+                       goto out;
+               }
+
+               if (!test_bit(opcode, ctx->restrictions.register_op)) {
+                       ret = -EACCES;
                        goto out;
                }
        }
@@ -8873,15 +9346,25 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
                        break;
                ret = io_unregister_personality(ctx, nr_args);
                break;
+       case IORING_REGISTER_ENABLE_RINGS:
+               ret = -EINVAL;
+               if (arg || nr_args)
+                       break;
+               ret = io_register_enable_rings(ctx);
+               break;
+       case IORING_REGISTER_RESTRICTIONS:
+               ret = io_register_restrictions(ctx, arg, nr_args);
+               break;
        default:
                ret = -EINVAL;
                break;
        }
 
+out:
        if (io_register_op_must_quiesce(opcode)) {
                /* bring the ctx back to life */
                percpu_ref_reinit(&ctx->refs);
-out:
+out_quiesce:
                reinit_completion(&ctx->ref_comp);
        }
        return ret;