Merge tag 's390-6.4-1' of git://git.kernel.org/pub/scm/linux/kernel/git/s390/linux
[linux-2.6-block.git] / drivers / iommu / iommu-sva.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Helpers for IOMMU drivers implementing SVA
4  */
5 #include <linux/mmu_context.h>
6 #include <linux/mutex.h>
7 #include <linux/sched/mm.h>
8 #include <linux/iommu.h>
9
10 #include "iommu-sva.h"
11
12 static DEFINE_MUTEX(iommu_sva_lock);
13 static DECLARE_IOASID_SET(iommu_sva_pasid);
14
15 /**
16  * iommu_sva_alloc_pasid - Allocate a PASID for the mm
17  * @mm: the mm
18  * @min: minimum PASID value (inclusive)
19  * @max: maximum PASID value (inclusive)
20  *
21  * Try to allocate a PASID for this mm, or take a reference to the existing one
22  * provided it fits within the [@min, @max] range. On success the PASID is
23  * available in mm->pasid and will be available for the lifetime of the mm.
24  *
25  * Returns 0 on success and < 0 on error.
26  */
27 int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
28 {
29         int ret = 0;
30         ioasid_t pasid;
31
32         if (min == INVALID_IOASID || max == INVALID_IOASID ||
33             min == 0 || max < min)
34                 return -EINVAL;
35
36         if (!arch_pgtable_dma_compat(mm))
37                 return -EBUSY;
38
39         mutex_lock(&iommu_sva_lock);
40         /* Is a PASID already associated with this mm? */
41         if (mm_valid_pasid(mm)) {
42                 if (mm->pasid < min || mm->pasid >= max)
43                         ret = -EOVERFLOW;
44                 goto out;
45         }
46
47         pasid = ioasid_alloc(&iommu_sva_pasid, min, max, mm);
48         if (pasid == INVALID_IOASID)
49                 ret = -ENOMEM;
50         else
51                 mm_pasid_set(mm, pasid);
52 out:
53         mutex_unlock(&iommu_sva_lock);
54         return ret;
55 }
56 EXPORT_SYMBOL_GPL(iommu_sva_alloc_pasid);
57
58 /* ioasid_find getter() requires a void * argument */
59 static bool __mmget_not_zero(void *mm)
60 {
61         return mmget_not_zero(mm);
62 }
63
64 /**
65  * iommu_sva_find() - Find mm associated to the given PASID
66  * @pasid: Process Address Space ID assigned to the mm
67  *
68  * On success a reference to the mm is taken, and must be released with mmput().
69  *
70  * Returns the mm corresponding to this PASID, or an error if not found.
71  */
72 struct mm_struct *iommu_sva_find(ioasid_t pasid)
73 {
74         return ioasid_find(&iommu_sva_pasid, pasid, __mmget_not_zero);
75 }
76 EXPORT_SYMBOL_GPL(iommu_sva_find);
77
78 /**
79  * iommu_sva_bind_device() - Bind a process address space to a device
80  * @dev: the device
81  * @mm: the mm to bind, caller must hold a reference to mm_users
82  *
83  * Create a bond between device and address space, allowing the device to
84  * access the mm using the PASID returned by iommu_sva_get_pasid(). If a
85  * bond already exists between @device and @mm, an additional internal
86  * reference is taken. Caller must call iommu_sva_unbind_device()
87  * to release each reference.
88  *
89  * iommu_dev_enable_feature(dev, IOMMU_DEV_FEAT_SVA) must be called first, to
90  * initialize the required SVA features.
91  *
92  * On error, returns an ERR_PTR value.
93  */
94 struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm)
95 {
96         struct iommu_domain *domain;
97         struct iommu_sva *handle;
98         ioasid_t max_pasids;
99         int ret;
100
101         max_pasids = dev->iommu->max_pasids;
102         if (!max_pasids)
103                 return ERR_PTR(-EOPNOTSUPP);
104
105         /* Allocate mm->pasid if necessary. */
106         ret = iommu_sva_alloc_pasid(mm, 1, max_pasids - 1);
107         if (ret)
108                 return ERR_PTR(ret);
109
110         handle = kzalloc(sizeof(*handle), GFP_KERNEL);
111         if (!handle)
112                 return ERR_PTR(-ENOMEM);
113
114         mutex_lock(&iommu_sva_lock);
115         /* Search for an existing domain. */
116         domain = iommu_get_domain_for_dev_pasid(dev, mm->pasid,
117                                                 IOMMU_DOMAIN_SVA);
118         if (IS_ERR(domain)) {
119                 ret = PTR_ERR(domain);
120                 goto out_unlock;
121         }
122
123         if (domain) {
124                 domain->users++;
125                 goto out;
126         }
127
128         /* Allocate a new domain and set it on device pasid. */
129         domain = iommu_sva_domain_alloc(dev, mm);
130         if (!domain) {
131                 ret = -ENOMEM;
132                 goto out_unlock;
133         }
134
135         ret = iommu_attach_device_pasid(domain, dev, mm->pasid);
136         if (ret)
137                 goto out_free_domain;
138         domain->users = 1;
139 out:
140         mutex_unlock(&iommu_sva_lock);
141         handle->dev = dev;
142         handle->domain = domain;
143
144         return handle;
145
146 out_free_domain:
147         iommu_domain_free(domain);
148 out_unlock:
149         mutex_unlock(&iommu_sva_lock);
150         kfree(handle);
151
152         return ERR_PTR(ret);
153 }
154 EXPORT_SYMBOL_GPL(iommu_sva_bind_device);
155
156 /**
157  * iommu_sva_unbind_device() - Remove a bond created with iommu_sva_bind_device
158  * @handle: the handle returned by iommu_sva_bind_device()
159  *
160  * Put reference to a bond between device and address space. The device should
161  * not be issuing any more transaction for this PASID. All outstanding page
162  * requests for this PASID must have been flushed to the IOMMU.
163  */
164 void iommu_sva_unbind_device(struct iommu_sva *handle)
165 {
166         struct iommu_domain *domain = handle->domain;
167         ioasid_t pasid = domain->mm->pasid;
168         struct device *dev = handle->dev;
169
170         mutex_lock(&iommu_sva_lock);
171         if (--domain->users == 0) {
172                 iommu_detach_device_pasid(domain, dev, pasid);
173                 iommu_domain_free(domain);
174         }
175         mutex_unlock(&iommu_sva_lock);
176         kfree(handle);
177 }
178 EXPORT_SYMBOL_GPL(iommu_sva_unbind_device);
179
180 u32 iommu_sva_get_pasid(struct iommu_sva *handle)
181 {
182         struct iommu_domain *domain = handle->domain;
183
184         return domain->mm->pasid;
185 }
186 EXPORT_SYMBOL_GPL(iommu_sva_get_pasid);
187
188 /*
189  * I/O page fault handler for SVA
190  */
191 enum iommu_page_response_code
192 iommu_sva_handle_iopf(struct iommu_fault *fault, void *data)
193 {
194         vm_fault_t ret;
195         struct vm_area_struct *vma;
196         struct mm_struct *mm = data;
197         unsigned int access_flags = 0;
198         unsigned int fault_flags = FAULT_FLAG_REMOTE;
199         struct iommu_fault_page_request *prm = &fault->prm;
200         enum iommu_page_response_code status = IOMMU_PAGE_RESP_INVALID;
201
202         if (!(prm->flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID))
203                 return status;
204
205         if (!mmget_not_zero(mm))
206                 return status;
207
208         mmap_read_lock(mm);
209
210         vma = find_extend_vma(mm, prm->addr);
211         if (!vma)
212                 /* Unmapped area */
213                 goto out_put_mm;
214
215         if (prm->perm & IOMMU_FAULT_PERM_READ)
216                 access_flags |= VM_READ;
217
218         if (prm->perm & IOMMU_FAULT_PERM_WRITE) {
219                 access_flags |= VM_WRITE;
220                 fault_flags |= FAULT_FLAG_WRITE;
221         }
222
223         if (prm->perm & IOMMU_FAULT_PERM_EXEC) {
224                 access_flags |= VM_EXEC;
225                 fault_flags |= FAULT_FLAG_INSTRUCTION;
226         }
227
228         if (!(prm->perm & IOMMU_FAULT_PERM_PRIV))
229                 fault_flags |= FAULT_FLAG_USER;
230
231         if (access_flags & ~vma->vm_flags)
232                 /* Access fault */
233                 goto out_put_mm;
234
235         ret = handle_mm_fault(vma, prm->addr, fault_flags, NULL);
236         status = ret & VM_FAULT_ERROR ? IOMMU_PAGE_RESP_INVALID :
237                 IOMMU_PAGE_RESP_SUCCESS;
238
239 out_put_mm:
240         mmap_read_unlock(mm);
241         mmput(mm);
242
243         return status;
244 }