virtio_pci: harden MSI-X interrupts
authorJason Wang <jasowang@redhat.com>
Tue, 19 Oct 2021 07:01:46 +0000 (15:01 +0800)
committerMichael S. Tsirkin <mst@redhat.com>
Mon, 1 Nov 2021 09:26:48 +0000 (05:26 -0400)
We used to synchronize pending MSI-X irq handlers via
synchronize_irq(), this may not work for the untrusted device which
may keep sending interrupts after reset which may lead unexpected
results. Similarly, we should not enable MSI-X interrupt until the
device is ready. So this patch fixes those two issues by:

1) switching to use disable_irq() to prevent the virtio interrupt
   handlers to be called after the device is reset.
2) using IRQF_NO_AUTOEN and enable the MSI-X irq during .ready()

This can make sure the virtio interrupt handler won't be called before
virtio_device_ready() and after reset.

Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Paul E. McKenney <paulmck@kernel.org>
Signed-off-by: Jason Wang <jasowang@redhat.com>
Link: https://lore.kernel.org/r/20211019070152.8236-5-jasowang@redhat.com
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
drivers/virtio/virtio_pci_common.c
drivers/virtio/virtio_pci_common.h
drivers/virtio/virtio_pci_legacy.c
drivers/virtio/virtio_pci_modern.c

index d724f676608ba34a973803eb8643bf3bba604262..3f51fdb7be45431ae562c84cbbef19277197811d 100644 (file)
@@ -24,8 +24,8 @@ MODULE_PARM_DESC(force_legacy,
                 "Force legacy mode for transitional virtio 1 devices");
 #endif
 
-/* wait for pending irq handlers */
-void vp_synchronize_vectors(struct virtio_device *vdev)
+/* disable irq handlers */
+void vp_disable_cbs(struct virtio_device *vdev)
 {
        struct virtio_pci_device *vp_dev = to_vp_device(vdev);
        int i;
@@ -34,7 +34,20 @@ void vp_synchronize_vectors(struct virtio_device *vdev)
                synchronize_irq(vp_dev->pci_dev->irq);
 
        for (i = 0; i < vp_dev->msix_vectors; ++i)
-               synchronize_irq(pci_irq_vector(vp_dev->pci_dev, i));
+               disable_irq(pci_irq_vector(vp_dev->pci_dev, i));
+}
+
+/* enable irq handlers */
+void vp_enable_cbs(struct virtio_device *vdev)
+{
+       struct virtio_pci_device *vp_dev = to_vp_device(vdev);
+       int i;
+
+       if (vp_dev->intx_enabled)
+               return;
+
+       for (i = 0; i < vp_dev->msix_vectors; ++i)
+               enable_irq(pci_irq_vector(vp_dev->pci_dev, i));
 }
 
 /* the notify function used when creating a virt queue */
@@ -141,7 +154,8 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors,
        snprintf(vp_dev->msix_names[v], sizeof *vp_dev->msix_names,
                 "%s-config", name);
        err = request_irq(pci_irq_vector(vp_dev->pci_dev, v),
-                         vp_config_changed, 0, vp_dev->msix_names[v],
+                         vp_config_changed, IRQF_NO_AUTOEN,
+                         vp_dev->msix_names[v],
                          vp_dev);
        if (err)
                goto error;
@@ -160,7 +174,8 @@ static int vp_request_msix_vectors(struct virtio_device *vdev, int nvectors,
                snprintf(vp_dev->msix_names[v], sizeof *vp_dev->msix_names,
                         "%s-virtqueues", name);
                err = request_irq(pci_irq_vector(vp_dev->pci_dev, v),
-                                 vp_vring_interrupt, 0, vp_dev->msix_names[v],
+                                 vp_vring_interrupt, IRQF_NO_AUTOEN,
+                                 vp_dev->msix_names[v],
                                  vp_dev);
                if (err)
                        goto error;
@@ -337,7 +352,7 @@ static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs,
                         "%s-%s",
                         dev_name(&vp_dev->vdev.dev), names[i]);
                err = request_irq(pci_irq_vector(vp_dev->pci_dev, msix_vec),
-                                 vring_interrupt, 0,
+                                 vring_interrupt, IRQF_NO_AUTOEN,
                                  vp_dev->msix_names[msix_vec],
                                  vqs[i]);
                if (err)
index eb17a29fc7ef1f6cbc3c93a390d03a80f0076fc6..d3c6f72c73906029b5c32451a32ea6c073cbf301 100644 (file)
@@ -101,8 +101,10 @@ static struct virtio_pci_device *to_vp_device(struct virtio_device *vdev)
        return container_of(vdev, struct virtio_pci_device, vdev);
 }
 
-/* wait for pending irq handlers */
-void vp_synchronize_vectors(struct virtio_device *vdev);
+/* disable irq handlers */
+void vp_disable_cbs(struct virtio_device *vdev);
+/* enable irq handlers */
+void vp_enable_cbs(struct virtio_device *vdev);
 /* the notify function used when creating a virt queue */
 bool vp_notify(struct virtqueue *vq);
 /* the config->del_vqs() implementation */
index 82eb437ad9205033f37812b0db4fa64a6a0ef5a9..b3f8128b7983b234489ba0ed87e2f2aafdbf25b5 100644 (file)
@@ -98,8 +98,8 @@ static void vp_reset(struct virtio_device *vdev)
        /* Flush out the status write, and flush in device writes,
         * including MSi-X interrupts, if any. */
        vp_legacy_get_status(&vp_dev->ldev);
-       /* Flush pending VQ/configuration callbacks. */
-       vp_synchronize_vectors(vdev);
+       /* Disable VQ/configuration callbacks. */
+       vp_disable_cbs(vdev);
 }
 
 static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector)
@@ -185,6 +185,7 @@ static void del_vq(struct virtio_pci_vq_info *info)
 }
 
 static const struct virtio_config_ops virtio_pci_config_ops = {
+       .enable_cbs     = vp_enable_cbs,
        .get            = vp_get,
        .set            = vp_set,
        .get_status     = vp_get_status,
index 30654d3a0b41ed6fca1a46f9dac6ac01d925da53..5455bc041fb6916bf265d45565905668fbdf9cf3 100644 (file)
@@ -172,8 +172,8 @@ static void vp_reset(struct virtio_device *vdev)
         */
        while (vp_modern_get_status(mdev))
                msleep(1);
-       /* Flush pending VQ/configuration callbacks. */
-       vp_synchronize_vectors(vdev);
+       /* Disable VQ/configuration callbacks. */
+       vp_disable_cbs(vdev);
 }
 
 static u16 vp_config_vector(struct virtio_pci_device *vp_dev, u16 vector)
@@ -380,6 +380,7 @@ static bool vp_get_shm_region(struct virtio_device *vdev,
 }
 
 static const struct virtio_config_ops virtio_pci_config_nodev_ops = {
+       .enable_cbs     = vp_enable_cbs,
        .get            = NULL,
        .set            = NULL,
        .generation     = vp_generation,
@@ -397,6 +398,7 @@ static const struct virtio_config_ops virtio_pci_config_nodev_ops = {
 };
 
 static const struct virtio_config_ops virtio_pci_config_ops = {
+       .enable_cbs     = vp_enable_cbs,
        .get            = vp_get,
        .set            = vp_set,
        .generation     = vp_generation,