statmount: simplify string option retrieval
[linux-2.6-block.git] / fs / namespace.c
index 16c2f84c65708813d0245f4365988875cc414d76..7f1618ed2aba934f6d6ee4c27871edcb08a14841 100644 (file)
@@ -4694,78 +4694,15 @@ static struct vfsmount *lookup_mnt_in_ns(u64 id, struct mnt_namespace *ns)
 }
 
 struct kstatmount {
-       struct statmount __user *const buf;
-       size_t const bufsize;
-       struct vfsmount *const mnt;
-       u64 const mask;
-       struct seq_file seq;
+       struct statmount __user *buf;
+       size_t bufsize;
+       struct vfsmount *mnt;
+       u64 mask;
        struct path root;
        struct statmount sm;
-       size_t pos;
-       int err;
+       struct seq_file seq;
 };
 
-typedef int (*statmount_func_t)(struct kstatmount *);
-
-static int statmount_string_seq(struct kstatmount *s, statmount_func_t func)
-{
-       size_t rem = s->bufsize - s->pos - sizeof(s->sm);
-       struct seq_file *seq = &s->seq;
-       int ret;
-
-       seq->count = 0;
-       seq->size = min(seq->size, rem);
-       seq->buf = kvmalloc(seq->size, GFP_KERNEL_ACCOUNT);
-       if (!seq->buf)
-               return -ENOMEM;
-
-       ret = func(s);
-       if (ret)
-               return ret;
-
-       if (seq_has_overflowed(seq)) {
-               if (seq->size == rem)
-                       return -EOVERFLOW;
-               seq->size *= 2;
-               if (seq->size > MAX_RW_COUNT)
-                       return -ENOMEM;
-               kvfree(seq->buf);
-               return 0;
-       }
-
-       /* Done */
-       return 1;
-}
-
-static void statmount_string(struct kstatmount *s, u64 mask, statmount_func_t func,
-                      u32 *str)
-{
-       int ret = s->pos + sizeof(s->sm) >= s->bufsize ? -EOVERFLOW : 0;
-       struct statmount *sm = &s->sm;
-       struct seq_file *seq = &s->seq;
-
-       if (s->err || !(s->mask & mask))
-               return;
-
-       seq->size = PAGE_SIZE;
-       while (!ret)
-               ret = statmount_string_seq(s, func);
-
-       if (ret < 0) {
-               s->err = ret;
-       } else {
-               seq->buf[seq->count++] = '\0';
-               if (copy_to_user(s->buf->str + s->pos, seq->buf, seq->count)) {
-                       s->err = -EFAULT;
-               } else {
-                       *str = s->pos;
-                       s->pos += seq->count;
-               }
-       }
-       kvfree(seq->buf);
-       sm->mask |= mask;
-}
-
 static u64 mnt_to_attr_flags(struct vfsmount *mnt)
 {
        unsigned int mnt_flags = READ_ONCE(mnt->mnt_flags);
@@ -4848,41 +4785,109 @@ static void statmount_propagate_from(struct kstatmount *s)
                s->sm.propagate_from = get_dominating_id(m, &current->fs->root);
 }
 
-static int statmount_mnt_root(struct kstatmount *s)
+static int statmount_mnt_root(struct kstatmount *s, struct seq_file *seq)
 {
-       struct seq_file *seq = &s->seq;
-       int err = show_path(seq, s->mnt->mnt_root);
+       int ret;
+       size_t start = seq->count;
 
-       if (!err && !seq_has_overflowed(seq)) {
-               seq->buf[seq->count] = '\0';
-               seq->count = string_unescape_inplace(seq->buf, UNESCAPE_OCTAL);
-       }
-       return err;
+       ret = show_path(seq, s->mnt->mnt_root);
+       if (ret)
+               return ret;
+
+       if (unlikely(seq_has_overflowed(seq)))
+               return -EAGAIN;
+
+       /*
+         * Unescape the result. It would be better if supplied string was not
+         * escaped in the first place, but that's a pretty invasive change.
+         */
+       seq->buf[seq->count] = '\0';
+       seq->count = start;
+       seq_commit(seq, string_unescape_inplace(seq->buf + start, UNESCAPE_OCTAL));
+       return 0;
 }
 
-static int statmount_mnt_point(struct kstatmount *s)
+static int statmount_mnt_point(struct kstatmount *s, struct seq_file *seq)
 {
        struct vfsmount *mnt = s->mnt;
        struct path mnt_path = { .dentry = mnt->mnt_root, .mnt = mnt };
-       int err = seq_path_root(&s->seq, &mnt_path, &s->root, "");
+       int err;
 
+       err = seq_path_root(seq, &mnt_path, &s->root, "");
        return err == SEQ_SKIP ? 0 : err;
 }
 
-static int statmount_fs_type(struct kstatmount *s)
+static int statmount_fs_type(struct kstatmount *s, struct seq_file *seq)
 {
-       struct seq_file *seq = &s->seq;
        struct super_block *sb = s->mnt->mnt_sb;
 
        seq_puts(seq, sb->s_type->name);
        return 0;
 }
 
-static int do_statmount(struct kstatmount *s)
+static int statmount_string(struct kstatmount *s, u64 flag)
 {
+       int ret;
+       size_t kbufsize;
+       struct seq_file *seq = &s->seq;
        struct statmount *sm = &s->sm;
-       struct mount *m = real_mount(s->mnt);
+
+       switch (flag) {
+       case STATMOUNT_FS_TYPE:
+               sm->fs_type = seq->count;
+               ret = statmount_fs_type(s, seq);
+               break;
+       case STATMOUNT_MNT_ROOT:
+               sm->mnt_root = seq->count;
+               ret = statmount_mnt_root(s, seq);
+               break;
+       case STATMOUNT_MNT_POINT:
+               sm->mnt_point = seq->count;
+               ret = statmount_mnt_point(s, seq);
+               break;
+       default:
+               WARN_ON_ONCE(true);
+               return -EINVAL;
+       }
+
+       if (unlikely(check_add_overflow(sizeof(*sm), seq->count, &kbufsize)))
+               return -EOVERFLOW;
+       if (kbufsize >= s->bufsize)
+               return -EOVERFLOW;
+
+       /* signal a retry */
+       if (unlikely(seq_has_overflowed(seq)))
+               return -EAGAIN;
+
+       if (ret)
+               return ret;
+
+       seq->buf[seq->count++] = '\0';
+       sm->mask |= flag;
+       return 0;
+}
+
+static int copy_statmount_to_user(struct kstatmount *s)
+{
+       struct statmount *sm = &s->sm;
+       struct seq_file *seq = &s->seq;
+       char __user *str = ((char __user *)s->buf) + sizeof(*sm);
        size_t copysize = min_t(size_t, s->bufsize, sizeof(*sm));
+
+       if (seq->count && copy_to_user(str, seq->buf, seq->count))
+               return -EFAULT;
+
+       /* Return the number of bytes copied to the buffer */
+       sm->size = copysize + seq->count;
+       if (copy_to_user(s->buf, sm, copysize))
+               return -EFAULT;
+
+       return 0;
+}
+
+static int do_statmount(struct kstatmount *s)
+{
+       struct mount *m = real_mount(s->mnt);
        int err;
 
        /*
@@ -4906,19 +4911,47 @@ static int do_statmount(struct kstatmount *s)
        if (s->mask & STATMOUNT_PROPAGATE_FROM)
                statmount_propagate_from(s);
 
-       statmount_string(s, STATMOUNT_FS_TYPE, statmount_fs_type, &sm->fs_type);
-       statmount_string(s, STATMOUNT_MNT_ROOT, statmount_mnt_root, &sm->mnt_root);
-       statmount_string(s, STATMOUNT_MNT_POINT, statmount_mnt_point, &sm->mnt_point);
+       if (s->mask & STATMOUNT_FS_TYPE)
+               err = statmount_string(s, STATMOUNT_FS_TYPE);
 
-       if (s->err)
-               return s->err;
+       if (!err && s->mask & STATMOUNT_MNT_ROOT)
+               err = statmount_string(s, STATMOUNT_MNT_ROOT);
 
-       /* Return the number of bytes copied to the buffer */
-       sm->size = copysize + s->pos;
+       if (!err && s->mask & STATMOUNT_MNT_POINT)
+               err = statmount_string(s, STATMOUNT_MNT_POINT);
 
-       if (copy_to_user(s->buf, sm, copysize))
+       if (err)
+               return err;
+
+       return 0;
+}
+
+static inline bool retry_statmount(const long ret, size_t *seq_size)
+{
+       if (likely(ret != -EAGAIN))
+               return false;
+       if (unlikely(check_mul_overflow(*seq_size, 2, seq_size)))
+               return false;
+       if (unlikely(*seq_size > MAX_RW_COUNT))
+               return false;
+       return true;
+}
+
+static int prepare_kstatmount(struct kstatmount *ks, struct mnt_id_req *kreq,
+                             struct statmount __user *buf, size_t bufsize,
+                             size_t seq_size)
+{
+       if (!access_ok(buf, bufsize))
                return -EFAULT;
 
+       memset(ks, 0, sizeof(*ks));
+       ks->mask = kreq->request_mask;
+       ks->buf = buf;
+       ks->bufsize = bufsize;
+       ks->seq.size = seq_size;
+       ks->seq.buf = kvmalloc(seq_size, GFP_KERNEL_ACCOUNT);
+       if (!ks->seq.buf)
+               return -ENOMEM;
        return 0;
 }
 
@@ -4928,6 +4961,9 @@ SYSCALL_DEFINE4(statmount, const struct mnt_id_req __user *, req,
 {
        struct vfsmount *mnt;
        struct mnt_id_req kreq;
+       struct kstatmount ks;
+       /* We currently support retrieval of 3 strings. */
+       size_t seq_size = 3 * PATH_MAX;
        int ret;
 
        if (flags)
@@ -4936,23 +4972,30 @@ SYSCALL_DEFINE4(statmount, const struct mnt_id_req __user *, req,
        if (copy_from_user(&kreq, req, sizeof(kreq)))
                return -EFAULT;
 
+retry:
+       ret = prepare_kstatmount(&ks, &kreq, buf, bufsize, seq_size);
+       if (ret)
+               return ret;
+
        down_read(&namespace_sem);
        mnt = lookup_mnt_in_ns(kreq.mnt_id, current->nsproxy->mnt_ns);
-       ret = -ENOENT;
-       if (mnt) {
-               struct kstatmount s = {
-                       .mask = kreq.request_mask,
-                       .buf = buf,
-                       .bufsize = bufsize,
-                       .mnt = mnt,
-               };
-
-               get_fs_root(current->fs, &s.root);
-               ret = do_statmount(&s);
-               path_put(&s.root);
+       if (!mnt) {
+               up_read(&namespace_sem);
+               kvfree(ks.seq.buf);
+               return -ENOENT;
        }
+
+       ks.mnt = mnt;
+       get_fs_root(current->fs, &ks.root);
+       ret = do_statmount(&ks);
+       path_put(&ks.root);
        up_read(&namespace_sem);
 
+       if (!ret)
+               ret = copy_statmount_to_user(&ks);
+       kvfree(ks.seq.buf);
+       if (retry_statmount(ret, &seq_size))
+               goto retry;
        return ret;
 }