io_uring: introduce a struct for hash table
[linux-block.git] / io_uring / poll.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/fs.h>
5 #include <linux/file.h>
6 #include <linux/mm.h>
7 #include <linux/slab.h>
8 #include <linux/poll.h>
9 #include <linux/hashtable.h>
10 #include <linux/io_uring.h>
11
12 #include <trace/events/io_uring.h>
13
14 #include <uapi/linux/io_uring.h>
15
16 #include "io_uring_types.h"
17 #include "io_uring.h"
18 #include "refs.h"
19 #include "opdef.h"
20 #include "kbuf.h"
21 #include "poll.h"
22 #include "cancel.h"
23
24 struct io_poll_update {
25         struct file                     *file;
26         u64                             old_user_data;
27         u64                             new_user_data;
28         __poll_t                        events;
29         bool                            update_events;
30         bool                            update_user_data;
31 };
32
33 struct io_poll_table {
34         struct poll_table_struct pt;
35         struct io_kiocb *req;
36         int nr_entries;
37         int error;
38 };
39
40 #define IO_POLL_CANCEL_FLAG     BIT(31)
41 #define IO_POLL_REF_MASK        GENMASK(30, 0)
42
43 /*
44  * If refs part of ->poll_refs (see IO_POLL_REF_MASK) is 0, it's free. We can
45  * bump it and acquire ownership. It's disallowed to modify requests while not
46  * owning it, that prevents from races for enqueueing task_work's and b/w
47  * arming poll and wakeups.
48  */
49 static inline bool io_poll_get_ownership(struct io_kiocb *req)
50 {
51         return !(atomic_fetch_inc(&req->poll_refs) & IO_POLL_REF_MASK);
52 }
53
54 static void io_poll_mark_cancelled(struct io_kiocb *req)
55 {
56         atomic_or(IO_POLL_CANCEL_FLAG, &req->poll_refs);
57 }
58
59 static struct io_poll *io_poll_get_double(struct io_kiocb *req)
60 {
61         /* pure poll stashes this in ->async_data, poll driven retry elsewhere */
62         if (req->opcode == IORING_OP_POLL_ADD)
63                 return req->async_data;
64         return req->apoll->double_poll;
65 }
66
67 static struct io_poll *io_poll_get_single(struct io_kiocb *req)
68 {
69         if (req->opcode == IORING_OP_POLL_ADD)
70                 return io_kiocb_to_cmd(req);
71         return &req->apoll->poll;
72 }
73
74 static void io_poll_req_insert(struct io_kiocb *req)
75 {
76         struct io_hash_table *table = &req->ctx->cancel_table;
77         u32 index = hash_long(req->cqe.user_data, table->hash_bits);
78         struct io_hash_bucket *hb = &table->hbs[index];
79
80         spin_lock(&hb->lock);
81         hlist_add_head(&req->hash_node, &hb->list);
82         spin_unlock(&hb->lock);
83 }
84
85 static void io_poll_req_delete(struct io_kiocb *req, struct io_ring_ctx *ctx)
86 {
87         struct io_hash_table *table = &req->ctx->cancel_table;
88         u32 index = hash_long(req->cqe.user_data, table->hash_bits);
89         spinlock_t *lock = &table->hbs[index].lock;
90
91         spin_lock(lock);
92         hash_del(&req->hash_node);
93         spin_unlock(lock);
94 }
95
96 static void io_init_poll_iocb(struct io_poll *poll, __poll_t events,
97                               wait_queue_func_t wake_func)
98 {
99         poll->head = NULL;
100 #define IO_POLL_UNMASK  (EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
101         /* mask in events that we always want/need */
102         poll->events = events | IO_POLL_UNMASK;
103         INIT_LIST_HEAD(&poll->wait.entry);
104         init_waitqueue_func_entry(&poll->wait, wake_func);
105 }
106
107 static inline void io_poll_remove_entry(struct io_poll *poll)
108 {
109         struct wait_queue_head *head = smp_load_acquire(&poll->head);
110
111         if (head) {
112                 spin_lock_irq(&head->lock);
113                 list_del_init(&poll->wait.entry);
114                 poll->head = NULL;
115                 spin_unlock_irq(&head->lock);
116         }
117 }
118
119 static void io_poll_remove_entries(struct io_kiocb *req)
120 {
121         /*
122          * Nothing to do if neither of those flags are set. Avoid dipping
123          * into the poll/apoll/double cachelines if we can.
124          */
125         if (!(req->flags & (REQ_F_SINGLE_POLL | REQ_F_DOUBLE_POLL)))
126                 return;
127
128         /*
129          * While we hold the waitqueue lock and the waitqueue is nonempty,
130          * wake_up_pollfree() will wait for us.  However, taking the waitqueue
131          * lock in the first place can race with the waitqueue being freed.
132          *
133          * We solve this as eventpoll does: by taking advantage of the fact that
134          * all users of wake_up_pollfree() will RCU-delay the actual free.  If
135          * we enter rcu_read_lock() and see that the pointer to the queue is
136          * non-NULL, we can then lock it without the memory being freed out from
137          * under us.
138          *
139          * Keep holding rcu_read_lock() as long as we hold the queue lock, in
140          * case the caller deletes the entry from the queue, leaving it empty.
141          * In that case, only RCU prevents the queue memory from being freed.
142          */
143         rcu_read_lock();
144         if (req->flags & REQ_F_SINGLE_POLL)
145                 io_poll_remove_entry(io_poll_get_single(req));
146         if (req->flags & REQ_F_DOUBLE_POLL)
147                 io_poll_remove_entry(io_poll_get_double(req));
148         rcu_read_unlock();
149 }
150
151 /*
152  * All poll tw should go through this. Checks for poll events, manages
153  * references, does rewait, etc.
154  *
155  * Returns a negative error on failure. >0 when no action require, which is
156  * either spurious wakeup or multishot CQE is served. 0 when it's done with
157  * the request, then the mask is stored in req->cqe.res.
158  */
159 static int io_poll_check_events(struct io_kiocb *req, bool *locked)
160 {
161         struct io_ring_ctx *ctx = req->ctx;
162         int v, ret;
163
164         /* req->task == current here, checking PF_EXITING is safe */
165         if (unlikely(req->task->flags & PF_EXITING))
166                 return -ECANCELED;
167
168         do {
169                 v = atomic_read(&req->poll_refs);
170
171                 /* tw handler should be the owner, and so have some references */
172                 if (WARN_ON_ONCE(!(v & IO_POLL_REF_MASK)))
173                         return 0;
174                 if (v & IO_POLL_CANCEL_FLAG)
175                         return -ECANCELED;
176
177                 if (!req->cqe.res) {
178                         struct poll_table_struct pt = { ._key = req->apoll_events };
179                         req->cqe.res = vfs_poll(req->file, &pt) & req->apoll_events;
180                 }
181
182                 if ((unlikely(!req->cqe.res)))
183                         continue;
184                 if (req->apoll_events & EPOLLONESHOT)
185                         return 0;
186
187                 /* multishot, just fill a CQE and proceed */
188                 if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
189                         __poll_t mask = mangle_poll(req->cqe.res &
190                                                     req->apoll_events);
191                         bool filled;
192
193                         spin_lock(&ctx->completion_lock);
194                         filled = io_fill_cqe_aux(ctx, req->cqe.user_data,
195                                                  mask, IORING_CQE_F_MORE);
196                         io_commit_cqring(ctx);
197                         spin_unlock(&ctx->completion_lock);
198                         if (filled) {
199                                 io_cqring_ev_posted(ctx);
200                                 continue;
201                         }
202                         return -ECANCELED;
203                 }
204
205                 ret = io_poll_issue(req, locked);
206                 if (ret)
207                         return ret;
208
209                 /*
210                  * Release all references, retry if someone tried to restart
211                  * task_work while we were executing it.
212                  */
213         } while (atomic_sub_return(v & IO_POLL_REF_MASK, &req->poll_refs));
214
215         return 1;
216 }
217
218 static void io_poll_task_func(struct io_kiocb *req, bool *locked)
219 {
220         struct io_ring_ctx *ctx = req->ctx;
221         int ret;
222
223         ret = io_poll_check_events(req, locked);
224         if (ret > 0)
225                 return;
226
227         if (!ret) {
228                 struct io_poll *poll = io_kiocb_to_cmd(req);
229
230                 req->cqe.res = mangle_poll(req->cqe.res & poll->events);
231         } else {
232                 req->cqe.res = ret;
233                 req_set_fail(req);
234         }
235
236         io_poll_remove_entries(req);
237         io_poll_req_delete(req, ctx);
238         io_req_set_res(req, req->cqe.res, 0);
239         io_req_task_complete(req, locked);
240 }
241
242 static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
243 {
244         int ret;
245
246         ret = io_poll_check_events(req, locked);
247         if (ret > 0)
248                 return;
249
250         io_poll_remove_entries(req);
251         io_poll_req_delete(req, req->ctx);
252
253         if (!ret)
254                 io_req_task_submit(req, locked);
255         else
256                 io_req_complete_failed(req, ret);
257 }
258
259 static void __io_poll_execute(struct io_kiocb *req, int mask,
260                               __poll_t __maybe_unused events)
261 {
262         io_req_set_res(req, mask, 0);
263         /*
264          * This is useful for poll that is armed on behalf of another
265          * request, and where the wakeup path could be on a different
266          * CPU. We want to avoid pulling in req->apoll->events for that
267          * case.
268          */
269         if (req->opcode == IORING_OP_POLL_ADD)
270                 req->io_task_work.func = io_poll_task_func;
271         else
272                 req->io_task_work.func = io_apoll_task_func;
273
274         trace_io_uring_task_add(req->ctx, req, req->cqe.user_data, req->opcode, mask);
275         io_req_task_work_add(req);
276 }
277
278 static inline void io_poll_execute(struct io_kiocb *req, int res,
279                 __poll_t events)
280 {
281         if (io_poll_get_ownership(req))
282                 __io_poll_execute(req, res, events);
283 }
284
285 static void io_poll_cancel_req(struct io_kiocb *req)
286 {
287         io_poll_mark_cancelled(req);
288         /* kick tw, which should complete the request */
289         io_poll_execute(req, 0, 0);
290 }
291
292 #define wqe_to_req(wait)        ((void *)((unsigned long) (wait)->private & ~1))
293 #define wqe_is_double(wait)     ((unsigned long) (wait)->private & 1)
294 #define IO_ASYNC_POLL_COMMON    (EPOLLONESHOT | EPOLLPRI)
295
296 static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
297                         void *key)
298 {
299         struct io_kiocb *req = wqe_to_req(wait);
300         struct io_poll *poll = container_of(wait, struct io_poll, wait);
301         __poll_t mask = key_to_poll(key);
302
303         if (unlikely(mask & POLLFREE)) {
304                 io_poll_mark_cancelled(req);
305                 /* we have to kick tw in case it's not already */
306                 io_poll_execute(req, 0, poll->events);
307
308                 /*
309                  * If the waitqueue is being freed early but someone is already
310                  * holds ownership over it, we have to tear down the request as
311                  * best we can. That means immediately removing the request from
312                  * its waitqueue and preventing all further accesses to the
313                  * waitqueue via the request.
314                  */
315                 list_del_init(&poll->wait.entry);
316
317                 /*
318                  * Careful: this *must* be the last step, since as soon
319                  * as req->head is NULL'ed out, the request can be
320                  * completed and freed, since aio_poll_complete_work()
321                  * will no longer need to take the waitqueue lock.
322                  */
323                 smp_store_release(&poll->head, NULL);
324                 return 1;
325         }
326
327         /* for instances that support it check for an event match first */
328         if (mask && !(mask & (poll->events & ~IO_ASYNC_POLL_COMMON)))
329                 return 0;
330
331         if (io_poll_get_ownership(req)) {
332                 /* optional, saves extra locking for removal in tw handler */
333                 if (mask && poll->events & EPOLLONESHOT) {
334                         list_del_init(&poll->wait.entry);
335                         poll->head = NULL;
336                         if (wqe_is_double(wait))
337                                 req->flags &= ~REQ_F_DOUBLE_POLL;
338                         else
339                                 req->flags &= ~REQ_F_SINGLE_POLL;
340                 }
341                 __io_poll_execute(req, mask, poll->events);
342         }
343         return 1;
344 }
345
346 static void __io_queue_proc(struct io_poll *poll, struct io_poll_table *pt,
347                             struct wait_queue_head *head,
348                             struct io_poll **poll_ptr)
349 {
350         struct io_kiocb *req = pt->req;
351         unsigned long wqe_private = (unsigned long) req;
352
353         /*
354          * The file being polled uses multiple waitqueues for poll handling
355          * (e.g. one for read, one for write). Setup a separate io_poll
356          * if this happens.
357          */
358         if (unlikely(pt->nr_entries)) {
359                 struct io_poll *first = poll;
360
361                 /* double add on the same waitqueue head, ignore */
362                 if (first->head == head)
363                         return;
364                 /* already have a 2nd entry, fail a third attempt */
365                 if (*poll_ptr) {
366                         if ((*poll_ptr)->head == head)
367                                 return;
368                         pt->error = -EINVAL;
369                         return;
370                 }
371
372                 poll = kmalloc(sizeof(*poll), GFP_ATOMIC);
373                 if (!poll) {
374                         pt->error = -ENOMEM;
375                         return;
376                 }
377                 /* mark as double wq entry */
378                 wqe_private |= 1;
379                 req->flags |= REQ_F_DOUBLE_POLL;
380                 io_init_poll_iocb(poll, first->events, first->wait.func);
381                 *poll_ptr = poll;
382                 if (req->opcode == IORING_OP_POLL_ADD)
383                         req->flags |= REQ_F_ASYNC_DATA;
384         }
385
386         req->flags |= REQ_F_SINGLE_POLL;
387         pt->nr_entries++;
388         poll->head = head;
389         poll->wait.private = (void *) wqe_private;
390
391         if (poll->events & EPOLLEXCLUSIVE)
392                 add_wait_queue_exclusive(head, &poll->wait);
393         else
394                 add_wait_queue(head, &poll->wait);
395 }
396
397 static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head,
398                                struct poll_table_struct *p)
399 {
400         struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
401         struct io_poll *poll = io_kiocb_to_cmd(pt->req);
402
403         __io_queue_proc(poll, pt, head,
404                         (struct io_poll **) &pt->req->async_data);
405 }
406
407 static int __io_arm_poll_handler(struct io_kiocb *req,
408                                  struct io_poll *poll,
409                                  struct io_poll_table *ipt, __poll_t mask)
410 {
411         struct io_ring_ctx *ctx = req->ctx;
412         int v;
413
414         INIT_HLIST_NODE(&req->hash_node);
415         req->work.cancel_seq = atomic_read(&ctx->cancel_seq);
416         io_init_poll_iocb(poll, mask, io_poll_wake);
417         poll->file = req->file;
418
419         req->apoll_events = poll->events;
420
421         ipt->pt._key = mask;
422         ipt->req = req;
423         ipt->error = 0;
424         ipt->nr_entries = 0;
425
426         /*
427          * Take the ownership to delay any tw execution up until we're done
428          * with poll arming. see io_poll_get_ownership().
429          */
430         atomic_set(&req->poll_refs, 1);
431         mask = vfs_poll(req->file, &ipt->pt) & poll->events;
432
433         if (mask &&
434            ((poll->events & (EPOLLET|EPOLLONESHOT)) == (EPOLLET|EPOLLONESHOT))) {
435                 io_poll_remove_entries(req);
436                 /* no one else has access to the req, forget about the ref */
437                 return mask;
438         }
439
440         if (!mask && unlikely(ipt->error || !ipt->nr_entries)) {
441                 io_poll_remove_entries(req);
442                 if (!ipt->error)
443                         ipt->error = -EINVAL;
444                 return 0;
445         }
446
447         io_poll_req_insert(req);
448
449         if (mask && (poll->events & EPOLLET)) {
450                 /* can't multishot if failed, just queue the event we've got */
451                 if (unlikely(ipt->error || !ipt->nr_entries)) {
452                         poll->events |= EPOLLONESHOT;
453                         req->apoll_events |= EPOLLONESHOT;
454                         ipt->error = 0;
455                 }
456                 __io_poll_execute(req, mask, poll->events);
457                 return 0;
458         }
459
460         /*
461          * Release ownership. If someone tried to queue a tw while it was
462          * locked, kick it off for them.
463          */
464         v = atomic_dec_return(&req->poll_refs);
465         if (unlikely(v & IO_POLL_REF_MASK))
466                 __io_poll_execute(req, 0, poll->events);
467         return 0;
468 }
469
470 static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
471                                struct poll_table_struct *p)
472 {
473         struct io_poll_table *pt = container_of(p, struct io_poll_table, pt);
474         struct async_poll *apoll = pt->req->apoll;
475
476         __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
477 }
478
479 int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
480 {
481         const struct io_op_def *def = &io_op_defs[req->opcode];
482         struct io_ring_ctx *ctx = req->ctx;
483         struct async_poll *apoll;
484         struct io_poll_table ipt;
485         __poll_t mask = POLLPRI | POLLERR | EPOLLET;
486         int ret;
487
488         if (!def->pollin && !def->pollout)
489                 return IO_APOLL_ABORTED;
490         if (!file_can_poll(req->file))
491                 return IO_APOLL_ABORTED;
492         if ((req->flags & (REQ_F_POLLED|REQ_F_PARTIAL_IO)) == REQ_F_POLLED)
493                 return IO_APOLL_ABORTED;
494         if (!(req->flags & REQ_F_APOLL_MULTISHOT))
495                 mask |= EPOLLONESHOT;
496
497         if (def->pollin) {
498                 mask |= EPOLLIN | EPOLLRDNORM;
499
500                 /* If reading from MSG_ERRQUEUE using recvmsg, ignore POLLIN */
501                 if (req->flags & REQ_F_CLEAR_POLLIN)
502                         mask &= ~EPOLLIN;
503         } else {
504                 mask |= EPOLLOUT | EPOLLWRNORM;
505         }
506         if (def->poll_exclusive)
507                 mask |= EPOLLEXCLUSIVE;
508         if (req->flags & REQ_F_POLLED) {
509                 apoll = req->apoll;
510                 kfree(apoll->double_poll);
511         } else if (!(issue_flags & IO_URING_F_UNLOCKED) &&
512                    !list_empty(&ctx->apoll_cache)) {
513                 apoll = list_first_entry(&ctx->apoll_cache, struct async_poll,
514                                                 poll.wait.entry);
515                 list_del_init(&apoll->poll.wait.entry);
516         } else {
517                 apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC);
518                 if (unlikely(!apoll))
519                         return IO_APOLL_ABORTED;
520         }
521         apoll->double_poll = NULL;
522         req->apoll = apoll;
523         req->flags |= REQ_F_POLLED;
524         ipt.pt._qproc = io_async_queue_proc;
525
526         io_kbuf_recycle(req, issue_flags);
527
528         ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask);
529         if (ret || ipt.error)
530                 return ret ? IO_APOLL_READY : IO_APOLL_ABORTED;
531
532         trace_io_uring_poll_arm(ctx, req, req->cqe.user_data, req->opcode,
533                                 mask, apoll->poll.events);
534         return IO_APOLL_OK;
535 }
536
537 /*
538  * Returns true if we found and killed one or more poll requests
539  */
540 __cold bool io_poll_remove_all(struct io_ring_ctx *ctx, struct task_struct *tsk,
541                                bool cancel_all)
542 {
543         struct io_hash_table *table = &ctx->cancel_table;
544         unsigned nr_buckets = 1U << table->hash_bits;
545         struct hlist_node *tmp;
546         struct io_kiocb *req;
547         bool found = false;
548         int i;
549
550         for (i = 0; i < nr_buckets; i++) {
551                 struct io_hash_bucket *hb = &table->hbs[i];
552
553                 spin_lock(&hb->lock);
554                 hlist_for_each_entry_safe(req, tmp, &hb->list, hash_node) {
555                         if (io_match_task_safe(req, tsk, cancel_all)) {
556                                 hlist_del_init(&req->hash_node);
557                                 io_poll_cancel_req(req);
558                                 found = true;
559                         }
560                 }
561                 spin_unlock(&hb->lock);
562         }
563         return found;
564 }
565
566 static struct io_kiocb *io_poll_find(struct io_ring_ctx *ctx, bool poll_only,
567                                      struct io_cancel_data *cd,
568                                      struct io_hash_table *table,
569                                      struct io_hash_bucket **out_bucket)
570 {
571         struct io_kiocb *req;
572         u32 index = hash_long(cd->data, table->hash_bits);
573         struct io_hash_bucket *hb = &table->hbs[index];
574
575         *out_bucket = NULL;
576
577         spin_lock(&hb->lock);
578         hlist_for_each_entry(req, &hb->list, hash_node) {
579                 if (cd->data != req->cqe.user_data)
580                         continue;
581                 if (poll_only && req->opcode != IORING_OP_POLL_ADD)
582                         continue;
583                 if (cd->flags & IORING_ASYNC_CANCEL_ALL) {
584                         if (cd->seq == req->work.cancel_seq)
585                                 continue;
586                         req->work.cancel_seq = cd->seq;
587                 }
588                 *out_bucket = hb;
589                 return req;
590         }
591         spin_unlock(&hb->lock);
592         return NULL;
593 }
594
595 static struct io_kiocb *io_poll_file_find(struct io_ring_ctx *ctx,
596                                           struct io_cancel_data *cd,
597                                           struct io_hash_table *table,
598                                           struct io_hash_bucket **out_bucket)
599 {
600         unsigned nr_buckets = 1U << table->hash_bits;
601         struct io_kiocb *req;
602         int i;
603
604         *out_bucket = NULL;
605
606         for (i = 0; i < nr_buckets; i++) {
607                 struct io_hash_bucket *hb = &table->hbs[i];
608
609                 spin_lock(&hb->lock);
610                 hlist_for_each_entry(req, &hb->list, hash_node) {
611                         if (!(cd->flags & IORING_ASYNC_CANCEL_ANY) &&
612                             req->file != cd->file)
613                                 continue;
614                         if (cd->seq == req->work.cancel_seq)
615                                 continue;
616                         req->work.cancel_seq = cd->seq;
617                         *out_bucket = hb;
618                         return req;
619                 }
620                 spin_unlock(&hb->lock);
621         }
622         return NULL;
623 }
624
625 static bool io_poll_disarm(struct io_kiocb *req)
626 {
627         if (!io_poll_get_ownership(req))
628                 return false;
629         io_poll_remove_entries(req);
630         hash_del(&req->hash_node);
631         return true;
632 }
633
634 static int __io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
635                             struct io_hash_table *table)
636 {
637         struct io_hash_bucket *bucket;
638         struct io_kiocb *req;
639
640         if (cd->flags & (IORING_ASYNC_CANCEL_FD|IORING_ASYNC_CANCEL_ANY))
641                 req = io_poll_file_find(ctx, cd, table, &bucket);
642         else
643                 req = io_poll_find(ctx, false, cd, table, &bucket);
644
645         if (req)
646                 io_poll_cancel_req(req);
647         if (bucket)
648                 spin_unlock(&bucket->lock);
649         return req ? 0 : -ENOENT;
650 }
651
652 int io_poll_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd)
653 {
654         return __io_poll_cancel(ctx, cd, &ctx->cancel_table);
655 }
656
657 static __poll_t io_poll_parse_events(const struct io_uring_sqe *sqe,
658                                      unsigned int flags)
659 {
660         u32 events;
661
662         events = READ_ONCE(sqe->poll32_events);
663 #ifdef __BIG_ENDIAN
664         events = swahw32(events);
665 #endif
666         if (!(flags & IORING_POLL_ADD_MULTI))
667                 events |= EPOLLONESHOT;
668         if (!(flags & IORING_POLL_ADD_LEVEL))
669                 events |= EPOLLET;
670         return demangle_poll(events) |
671                 (events & (EPOLLEXCLUSIVE|EPOLLONESHOT|EPOLLET));
672 }
673
674 int io_poll_remove_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
675 {
676         struct io_poll_update *upd = io_kiocb_to_cmd(req);
677         u32 flags;
678
679         if (sqe->buf_index || sqe->splice_fd_in)
680                 return -EINVAL;
681         flags = READ_ONCE(sqe->len);
682         if (flags & ~(IORING_POLL_UPDATE_EVENTS | IORING_POLL_UPDATE_USER_DATA |
683                       IORING_POLL_ADD_MULTI))
684                 return -EINVAL;
685         /* meaningless without update */
686         if (flags == IORING_POLL_ADD_MULTI)
687                 return -EINVAL;
688
689         upd->old_user_data = READ_ONCE(sqe->addr);
690         upd->update_events = flags & IORING_POLL_UPDATE_EVENTS;
691         upd->update_user_data = flags & IORING_POLL_UPDATE_USER_DATA;
692
693         upd->new_user_data = READ_ONCE(sqe->off);
694         if (!upd->update_user_data && upd->new_user_data)
695                 return -EINVAL;
696         if (upd->update_events)
697                 upd->events = io_poll_parse_events(sqe, flags);
698         else if (sqe->poll32_events)
699                 return -EINVAL;
700
701         return 0;
702 }
703
704 int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
705 {
706         struct io_poll *poll = io_kiocb_to_cmd(req);
707         u32 flags;
708
709         if (sqe->buf_index || sqe->off || sqe->addr)
710                 return -EINVAL;
711         flags = READ_ONCE(sqe->len);
712         if (flags & ~(IORING_POLL_ADD_MULTI|IORING_POLL_ADD_LEVEL))
713                 return -EINVAL;
714         if ((flags & IORING_POLL_ADD_MULTI) && (req->flags & REQ_F_CQE_SKIP))
715                 return -EINVAL;
716
717         poll->events = io_poll_parse_events(sqe, flags);
718         return 0;
719 }
720
721 int io_poll_add(struct io_kiocb *req, unsigned int issue_flags)
722 {
723         struct io_poll *poll = io_kiocb_to_cmd(req);
724         struct io_poll_table ipt;
725         int ret;
726
727         ipt.pt._qproc = io_poll_queue_proc;
728
729         ret = __io_arm_poll_handler(req, poll, &ipt, poll->events);
730         if (ret) {
731                 io_req_set_res(req, ret, 0);
732                 return IOU_OK;
733         }
734         if (ipt.error) {
735                 req_set_fail(req);
736                 return ipt.error;
737         }
738
739         return IOU_ISSUE_SKIP_COMPLETE;
740 }
741
742 int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
743 {
744         struct io_poll_update *poll_update = io_kiocb_to_cmd(req);
745         struct io_cancel_data cd = { .data = poll_update->old_user_data, };
746         struct io_ring_ctx *ctx = req->ctx;
747         struct io_hash_bucket *bucket;
748         struct io_kiocb *preq;
749         int ret2, ret = 0;
750         bool locked;
751
752         preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
753         if (preq)
754                 ret2 = io_poll_disarm(preq);
755         if (bucket)
756                 spin_unlock(&bucket->lock);
757
758         if (!preq) {
759                 ret = -ENOENT;
760                 goto out;
761         }
762         if (!ret2) {
763                 ret = -EALREADY;
764                 goto out;
765         }
766
767         if (poll_update->update_events || poll_update->update_user_data) {
768                 /* only mask one event flags, keep behavior flags */
769                 if (poll_update->update_events) {
770                         struct io_poll *poll = io_kiocb_to_cmd(preq);
771
772                         poll->events &= ~0xffff;
773                         poll->events |= poll_update->events & 0xffff;
774                         poll->events |= IO_POLL_UNMASK;
775                 }
776                 if (poll_update->update_user_data)
777                         preq->cqe.user_data = poll_update->new_user_data;
778
779                 ret2 = io_poll_add(preq, issue_flags);
780                 /* successfully updated, don't complete poll request */
781                 if (!ret2 || ret2 == -EIOCBQUEUED)
782                         goto out;
783         }
784
785         req_set_fail(preq);
786         io_req_set_res(preq, -ECANCELED, 0);
787         locked = !(issue_flags & IO_URING_F_UNLOCKED);
788         io_req_task_complete(preq, &locked);
789 out:
790         if (ret < 0) {
791                 req_set_fail(req);
792                 return ret;
793         }
794         /* complete update request, we're done with it */
795         io_req_set_res(req, ret, 0);
796         return IOU_OK;
797 }