vhost_task: Handle SIGKILL by flushing work and exiting
[linux-block.git] / drivers / vhost / vhost.c
index c6448ff3776830e7392f9b92dd079e5cfd92ca80..b609556824748f5667c9ee952bce584af646e478 100644 (file)
@@ -273,7 +273,7 @@ static void __vhost_worker_flush(struct vhost_worker *worker)
 {
        struct vhost_flush_struct flush;
 
-       if (!worker->attachment_cnt)
+       if (!worker->attachment_cnt || worker->killed)
                return;
 
        init_completion(&flush.wait_event);
@@ -388,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        __vhost_vq_meta_reset(vq);
 }
 
-static bool vhost_worker(void *data)
+static bool vhost_run_work_list(void *data)
 {
        struct vhost_worker *worker = data;
        struct vhost_work *work, *work_next;
@@ -413,6 +413,40 @@ static bool vhost_worker(void *data)
        return !!node;
 }
 
+static void vhost_worker_killed(void *data)
+{
+       struct vhost_worker *worker = data;
+       struct vhost_dev *dev = worker->dev;
+       struct vhost_virtqueue *vq;
+       int i, attach_cnt = 0;
+
+       mutex_lock(&worker->mutex);
+       worker->killed = true;
+
+       for (i = 0; i < dev->nvqs; i++) {
+               vq = dev->vqs[i];
+
+               mutex_lock(&vq->mutex);
+               if (worker ==
+                   rcu_dereference_check(vq->worker,
+                                         lockdep_is_held(&vq->mutex))) {
+                       rcu_assign_pointer(vq->worker, NULL);
+                       attach_cnt++;
+               }
+               mutex_unlock(&vq->mutex);
+       }
+
+       worker->attachment_cnt -= attach_cnt;
+       if (attach_cnt)
+               synchronize_rcu();
+       /*
+        * Finish vhost_worker_flush calls and any other works that snuck in
+        * before the synchronize_rcu.
+        */
+       vhost_run_work_list(worker);
+       mutex_unlock(&worker->mutex);
+}
+
 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
 {
        kfree(vq->indirect);
@@ -627,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
        if (!worker)
                return NULL;
 
+       worker->dev = dev;
        snprintf(name, sizeof(name), "vhost-%d", current->pid);
 
-       vtsk = vhost_task_create(vhost_worker, worker, name);
+       vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
+                                worker, name);
        if (!vtsk)
                goto free_worker;
 
@@ -661,6 +697,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
        struct vhost_worker *old_worker;
 
        mutex_lock(&worker->mutex);
+       if (worker->killed) {
+               mutex_unlock(&worker->mutex);
+               return;
+       }
+
        mutex_lock(&vq->mutex);
 
        old_worker = rcu_dereference_check(vq->worker,
@@ -681,6 +722,11 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
         * device wide flushes which doesn't use RCU for execution.
         */
        mutex_lock(&old_worker->mutex);
+       if (old_worker->killed) {
+               mutex_unlock(&old_worker->mutex);
+               return;
+       }
+
        /*
         * We don't want to call synchronize_rcu for every vq during setup
         * because it will slow down VM startup. If we haven't done
@@ -758,7 +804,7 @@ static int vhost_free_worker(struct vhost_dev *dev,
                return -ENODEV;
 
        mutex_lock(&worker->mutex);
-       if (worker->attachment_cnt) {
+       if (worker->attachment_cnt || worker->killed) {
                mutex_unlock(&worker->mutex);
                return -EBUSY;
        }