iommu/arm-smmu: Correct group reference count
[linux-2.6-block.git] / drivers / iommu / arm-smmu-v3.c
index 2e3e235f509c12f4cecc88596579ced2825abb87..3ea4d576bf087f0be0ea9457229ef179f26dc660 100644 (file)
@@ -1809,13 +1809,13 @@ static int arm_smmu_add_device(struct device *dev)
                smmu = arm_smmu_get_for_pci_dev(pdev);
                if (!smmu) {
                        ret = -ENOENT;
-                       goto out_put_group;
+                       goto out_remove_dev;
                }
 
                smmu_group = kzalloc(sizeof(*smmu_group), GFP_KERNEL);
                if (!smmu_group) {
                        ret = -ENOMEM;
-                       goto out_put_group;
+                       goto out_remove_dev;
                }
 
                smmu_group->ste.valid   = true;
@@ -1831,20 +1831,20 @@ static int arm_smmu_add_device(struct device *dev)
        for (i = 0; i < smmu_group->num_sids; ++i) {
                /* If we already know about this SID, then we're done */
                if (smmu_group->sids[i] == sid)
-                       return 0;
+                       goto out_put_group;
        }
 
        /* Check the SID is in range of the SMMU and our stream table */
        if (!arm_smmu_sid_in_range(smmu, sid)) {
                ret = -ERANGE;
-               goto out_put_group;
+               goto out_remove_dev;
        }
 
        /* Ensure l2 strtab is initialised */
        if (smmu->features & ARM_SMMU_FEAT_2_LVL_STRTAB) {
                ret = arm_smmu_init_l2_strtab(smmu, sid);
                if (ret)
-                       goto out_put_group;
+                       goto out_remove_dev;
        }
 
        /* Resize the SID array for the group */
@@ -1854,15 +1854,19 @@ static int arm_smmu_add_device(struct device *dev)
        if (!sids) {
                smmu_group->num_sids--;
                ret = -ENOMEM;
-               goto out_put_group;
+               goto out_remove_dev;
        }
 
        /* Add the new SID */
        sids[smmu_group->num_sids - 1] = sid;
        smmu_group->sids = sids;
-       return 0;
 
 out_put_group:
+       iommu_group_put(group);
+       return 0;
+
+out_remove_dev:
+       iommu_group_remove_device(dev);
        iommu_group_put(group);
        return ret;
 }