iommufd: Share iommufd_hwpt_alloc with IOMMUFD_OBJ_HWPT_NESTED
authorNicolin Chen <nicolinc@nvidia.com>
Thu, 26 Oct 2023 04:39:33 +0000 (21:39 -0700)
committerJason Gunthorpe <jgg@nvidia.com>
Thu, 26 Oct 2023 14:15:57 +0000 (11:15 -0300)
Allow iommufd_hwpt_alloc() to have a common routine but jump to different
allocators corresponding to different user input pt_obj types, either an
IOMMUFD_OBJ_IOAS for a PAGING hwpt or an IOMMUFD_OBJ_HWPT_PAGING as the
parent for a NESTED hwpt.

Also, move the "flags" validation to the hwpt allocator (paging), so that
later the hwpt_nested allocator can do its own separate flags validation.

Link: https://lore.kernel.org/r/20231026043938.63898-6-yi.l.liu@intel.com
Signed-off-by: Nicolin Chen <nicolinc@nvidia.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
drivers/iommu/iommufd/hw_pagetable.c

index 39b8b625b48d406c369f7d0fc6d5db0f8dbfa744..6bce9af0cb8d68ef6174c498921c481c5e2bd049 100644 (file)
@@ -82,6 +82,8 @@ iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
                          struct iommufd_device *idev, u32 flags,
                          bool immediate_attach)
 {
+       const u32 valid_flags = IOMMU_HWPT_ALLOC_NEST_PARENT |
+                               IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
        const struct iommu_ops *ops = dev_iommu_ops(idev->dev);
        struct iommufd_hwpt_paging *hwpt_paging;
        struct iommufd_hw_pagetable *hwpt;
@@ -91,6 +93,8 @@ iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 
        if (flags && !ops->domain_alloc_user)
                return ERR_PTR(-EOPNOTSUPP);
+       if (flags & ~valid_flags)
+               return ERR_PTR(-EOPNOTSUPP);
 
        hwpt_paging = __iommufd_object_alloc(
                ictx, hwpt_paging, IOMMUFD_OBJ_HWPT_PAGING, common.obj);
@@ -167,35 +171,41 @@ out_abort:
 int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
 {
        struct iommu_hwpt_alloc *cmd = ucmd->cmd;
-       struct iommufd_hwpt_paging *hwpt_paging;
        struct iommufd_hw_pagetable *hwpt;
+       struct iommufd_ioas *ioas = NULL;
+       struct iommufd_object *pt_obj;
        struct iommufd_device *idev;
-       struct iommufd_ioas *ioas;
        int rc;
 
-       if ((cmd->flags & ~(IOMMU_HWPT_ALLOC_NEST_PARENT |
-                           IOMMU_HWPT_ALLOC_DIRTY_TRACKING)) ||
-           cmd->__reserved)
+       if (cmd->__reserved)
                return -EOPNOTSUPP;
 
        idev = iommufd_get_device(ucmd, cmd->dev_id);
        if (IS_ERR(idev))
                return PTR_ERR(idev);
 
-       ioas = iommufd_get_ioas(ucmd->ictx, cmd->pt_id);
-       if (IS_ERR(ioas)) {
-               rc = PTR_ERR(ioas);
+       pt_obj = iommufd_get_object(ucmd->ictx, cmd->pt_id, IOMMUFD_OBJ_ANY);
+       if (IS_ERR(pt_obj)) {
+               rc = -EINVAL;
                goto out_put_idev;
        }
 
-       mutex_lock(&ioas->mutex);
-       hwpt_paging = iommufd_hwpt_paging_alloc(ucmd->ictx, ioas, idev,
-                                               cmd->flags, false);
-       if (IS_ERR(hwpt_paging)) {
-               rc = PTR_ERR(hwpt_paging);
-               goto out_unlock;
+       if (pt_obj->type == IOMMUFD_OBJ_IOAS) {
+               struct iommufd_hwpt_paging *hwpt_paging;
+
+               ioas = container_of(pt_obj, struct iommufd_ioas, obj);
+               mutex_lock(&ioas->mutex);
+               hwpt_paging = iommufd_hwpt_paging_alloc(ucmd->ictx, ioas, idev,
+                                                       cmd->flags, false);
+               if (IS_ERR(hwpt_paging)) {
+                       rc = PTR_ERR(hwpt_paging);
+                       goto out_unlock;
+               }
+               hwpt = &hwpt_paging->common;
+       } else {
+               rc = -EINVAL;
+               goto out_put_pt;
        }
-       hwpt = &hwpt_paging->common;
 
        cmd->out_hwpt_id = hwpt->obj.id;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
@@ -207,8 +217,10 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
 out_hwpt:
        iommufd_object_abort_and_destroy(ucmd->ictx, &hwpt->obj);
 out_unlock:
-       mutex_unlock(&ioas->mutex);
-       iommufd_put_object(&ioas->obj);
+       if (ioas)
+               mutex_unlock(&ioas->mutex);
+out_put_pt:
+       iommufd_put_object(pt_obj);
 out_put_idev:
        iommufd_put_object(&idev->obj);
        return rc;