ata: libata-core: fetch sense data for successful commands iff CDL enabled
[linux-2.6-block.git] / drivers / vhost / vhost.c
index 60c9ebd629dd159d0ceb8a0c14f96091be193753..c71d573f1c9497c37e2da7693becd09f10bf5989 100644 (file)
@@ -187,13 +187,15 @@ EXPORT_SYMBOL_GPL(vhost_work_init);
 
 /* Init poll structure */
 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
-                    __poll_t mask, struct vhost_dev *dev)
+                    __poll_t mask, struct vhost_dev *dev,
+                    struct vhost_virtqueue *vq)
 {
        init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
        init_poll_funcptr(&poll->table, vhost_poll_func);
        poll->mask = mask;
        poll->dev = dev;
        poll->wqh = NULL;
+       poll->vq = vq;
 
        vhost_work_init(&poll->work, fn);
 }
@@ -231,46 +233,102 @@ void vhost_poll_stop(struct vhost_poll *poll)
 }
 EXPORT_SYMBOL_GPL(vhost_poll_stop);
 
-void vhost_dev_flush(struct vhost_dev *dev)
+static void vhost_worker_queue(struct vhost_worker *worker,
+                              struct vhost_work *work)
+{
+       if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
+               /* We can only add the work to the list after we're
+                * sure it was not in the list.
+                * test_and_set_bit() implies a memory barrier.
+                */
+               llist_add(&work->node, &worker->work_list);
+               vhost_task_wake(worker->vtsk);
+       }
+}
+
+bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
+{
+       struct vhost_worker *worker;
+       bool queued = false;
+
+       rcu_read_lock();
+       worker = rcu_dereference(vq->worker);
+       if (worker) {
+               queued = true;
+               vhost_worker_queue(worker, work);
+       }
+       rcu_read_unlock();
+
+       return queued;
+}
+EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
+
+void vhost_vq_flush(struct vhost_virtqueue *vq)
 {
        struct vhost_flush_struct flush;
 
-       if (dev->worker.vtsk) {
-               init_completion(&flush.wait_event);
-               vhost_work_init(&flush.work, vhost_flush_work);
+       init_completion(&flush.wait_event);
+       vhost_work_init(&flush.work, vhost_flush_work);
 
-               vhost_work_queue(dev, &flush.work);
+       if (vhost_vq_work_queue(vq, &flush.work))
                wait_for_completion(&flush.wait_event);
-       }
 }
-EXPORT_SYMBOL_GPL(vhost_dev_flush);
+EXPORT_SYMBOL_GPL(vhost_vq_flush);
 
-void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
+/**
+ * vhost_worker_flush - flush a worker
+ * @worker: worker to flush
+ *
+ * This does not use RCU to protect the worker, so the device or worker
+ * mutex must be held.
+ */
+static void vhost_worker_flush(struct vhost_worker *worker)
 {
-       if (!dev->worker.vtsk)
-               return;
+       struct vhost_flush_struct flush;
 
-       if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
-               /* We can only add the work to the list after we're
-                * sure it was not in the list.
-                * test_and_set_bit() implies a memory barrier.
-                */
-               llist_add(&work->node, &dev->worker.work_list);
-               vhost_task_wake(dev->worker.vtsk);
+       init_completion(&flush.wait_event);
+       vhost_work_init(&flush.work, vhost_flush_work);
+
+       vhost_worker_queue(worker, &flush.work);
+       wait_for_completion(&flush.wait_event);
+}
+
+void vhost_dev_flush(struct vhost_dev *dev)
+{
+       struct vhost_worker *worker;
+       unsigned long i;
+
+       xa_for_each(&dev->worker_xa, i, worker) {
+               mutex_lock(&worker->mutex);
+               if (!worker->attachment_cnt) {
+                       mutex_unlock(&worker->mutex);
+                       continue;
+               }
+               vhost_worker_flush(worker);
+               mutex_unlock(&worker->mutex);
        }
 }
-EXPORT_SYMBOL_GPL(vhost_work_queue);
+EXPORT_SYMBOL_GPL(vhost_dev_flush);
 
 /* A lockless hint for busy polling code to exit the loop */
-bool vhost_has_work(struct vhost_dev *dev)
+bool vhost_vq_has_work(struct vhost_virtqueue *vq)
 {
-       return !llist_empty(&dev->worker.work_list);
+       struct vhost_worker *worker;
+       bool has_work = false;
+
+       rcu_read_lock();
+       worker = rcu_dereference(vq->worker);
+       if (worker && !llist_empty(&worker->work_list))
+               has_work = true;
+       rcu_read_unlock();
+
+       return has_work;
 }
-EXPORT_SYMBOL_GPL(vhost_has_work);
+EXPORT_SYMBOL_GPL(vhost_vq_has_work);
 
 void vhost_poll_queue(struct vhost_poll *poll)
 {
-       vhost_work_queue(poll->dev, &poll->work);
+       vhost_vq_work_queue(poll->vq, &poll->work);
 }
 EXPORT_SYMBOL_GPL(vhost_poll_queue);
 
@@ -329,6 +387,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        vq->busyloop_timeout = 0;
        vq->umem = NULL;
        vq->iotlb = NULL;
+       rcu_assign_pointer(vq->worker, NULL);
        vhost_vring_call_reset(&vq->call_ctx);
        __vhost_vq_meta_reset(vq);
 }
@@ -458,8 +517,6 @@ void vhost_dev_init(struct vhost_dev *dev,
        dev->umem = NULL;
        dev->iotlb = NULL;
        dev->mm = NULL;
-       memset(&dev->worker, 0, sizeof(dev->worker));
-       init_llist_head(&dev->worker.work_list);
        dev->iov_limit = iov_limit;
        dev->weight = weight;
        dev->byte_weight = byte_weight;
@@ -469,7 +526,7 @@ void vhost_dev_init(struct vhost_dev *dev,
        INIT_LIST_HEAD(&dev->read_list);
        INIT_LIST_HEAD(&dev->pending_list);
        spin_lock_init(&dev->iotlb_lock);
-
+       xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
 
        for (i = 0; i < dev->nvqs; ++i) {
                vq = dev->vqs[i];
@@ -481,7 +538,7 @@ void vhost_dev_init(struct vhost_dev *dev,
                vhost_vq_reset(dev, vq);
                if (vq->handle_kick)
                        vhost_poll_init(&vq->poll, vq->handle_kick,
-                                       EPOLLIN, dev);
+                                       EPOLLIN, dev, vq);
        }
 }
 EXPORT_SYMBOL_GPL(vhost_dev_init);
@@ -531,38 +588,284 @@ static void vhost_detach_mm(struct vhost_dev *dev)
        dev->mm = NULL;
 }
 
-static void vhost_worker_free(struct vhost_dev *dev)
+static void vhost_worker_destroy(struct vhost_dev *dev,
+                                struct vhost_worker *worker)
+{
+       if (!worker)
+               return;
+
+       WARN_ON(!llist_empty(&worker->work_list));
+       xa_erase(&dev->worker_xa, worker->id);
+       vhost_task_stop(worker->vtsk);
+       kfree(worker);
+}
+
+static void vhost_workers_free(struct vhost_dev *dev)
 {
-       if (!dev->worker.vtsk)
+       struct vhost_worker *worker;
+       unsigned long i;
+
+       if (!dev->use_worker)
                return;
 
-       WARN_ON(!llist_empty(&dev->worker.work_list));
-       vhost_task_stop(dev->worker.vtsk);
-       dev->worker.kcov_handle = 0;
-       dev->worker.vtsk = NULL;
+       for (i = 0; i < dev->nvqs; i++)
+               rcu_assign_pointer(dev->vqs[i]->worker, NULL);
+       /*
+        * Free the default worker we created and cleanup workers userspace
+        * created but couldn't clean up (it forgot or crashed).
+        */
+       xa_for_each(&dev->worker_xa, i, worker)
+               vhost_worker_destroy(dev, worker);
+       xa_destroy(&dev->worker_xa);
 }
 
-static int vhost_worker_create(struct vhost_dev *dev)
+static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
 {
+       struct vhost_worker *worker;
        struct vhost_task *vtsk;
        char name[TASK_COMM_LEN];
+       int ret;
+       u32 id;
+
+       worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
+       if (!worker)
+               return NULL;
 
        snprintf(name, sizeof(name), "vhost-%d", current->pid);
 
-       vtsk = vhost_task_create(vhost_worker, &dev->worker, name);
+       vtsk = vhost_task_create(vhost_worker, worker, name);
        if (!vtsk)
-               return -ENOMEM;
+               goto free_worker;
+
+       mutex_init(&worker->mutex);
+       init_llist_head(&worker->work_list);
+       worker->kcov_handle = kcov_common_handle();
+       worker->vtsk = vtsk;
 
-       dev->worker.kcov_handle = kcov_common_handle();
-       dev->worker.vtsk = vtsk;
        vhost_task_start(vtsk);
+
+       ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
+       if (ret < 0)
+               goto stop_worker;
+       worker->id = id;
+
+       return worker;
+
+stop_worker:
+       vhost_task_stop(vtsk);
+free_worker:
+       kfree(worker);
+       return NULL;
+}
+
+/* Caller must have device mutex */
+static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+                                    struct vhost_worker *worker)
+{
+       struct vhost_worker *old_worker;
+
+       old_worker = rcu_dereference_check(vq->worker,
+                                          lockdep_is_held(&vq->dev->mutex));
+
+       mutex_lock(&worker->mutex);
+       worker->attachment_cnt++;
+       mutex_unlock(&worker->mutex);
+       rcu_assign_pointer(vq->worker, worker);
+
+       if (!old_worker)
+               return;
+       /*
+        * Take the worker mutex to make sure we see the work queued from
+        * device wide flushes which doesn't use RCU for execution.
+        */
+       mutex_lock(&old_worker->mutex);
+       old_worker->attachment_cnt--;
+       /*
+        * We don't want to call synchronize_rcu for every vq during setup
+        * because it will slow down VM startup. If we haven't done
+        * VHOST_SET_VRING_KICK and not done the driver specific
+        * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
+        * not be any works queued for scsi and net.
+        */
+       mutex_lock(&vq->mutex);
+       if (!vhost_vq_get_backend(vq) && !vq->kick) {
+               mutex_unlock(&vq->mutex);
+               mutex_unlock(&old_worker->mutex);
+               /*
+                * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
+                * Warn if it adds support for multiple workers but forgets to
+                * handle the early queueing case.
+                */
+               WARN_ON(!old_worker->attachment_cnt &&
+                       !llist_empty(&old_worker->work_list));
+               return;
+       }
+       mutex_unlock(&vq->mutex);
+
+       /* Make sure new vq queue/flush/poll calls see the new worker */
+       synchronize_rcu();
+       /* Make sure whatever was queued gets run */
+       vhost_worker_flush(old_worker);
+       mutex_unlock(&old_worker->mutex);
+}
+
+ /* Caller must have device mutex */
+static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
+                                 struct vhost_vring_worker *info)
+{
+       unsigned long index = info->worker_id;
+       struct vhost_dev *dev = vq->dev;
+       struct vhost_worker *worker;
+
+       if (!dev->use_worker)
+               return -EINVAL;
+
+       worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+       if (!worker || worker->id != info->worker_id)
+               return -ENODEV;
+
+       __vhost_vq_attach_worker(vq, worker);
+       return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_new_worker(struct vhost_dev *dev,
+                           struct vhost_worker_state *info)
+{
+       struct vhost_worker *worker;
+
+       worker = vhost_worker_create(dev);
+       if (!worker)
+               return -ENOMEM;
+
+       info->worker_id = worker->id;
+       return 0;
+}
+
+/* Caller must have device mutex */
+static int vhost_free_worker(struct vhost_dev *dev,
+                            struct vhost_worker_state *info)
+{
+       unsigned long index = info->worker_id;
+       struct vhost_worker *worker;
+
+       worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
+       if (!worker || worker->id != info->worker_id)
+               return -ENODEV;
+
+       mutex_lock(&worker->mutex);
+       if (worker->attachment_cnt) {
+               mutex_unlock(&worker->mutex);
+               return -EBUSY;
+       }
+       mutex_unlock(&worker->mutex);
+
+       vhost_worker_destroy(dev, worker);
        return 0;
 }
 
+static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
+                                 struct vhost_virtqueue **vq, u32 *id)
+{
+       u32 __user *idxp = argp;
+       u32 idx;
+       long r;
+
+       r = get_user(idx, idxp);
+       if (r < 0)
+               return r;
+
+       if (idx >= dev->nvqs)
+               return -ENOBUFS;
+
+       idx = array_index_nospec(idx, dev->nvqs);
+
+       *vq = dev->vqs[idx];
+       *id = idx;
+       return 0;
+}
+
+/* Caller must have device mutex */
+long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
+                       void __user *argp)
+{
+       struct vhost_vring_worker ring_worker;
+       struct vhost_worker_state state;
+       struct vhost_worker *worker;
+       struct vhost_virtqueue *vq;
+       long ret;
+       u32 idx;
+
+       if (!dev->use_worker)
+               return -EINVAL;
+
+       if (!vhost_dev_has_owner(dev))
+               return -EINVAL;
+
+       ret = vhost_dev_check_owner(dev);
+       if (ret)
+               return ret;
+
+       switch (ioctl) {
+       /* dev worker ioctls */
+       case VHOST_NEW_WORKER:
+               ret = vhost_new_worker(dev, &state);
+               if (!ret && copy_to_user(argp, &state, sizeof(state)))
+                       ret = -EFAULT;
+               return ret;
+       case VHOST_FREE_WORKER:
+               if (copy_from_user(&state, argp, sizeof(state)))
+                       return -EFAULT;
+               return vhost_free_worker(dev, &state);
+       /* vring worker ioctls */
+       case VHOST_ATTACH_VRING_WORKER:
+       case VHOST_GET_VRING_WORKER:
+               break;
+       default:
+               return -ENOIOCTLCMD;
+       }
+
+       ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
+       if (ret)
+               return ret;
+
+       switch (ioctl) {
+       case VHOST_ATTACH_VRING_WORKER:
+               if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
+                       ret = -EFAULT;
+                       break;
+               }
+
+               ret = vhost_vq_attach_worker(vq, &ring_worker);
+               break;
+       case VHOST_GET_VRING_WORKER:
+               worker = rcu_dereference_check(vq->worker,
+                                              lockdep_is_held(&dev->mutex));
+               if (!worker) {
+                       ret = -EINVAL;
+                       break;
+               }
+
+               ring_worker.index = idx;
+               ring_worker.worker_id = worker->id;
+
+               if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
+                       ret = -EFAULT;
+               break;
+       default:
+               ret = -ENOIOCTLCMD;
+               break;
+       }
+
+       return ret;
+}
+EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
+
 /* Caller should have device mutex */
 long vhost_dev_set_owner(struct vhost_dev *dev)
 {
-       int err;
+       struct vhost_worker *worker;
+       int err, i;
 
        /* Is there an owner already? */
        if (vhost_dev_has_owner(dev)) {
@@ -572,20 +875,32 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 
        vhost_attach_mm(dev);
 
-       if (dev->use_worker) {
-               err = vhost_worker_create(dev);
-               if (err)
-                       goto err_worker;
-       }
-
        err = vhost_dev_alloc_iovecs(dev);
        if (err)
                goto err_iovecs;
 
+       if (dev->use_worker) {
+               /*
+                * This should be done last, because vsock can queue work
+                * before VHOST_SET_OWNER so it simplifies the failure path
+                * below since we don't have to worry about vsock queueing
+                * while we free the worker.
+                */
+               worker = vhost_worker_create(dev);
+               if (!worker) {
+                       err = -ENOMEM;
+                       goto err_worker;
+               }
+
+               for (i = 0; i < dev->nvqs; i++)
+                       __vhost_vq_attach_worker(dev->vqs[i], worker);
+       }
+
        return 0;
-err_iovecs:
-       vhost_worker_free(dev);
+
 err_worker:
+       vhost_dev_free_iovecs(dev);
+err_iovecs:
        vhost_detach_mm(dev);
 err_mm:
        return err;
@@ -677,7 +992,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
        dev->iotlb = NULL;
        vhost_clear_msg(dev);
        wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
-       vhost_worker_free(dev);
+       vhost_workers_free(dev);
        vhost_detach_mm(dev);
 }
 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -1565,21 +1880,15 @@ long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *arg
        struct file *eventfp, *filep = NULL;
        bool pollstart = false, pollstop = false;
        struct eventfd_ctx *ctx = NULL;
-       u32 __user *idxp = argp;
        struct vhost_virtqueue *vq;
        struct vhost_vring_state s;
        struct vhost_vring_file f;
        u32 idx;
        long r;
 
-       r = get_user(idx, idxp);
+       r = vhost_get_vq_from_user(d, argp, &vq, &idx);
        if (r < 0)
                return r;
-       if (idx >= d->nvqs)
-               return -ENOBUFS;
-
-       idx = array_index_nospec(idx, d->nvqs);
-       vq = d->vqs[idx];
 
        if (ioctl == VHOST_SET_VRING_NUM ||
            ioctl == VHOST_SET_VRING_ADDR) {