init/Kconfig: clean up ANON_INODES and old IO schedulers options
[linux-block.git] / fs / userfaultfd.c
index 37df7c9eedb1527089466f45a930c69d8bab1925..e39fdec8a0b08be4b6262abf6b37919fc735c38c 100644 (file)
@@ -314,8 +314,11 @@ static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
        if (!pmd_present(_pmd))
                goto out;
 
-       if (pmd_trans_huge(_pmd))
+       if (pmd_trans_huge(_pmd)) {
+               if (!pmd_write(_pmd) && (reason & VM_UFFD_WP))
+                       ret = true;
                goto out;
+       }
 
        /*
         * the pmd is stable (as in !pmd_trans_unstable) so we can re-read it
@@ -328,12 +331,38 @@ static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
         */
        if (pte_none(*pte))
                ret = true;
+       if (!pte_write(*pte) && (reason & VM_UFFD_WP))
+               ret = true;
        pte_unmap(pte);
 
 out:
        return ret;
 }
 
+/* Should pair with userfaultfd_signal_pending() */
+static inline long userfaultfd_get_blocking_state(unsigned int flags)
+{
+       if (flags & FAULT_FLAG_INTERRUPTIBLE)
+               return TASK_INTERRUPTIBLE;
+
+       if (flags & FAULT_FLAG_KILLABLE)
+               return TASK_KILLABLE;
+
+       return TASK_UNINTERRUPTIBLE;
+}
+
+/* Should pair with userfaultfd_get_blocking_state() */
+static inline bool userfaultfd_signal_pending(unsigned int flags)
+{
+       if (flags & FAULT_FLAG_INTERRUPTIBLE)
+               return signal_pending(current);
+
+       if (flags & FAULT_FLAG_KILLABLE)
+               return fatal_signal_pending(current);
+
+       return false;
+}
+
 /*
  * The locking rules involved in returning VM_FAULT_RETRY depending on
  * FAULT_FLAG_ALLOW_RETRY, FAULT_FLAG_RETRY_NOWAIT and
@@ -355,7 +384,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
        struct userfaultfd_ctx *ctx;
        struct userfaultfd_wait_queue uwq;
        vm_fault_t ret = VM_FAULT_SIGBUS;
-       bool must_wait, return_to_userland;
+       bool must_wait;
        long blocking_state;
 
        /*
@@ -462,11 +491,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
        uwq.ctx = ctx;
        uwq.waken = false;
 
-       return_to_userland =
-               (vmf->flags & (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE)) ==
-               (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE);
-       blocking_state = return_to_userland ? TASK_INTERRUPTIBLE :
-                        TASK_KILLABLE;
+       blocking_state = userfaultfd_get_blocking_state(vmf->flags);
 
        spin_lock_irq(&ctx->fault_pending_wqh.lock);
        /*
@@ -492,8 +517,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
        up_read(&mm->mmap_sem);
 
        if (likely(must_wait && !READ_ONCE(ctx->released) &&
-                  (return_to_userland ? !signal_pending(current) :
-                   !fatal_signal_pending(current)))) {
+                  !userfaultfd_signal_pending(vmf->flags))) {
                wake_up_poll(&ctx->fd_wqh, EPOLLIN);
                schedule();
                ret |= VM_FAULT_MAJOR;
@@ -515,8 +539,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
                        set_current_state(blocking_state);
                        if (READ_ONCE(uwq.waken) ||
                            READ_ONCE(ctx->released) ||
-                           (return_to_userland ? signal_pending(current) :
-                            fatal_signal_pending(current)))
+                           userfaultfd_signal_pending(vmf->flags))
                                break;
                        schedule();
                }
@@ -524,30 +547,6 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 
        __set_current_state(TASK_RUNNING);
 
-       if (return_to_userland) {
-               if (signal_pending(current) &&
-                   !fatal_signal_pending(current)) {
-                       /*
-                        * If we got a SIGSTOP or SIGCONT and this is
-                        * a normal userland page fault, just let
-                        * userland return so the signal will be
-                        * handled and gdb debugging works.  The page
-                        * fault code immediately after we return from
-                        * this function is going to release the
-                        * mmap_sem and it's not depending on it
-                        * (unlike gup would if we were not to return
-                        * VM_FAULT_RETRY).
-                        *
-                        * If a fatal signal is pending we still take
-                        * the streamlined VM_FAULT_RETRY failure path
-                        * and there's no need to retake the mmap_sem
-                        * in such case.
-                        */
-                       down_read(&mm->mmap_sem);
-                       ret = VM_FAULT_NOPAGE;
-               }
-       }
-
        /*
         * Here we race with the list_del; list_add in
         * userfaultfd_ctx_read(), however because we don't ever run
@@ -1293,10 +1292,13 @@ static __always_inline int validate_range(struct mm_struct *mm,
        return 0;
 }
 
-static inline bool vma_can_userfault(struct vm_area_struct *vma)
+static inline bool vma_can_userfault(struct vm_area_struct *vma,
+                                    unsigned long vm_flags)
 {
-       return vma_is_anonymous(vma) || is_vm_hugetlb_page(vma) ||
-               vma_is_shmem(vma);
+       /* FIXME: add WP support to hugetlbfs and shmem */
+       return vma_is_anonymous(vma) ||
+               ((is_vm_hugetlb_page(vma) || vma_is_shmem(vma)) &&
+                !(vm_flags & VM_UFFD_WP));
 }
 
 static int userfaultfd_register(struct userfaultfd_ctx *ctx,
@@ -1328,15 +1330,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        vm_flags = 0;
        if (uffdio_register.mode & UFFDIO_REGISTER_MODE_MISSING)
                vm_flags |= VM_UFFD_MISSING;
-       if (uffdio_register.mode & UFFDIO_REGISTER_MODE_WP) {
+       if (uffdio_register.mode & UFFDIO_REGISTER_MODE_WP)
                vm_flags |= VM_UFFD_WP;
-               /*
-                * FIXME: remove the below error constraint by
-                * implementing the wprotect tracking mode.
-                */
-               ret = -EINVAL;
-               goto out;
-       }
 
        ret = validate_range(mm, &uffdio_register.range.start,
                             uffdio_register.range.len);
@@ -1386,7 +1381,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 
                /* check not compatible vmas */
                ret = -EINVAL;
-               if (!vma_can_userfault(cur))
+               if (!vma_can_userfault(cur, vm_flags))
                        goto out_unlock;
 
                /*
@@ -1414,6 +1409,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                        if (end & (vma_hpagesize - 1))
                                goto out_unlock;
                }
+               if ((vm_flags & VM_UFFD_WP) && !(cur->vm_flags & VM_MAYWRITE))
+                       goto out_unlock;
 
                /*
                 * Check that this vma isn't already owned by a
@@ -1443,7 +1440,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        do {
                cond_resched();
 
-               BUG_ON(!vma_can_userfault(vma));
+               BUG_ON(!vma_can_userfault(vma, vm_flags));
                BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
                       vma->vm_userfaultfd_ctx.ctx != ctx);
                WARN_ON(!(vma->vm_flags & VM_MAYWRITE));
@@ -1498,14 +1495,24 @@ out_unlock:
        up_write(&mm->mmap_sem);
        mmput(mm);
        if (!ret) {
+               __u64 ioctls_out;
+
+               ioctls_out = basic_ioctls ? UFFD_API_RANGE_IOCTLS_BASIC :
+                   UFFD_API_RANGE_IOCTLS;
+
+               /*
+                * Declare the WP ioctl only if the WP mode is
+                * specified and all checks passed with the range
+                */
+               if (!(uffdio_register.mode & UFFDIO_REGISTER_MODE_WP))
+                       ioctls_out &= ~((__u64)1 << _UFFDIO_WRITEPROTECT);
+
                /*
                 * Now that we scanned all vmas we can already tell
                 * userland which ioctls methods are guaranteed to
                 * succeed on this range.
                 */
-               if (put_user(basic_ioctls ? UFFD_API_RANGE_IOCTLS_BASIC :
-                            UFFD_API_RANGE_IOCTLS,
-                            &user_uffdio_register->ioctls))
+               if (put_user(ioctls_out, &user_uffdio_register->ioctls))
                        ret = -EFAULT;
        }
 out:
@@ -1581,7 +1588,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                 * provides for more strict behavior to notice
                 * unregistration errors.
                 */
-               if (!vma_can_userfault(cur))
+               if (!vma_can_userfault(cur, cur->vm_flags))
                        goto out_unlock;
 
                found = true;
@@ -1595,7 +1602,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
        do {
                cond_resched();
 
-               BUG_ON(!vma_can_userfault(vma));
+               BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
 
                /*
                 * Nothing to do: this vma is already registered into this
@@ -1730,11 +1737,12 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
        ret = -EINVAL;
        if (uffdio_copy.src + uffdio_copy.len <= uffdio_copy.src)
                goto out;
-       if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE)
+       if (uffdio_copy.mode & ~(UFFDIO_COPY_MODE_DONTWAKE|UFFDIO_COPY_MODE_WP))
                goto out;
        if (mmget_not_zero(ctx->mm)) {
                ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
-                                  uffdio_copy.len, &ctx->mmap_changing);
+                                  uffdio_copy.len, &ctx->mmap_changing,
+                                  uffdio_copy.mode);
                mmput(ctx->mm);
        } else {
                return -ESRCH;
@@ -1807,6 +1815,53 @@ out:
        return ret;
 }
 
+static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
+                                   unsigned long arg)
+{
+       int ret;
+       struct uffdio_writeprotect uffdio_wp;
+       struct uffdio_writeprotect __user *user_uffdio_wp;
+       struct userfaultfd_wake_range range;
+       bool mode_wp, mode_dontwake;
+
+       if (READ_ONCE(ctx->mmap_changing))
+               return -EAGAIN;
+
+       user_uffdio_wp = (struct uffdio_writeprotect __user *) arg;
+
+       if (copy_from_user(&uffdio_wp, user_uffdio_wp,
+                          sizeof(struct uffdio_writeprotect)))
+               return -EFAULT;
+
+       ret = validate_range(ctx->mm, &uffdio_wp.range.start,
+                            uffdio_wp.range.len);
+       if (ret)
+               return ret;
+
+       if (uffdio_wp.mode & ~(UFFDIO_WRITEPROTECT_MODE_DONTWAKE |
+                              UFFDIO_WRITEPROTECT_MODE_WP))
+               return -EINVAL;
+
+       mode_wp = uffdio_wp.mode & UFFDIO_WRITEPROTECT_MODE_WP;
+       mode_dontwake = uffdio_wp.mode & UFFDIO_WRITEPROTECT_MODE_DONTWAKE;
+
+       if (mode_wp && mode_dontwake)
+               return -EINVAL;
+
+       ret = mwriteprotect_range(ctx->mm, uffdio_wp.range.start,
+                                 uffdio_wp.range.len, mode_wp,
+                                 &ctx->mmap_changing);
+       if (ret)
+               return ret;
+
+       if (!mode_wp && !mode_dontwake) {
+               range.start = uffdio_wp.range.start;
+               range.len = uffdio_wp.range.len;
+               wake_userfault(ctx, &range);
+       }
+       return ret;
+}
+
 static inline unsigned int uffd_ctx_features(__u64 user_features)
 {
        /*
@@ -1888,6 +1943,9 @@ static long userfaultfd_ioctl(struct file *file, unsigned cmd,
        case UFFDIO_ZEROPAGE:
                ret = userfaultfd_zeropage(ctx, arg);
                break;
+       case UFFDIO_WRITEPROTECT:
+               ret = userfaultfd_writeprotect(ctx, arg);
+               break;
        }
        return ret;
 }