Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[linux-block.git] / drivers / vhost / vhost.c
index 8995730ce0bfc82d193bd7128e51817fba43de76..b609556824748f5667c9ee952bce584af646e478 100644 (file)
@@ -263,34 +263,37 @@ bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
 }
 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
 
-void vhost_vq_flush(struct vhost_virtqueue *vq)
-{
-       struct vhost_flush_struct flush;
-
-       init_completion(&flush.wait_event);
-       vhost_work_init(&flush.work, vhost_flush_work);
-
-       if (vhost_vq_work_queue(vq, &flush.work))
-               wait_for_completion(&flush.wait_event);
-}
-EXPORT_SYMBOL_GPL(vhost_vq_flush);
-
 /**
- * vhost_worker_flush - flush a worker
+ * __vhost_worker_flush - flush a worker
  * @worker: worker to flush
  *
- * This does not use RCU to protect the worker, so the device or worker
- * mutex must be held.
+ * The worker's flush_mutex must be held.
  */
-static void vhost_worker_flush(struct vhost_worker *worker)
+static void __vhost_worker_flush(struct vhost_worker *worker)
 {
        struct vhost_flush_struct flush;
 
+       if (!worker->attachment_cnt || worker->killed)
+               return;
+
        init_completion(&flush.wait_event);
        vhost_work_init(&flush.work, vhost_flush_work);
 
        vhost_worker_queue(worker, &flush.work);
+       /*
+        * Drop mutex in case our worker is killed and it needs to take the
+        * mutex to force cleanup.
+        */
+       mutex_unlock(&worker->mutex);
        wait_for_completion(&flush.wait_event);
+       mutex_lock(&worker->mutex);
+}
+
+static void vhost_worker_flush(struct vhost_worker *worker)
+{
+       mutex_lock(&worker->mutex);
+       __vhost_worker_flush(worker);
+       mutex_unlock(&worker->mutex);
 }
 
 void vhost_dev_flush(struct vhost_dev *dev)
@@ -298,15 +301,8 @@ void vhost_dev_flush(struct vhost_dev *dev)
        struct vhost_worker *worker;
        unsigned long i;
 
-       xa_for_each(&dev->worker_xa, i, worker) {
-               mutex_lock(&worker->mutex);
-               if (!worker->attachment_cnt) {
-                       mutex_unlock(&worker->mutex);
-                       continue;
-               }
+       xa_for_each(&dev->worker_xa, i, worker)
                vhost_worker_flush(worker);
-               mutex_unlock(&worker->mutex);
-       }
 }
 EXPORT_SYMBOL_GPL(vhost_dev_flush);
 
@@ -392,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        __vhost_vq_meta_reset(vq);
 }
 
-static bool vhost_worker(void *data)
+static bool vhost_run_work_list(void *data)
 {
        struct vhost_worker *worker = data;
        struct vhost_work *work, *work_next;
@@ -417,6 +413,40 @@ static bool vhost_worker(void *data)
        return !!node;
 }
 
+static void vhost_worker_killed(void *data)
+{
+       struct vhost_worker *worker = data;
+       struct vhost_dev *dev = worker->dev;
+       struct vhost_virtqueue *vq;
+       int i, attach_cnt = 0;
+
+       mutex_lock(&worker->mutex);
+       worker->killed = true;
+
+       for (i = 0; i < dev->nvqs; i++) {
+               vq = dev->vqs[i];
+
+               mutex_lock(&vq->mutex);
+               if (worker ==
+                   rcu_dereference_check(vq->worker,
+                                         lockdep_is_held(&vq->mutex))) {
+                       rcu_assign_pointer(vq->worker, NULL);
+                       attach_cnt++;
+               }
+               mutex_unlock(&vq->mutex);
+       }
+
+       worker->attachment_cnt -= attach_cnt;
+       if (attach_cnt)
+               synchronize_rcu();
+       /*
+        * Finish vhost_worker_flush calls and any other works that snuck in
+        * before the synchronize_rcu.
+        */
+       vhost_run_work_list(worker);
+       mutex_unlock(&worker->mutex);
+}
+
 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
 {
        kfree(vq->indirect);
@@ -631,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
        if (!worker)
                return NULL;
 
+       worker->dev = dev;
        snprintf(name, sizeof(name), "vhost-%d", current->pid);
 
-       vtsk = vhost_task_create(vhost_worker, worker, name);
+       vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
+                                worker, name);
        if (!vtsk)
                goto free_worker;
 
@@ -664,22 +696,37 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
 {
        struct vhost_worker *old_worker;
 
-       old_worker = rcu_dereference_check(vq->worker,
-                                          lockdep_is_held(&vq->dev->mutex));
-
        mutex_lock(&worker->mutex);
-       worker->attachment_cnt++;
-       mutex_unlock(&worker->mutex);
+       if (worker->killed) {
+               mutex_unlock(&worker->mutex);
+               return;
+       }
+
+       mutex_lock(&vq->mutex);
+
+       old_worker = rcu_dereference_check(vq->worker,
+                                          lockdep_is_held(&vq->mutex));
        rcu_assign_pointer(vq->worker, worker);
+       worker->attachment_cnt++;
 
-       if (!old_worker)
+       if (!old_worker) {
+               mutex_unlock(&vq->mutex);
+               mutex_unlock(&worker->mutex);
                return;
+       }
+       mutex_unlock(&vq->mutex);
+       mutex_unlock(&worker->mutex);
+
        /*
         * Take the worker mutex to make sure we see the work queued from
         * device wide flushes which doesn't use RCU for execution.
         */
        mutex_lock(&old_worker->mutex);
-       old_worker->attachment_cnt--;
+       if (old_worker->killed) {
+               mutex_unlock(&old_worker->mutex);
+               return;
+       }
+
        /*
         * We don't want to call synchronize_rcu for every vq during setup
         * because it will slow down VM startup. If we haven't done
@@ -690,6 +737,8 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
        mutex_lock(&vq->mutex);
        if (!vhost_vq_get_backend(vq) && !vq->kick) {
                mutex_unlock(&vq->mutex);
+
+               old_worker->attachment_cnt--;
                mutex_unlock(&old_worker->mutex);
                /*
                 * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
@@ -705,7 +754,8 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
        /* Make sure new vq queue/flush/poll calls see the new worker */
        synchronize_rcu();
        /* Make sure whatever was queued gets run */
-       vhost_worker_flush(old_worker);
+       __vhost_worker_flush(old_worker);
+       old_worker->attachment_cnt--;
        mutex_unlock(&old_worker->mutex);
 }
 
@@ -754,10 +804,16 @@ static int vhost_free_worker(struct vhost_dev *dev,
                return -ENODEV;
 
        mutex_lock(&worker->mutex);
-       if (worker->attachment_cnt) {
+       if (worker->attachment_cnt || worker->killed) {
                mutex_unlock(&worker->mutex);
                return -EBUSY;
        }
+       /*
+        * A flush might have raced and snuck in before attachment_cnt was set
+        * to zero. Make sure flushes are flushed from the queue before
+        * freeing.
+        */
+       __vhost_worker_flush(worker);
        mutex_unlock(&worker->mutex);
 
        vhost_worker_destroy(dev, worker);