vfio/mtty: Overhaul mtty interrupt handling
[linux-block.git] / samples / vfio-mdev / mtty.c
index 5af00387c519e245cb3a923d0e5b3b7cc0f8e7cc..245db52bedf299a522c1299f70cf97ecc5310515 100644 (file)
@@ -127,7 +127,6 @@ struct serial_port {
 /* State of each mdev device */
 struct mdev_state {
        struct vfio_device vdev;
-       int irq_fd;
        struct eventfd_ctx *intx_evtfd;
        struct eventfd_ctx *msi_evtfd;
        int irq_index;
@@ -141,6 +140,7 @@ struct mdev_state {
        struct mutex rxtx_lock;
        struct vfio_device_info dev_info;
        int nr_ports;
+       u8 intx_mask:1;
 };
 
 static struct mtty_type {
@@ -166,10 +166,6 @@ static const struct file_operations vd_fops = {
 
 static const struct vfio_device_ops mtty_dev_ops;
 
-/* function prototypes */
-
-static int mtty_trigger_interrupt(struct mdev_state *mdev_state);
-
 /* Helper functions */
 
 static void dump_buffer(u8 *buf, uint32_t count)
@@ -186,6 +182,36 @@ static void dump_buffer(u8 *buf, uint32_t count)
 #endif
 }
 
+static bool is_intx(struct mdev_state *mdev_state)
+{
+       return mdev_state->irq_index == VFIO_PCI_INTX_IRQ_INDEX;
+}
+
+static bool is_msi(struct mdev_state *mdev_state)
+{
+       return mdev_state->irq_index == VFIO_PCI_MSI_IRQ_INDEX;
+}
+
+static bool is_noirq(struct mdev_state *mdev_state)
+{
+       return !is_intx(mdev_state) && !is_msi(mdev_state);
+}
+
+static void mtty_trigger_interrupt(struct mdev_state *mdev_state)
+{
+       lockdep_assert_held(&mdev_state->ops_lock);
+
+       if (is_msi(mdev_state)) {
+               if (mdev_state->msi_evtfd)
+                       eventfd_signal(mdev_state->msi_evtfd, 1);
+       } else if (is_intx(mdev_state)) {
+               if (mdev_state->intx_evtfd && !mdev_state->intx_mask) {
+                       eventfd_signal(mdev_state->intx_evtfd, 1);
+                       mdev_state->intx_mask = true;
+               }
+       }
+}
+
 static void mtty_create_config_space(struct mdev_state *mdev_state)
 {
        /* PCI dev ID */
@@ -921,6 +947,25 @@ write_err:
        return -EFAULT;
 }
 
+static void mtty_disable_intx(struct mdev_state *mdev_state)
+{
+       if (mdev_state->intx_evtfd) {
+               eventfd_ctx_put(mdev_state->intx_evtfd);
+               mdev_state->intx_evtfd = NULL;
+               mdev_state->intx_mask = false;
+               mdev_state->irq_index = -1;
+       }
+}
+
+static void mtty_disable_msi(struct mdev_state *mdev_state)
+{
+       if (mdev_state->msi_evtfd) {
+               eventfd_ctx_put(mdev_state->msi_evtfd);
+               mdev_state->msi_evtfd = NULL;
+               mdev_state->irq_index = -1;
+       }
+}
+
 static int mtty_set_irqs(struct mdev_state *mdev_state, uint32_t flags,
                         unsigned int index, unsigned int start,
                         unsigned int count, void *data)
@@ -932,59 +977,113 @@ static int mtty_set_irqs(struct mdev_state *mdev_state, uint32_t flags,
        case VFIO_PCI_INTX_IRQ_INDEX:
                switch (flags & VFIO_IRQ_SET_ACTION_TYPE_MASK) {
                case VFIO_IRQ_SET_ACTION_MASK:
+                       if (!is_intx(mdev_state) || start != 0 || count != 1) {
+                               ret = -EINVAL;
+                               break;
+                       }
+
+                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
+                               mdev_state->intx_mask = true;
+                       } else if (flags & VFIO_IRQ_SET_DATA_BOOL) {
+                               uint8_t mask = *(uint8_t *)data;
+
+                               if (mask)
+                                       mdev_state->intx_mask = true;
+                       } else if (flags &  VFIO_IRQ_SET_DATA_EVENTFD) {
+                               ret = -ENOTTY; /* No support for mask fd */
+                       }
+                       break;
                case VFIO_IRQ_SET_ACTION_UNMASK:
+                       if (!is_intx(mdev_state) || start != 0 || count != 1) {
+                               ret = -EINVAL;
+                               break;
+                       }
+
+                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
+                               mdev_state->intx_mask = false;
+                       } else if (flags & VFIO_IRQ_SET_DATA_BOOL) {
+                               uint8_t mask = *(uint8_t *)data;
+
+                               if (mask)
+                                       mdev_state->intx_mask = false;
+                       } else if (flags &  VFIO_IRQ_SET_DATA_EVENTFD) {
+                               ret = -ENOTTY; /* No support for unmask fd */
+                       }
                        break;
                case VFIO_IRQ_SET_ACTION_TRIGGER:
-               {
-                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
-                               pr_info("%s: disable INTx\n", __func__);
-                               if (mdev_state->intx_evtfd)
-                                       eventfd_ctx_put(mdev_state->intx_evtfd);
+                       if (is_intx(mdev_state) && !count &&
+                           (flags & VFIO_IRQ_SET_DATA_NONE)) {
+                               mtty_disable_intx(mdev_state);
+                               break;
+                       }
+
+                       if (!(is_intx(mdev_state) || is_noirq(mdev_state)) ||
+                           start != 0 || count != 1) {
+                               ret = -EINVAL;
                                break;
                        }
 
                        if (flags & VFIO_IRQ_SET_DATA_EVENTFD) {
                                int fd = *(int *)data;
+                               struct eventfd_ctx *evt;
+
+                               mtty_disable_intx(mdev_state);
+
+                               if (fd < 0)
+                                       break;
 
-                               if (fd > 0) {
-                                       struct eventfd_ctx *evt;
-
-                                       evt = eventfd_ctx_fdget(fd);
-                                       if (IS_ERR(evt)) {
-                                               ret = PTR_ERR(evt);
-                                               break;
-                                       }
-                                       mdev_state->intx_evtfd = evt;
-                                       mdev_state->irq_fd = fd;
-                                       mdev_state->irq_index = index;
+                               evt = eventfd_ctx_fdget(fd);
+                               if (IS_ERR(evt)) {
+                                       ret = PTR_ERR(evt);
                                        break;
                                }
+                               mdev_state->intx_evtfd = evt;
+                               mdev_state->irq_index = index;
+                               break;
+                       }
+
+                       if (!is_intx(mdev_state)) {
+                               ret = -EINVAL;
+                               break;
+                       }
+
+                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
+                               mtty_trigger_interrupt(mdev_state);
+                       } else if (flags & VFIO_IRQ_SET_DATA_BOOL) {
+                               uint8_t trigger = *(uint8_t *)data;
+
+                               if (trigger)
+                                       mtty_trigger_interrupt(mdev_state);
                        }
                        break;
                }
-               }
                break;
        case VFIO_PCI_MSI_IRQ_INDEX:
                switch (flags & VFIO_IRQ_SET_ACTION_TYPE_MASK) {
                case VFIO_IRQ_SET_ACTION_MASK:
                case VFIO_IRQ_SET_ACTION_UNMASK:
+                       ret = -ENOTTY;
                        break;
                case VFIO_IRQ_SET_ACTION_TRIGGER:
-                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
-                               if (mdev_state->msi_evtfd)
-                                       eventfd_ctx_put(mdev_state->msi_evtfd);
-                               pr_info("%s: disable MSI\n", __func__);
-                               mdev_state->irq_index = VFIO_PCI_INTX_IRQ_INDEX;
+                       if (is_msi(mdev_state) && !count &&
+                           (flags & VFIO_IRQ_SET_DATA_NONE)) {
+                               mtty_disable_msi(mdev_state);
                                break;
                        }
+
+                       if (!(is_msi(mdev_state) || is_noirq(mdev_state)) ||
+                           start != 0 || count != 1) {
+                               ret = -EINVAL;
+                               break;
+                       }
+
                        if (flags & VFIO_IRQ_SET_DATA_EVENTFD) {
                                int fd = *(int *)data;
                                struct eventfd_ctx *evt;
 
-                               if (fd <= 0)
-                                       break;
+                               mtty_disable_msi(mdev_state);
 
-                               if (mdev_state->msi_evtfd)
+                               if (fd < 0)
                                        break;
 
                                evt = eventfd_ctx_fdget(fd);
@@ -993,20 +1092,37 @@ static int mtty_set_irqs(struct mdev_state *mdev_state, uint32_t flags,
                                        break;
                                }
                                mdev_state->msi_evtfd = evt;
-                               mdev_state->irq_fd = fd;
                                mdev_state->irq_index = index;
+                               break;
+                       }
+
+                       if (!is_msi(mdev_state)) {
+                               ret = -EINVAL;
+                               break;
+                       }
+
+                       if (flags & VFIO_IRQ_SET_DATA_NONE) {
+                               mtty_trigger_interrupt(mdev_state);
+                       } else if (flags & VFIO_IRQ_SET_DATA_BOOL) {
+                               uint8_t trigger = *(uint8_t *)data;
+
+                               if (trigger)
+                                       mtty_trigger_interrupt(mdev_state);
                        }
                        break;
-       }
-       break;
+               }
+               break;
        case VFIO_PCI_MSIX_IRQ_INDEX:
-               pr_info("%s: MSIX_IRQ\n", __func__);
+               dev_dbg(mdev_state->vdev.dev, "%s: MSIX_IRQ\n", __func__);
+               ret = -ENOTTY;
                break;
        case VFIO_PCI_ERR_IRQ_INDEX:
-               pr_info("%s: ERR_IRQ\n", __func__);
+               dev_dbg(mdev_state->vdev.dev, "%s: ERR_IRQ\n", __func__);
+               ret = -ENOTTY;
                break;
        case VFIO_PCI_REQ_IRQ_INDEX:
-               pr_info("%s: REQ_IRQ\n", __func__);
+               dev_dbg(mdev_state->vdev.dev, "%s: REQ_IRQ\n", __func__);
+               ret = -ENOTTY;
                break;
        }
 
@@ -1014,33 +1130,6 @@ static int mtty_set_irqs(struct mdev_state *mdev_state, uint32_t flags,
        return ret;
 }
 
-static int mtty_trigger_interrupt(struct mdev_state *mdev_state)
-{
-       int ret = -1;
-
-       if ((mdev_state->irq_index == VFIO_PCI_MSI_IRQ_INDEX) &&
-           (!mdev_state->msi_evtfd))
-               return -EINVAL;
-       else if ((mdev_state->irq_index == VFIO_PCI_INTX_IRQ_INDEX) &&
-                (!mdev_state->intx_evtfd)) {
-               pr_info("%s: Intr eventfd not found\n", __func__);
-               return -EINVAL;
-       }
-
-       if (mdev_state->irq_index == VFIO_PCI_MSI_IRQ_INDEX)
-               ret = eventfd_signal(mdev_state->msi_evtfd, 1);
-       else
-               ret = eventfd_signal(mdev_state->intx_evtfd, 1);
-
-#if defined(DEBUG_INTR)
-       pr_info("Intx triggered\n");
-#endif
-       if (ret != 1)
-               pr_err("%s: eventfd signal failed (%d)\n", __func__, ret);
-
-       return ret;
-}
-
 static int mtty_get_region_info(struct mdev_state *mdev_state,
                         struct vfio_region_info *region_info,
                         u16 *cap_type_id, void **cap_type)
@@ -1084,22 +1173,16 @@ static int mtty_get_region_info(struct mdev_state *mdev_state,
 
 static int mtty_get_irq_info(struct vfio_irq_info *irq_info)
 {
-       switch (irq_info->index) {
-       case VFIO_PCI_INTX_IRQ_INDEX:
-       case VFIO_PCI_MSI_IRQ_INDEX:
-       case VFIO_PCI_REQ_IRQ_INDEX:
-               break;
-
-       default:
+       if (irq_info->index != VFIO_PCI_INTX_IRQ_INDEX &&
+           irq_info->index != VFIO_PCI_MSI_IRQ_INDEX)
                return -EINVAL;
-       }
 
        irq_info->flags = VFIO_IRQ_INFO_EVENTFD;
        irq_info->count = 1;
 
        if (irq_info->index == VFIO_PCI_INTX_IRQ_INDEX)
-               irq_info->flags |= (VFIO_IRQ_INFO_MASKABLE |
-                               VFIO_IRQ_INFO_AUTOMASKED);
+               irq_info->flags |= VFIO_IRQ_INFO_MASKABLE |
+                                  VFIO_IRQ_INFO_AUTOMASKED;
        else
                irq_info->flags |= VFIO_IRQ_INFO_NORESIZE;
 
@@ -1262,6 +1345,15 @@ static unsigned int mtty_get_available(struct mdev_type *mtype)
        return atomic_read(&mdev_avail_ports) / type->nr_ports;
 }
 
+static void mtty_close(struct vfio_device *vdev)
+{
+       struct mdev_state *mdev_state =
+                               container_of(vdev, struct mdev_state, vdev);
+
+       mtty_disable_intx(mdev_state);
+       mtty_disable_msi(mdev_state);
+}
+
 static const struct vfio_device_ops mtty_dev_ops = {
        .name = "vfio-mtty",
        .init = mtty_init_dev,
@@ -1273,6 +1365,7 @@ static const struct vfio_device_ops mtty_dev_ops = {
        .unbind_iommufd = vfio_iommufd_emulated_unbind,
        .attach_ioas    = vfio_iommufd_emulated_attach_ioas,
        .detach_ioas    = vfio_iommufd_emulated_detach_ioas,
+       .close_device   = mtty_close,
 };
 
 static struct mdev_driver mtty_driver = {