Merge tag 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/rdma/rdma
[linux-2.6-block.git] / drivers / infiniband / core / umem_odp.c
1 /*
2  * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32
33 #include <linux/types.h>
34 #include <linux/sched.h>
35 #include <linux/sched/mm.h>
36 #include <linux/sched/task.h>
37 #include <linux/pid.h>
38 #include <linux/slab.h>
39 #include <linux/export.h>
40 #include <linux/vmalloc.h>
41 #include <linux/hugetlb.h>
42 #include <linux/interval_tree.h>
43 #include <linux/pagemap.h>
44
45 #include <rdma/ib_verbs.h>
46 #include <rdma/ib_umem.h>
47 #include <rdma/ib_umem_odp.h>
48
49 #include "uverbs.h"
50
51 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
52 {
53         mutex_lock(&umem_odp->umem_mutex);
54         if (umem_odp->notifiers_count++ == 0)
55                 /*
56                  * Initialize the completion object for waiting on
57                  * notifiers. Since notifier_count is zero, no one should be
58                  * waiting right now.
59                  */
60                 reinit_completion(&umem_odp->notifier_completion);
61         mutex_unlock(&umem_odp->umem_mutex);
62 }
63
64 static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
65 {
66         mutex_lock(&umem_odp->umem_mutex);
67         /*
68          * This sequence increase will notify the QP page fault that the page
69          * that is going to be mapped in the spte could have been freed.
70          */
71         ++umem_odp->notifiers_seq;
72         if (--umem_odp->notifiers_count == 0)
73                 complete_all(&umem_odp->notifier_completion);
74         mutex_unlock(&umem_odp->umem_mutex);
75 }
76
77 static void ib_umem_notifier_release(struct mmu_notifier *mn,
78                                      struct mm_struct *mm)
79 {
80         struct ib_ucontext_per_mm *per_mm =
81                 container_of(mn, struct ib_ucontext_per_mm, mn);
82         struct rb_node *node;
83
84         down_read(&per_mm->umem_rwsem);
85         if (!per_mm->mn.users)
86                 goto out;
87
88         for (node = rb_first_cached(&per_mm->umem_tree); node;
89              node = rb_next(node)) {
90                 struct ib_umem_odp *umem_odp =
91                         rb_entry(node, struct ib_umem_odp, interval_tree.rb);
92
93                 /*
94                  * Increase the number of notifiers running, to prevent any
95                  * further fault handling on this MR.
96                  */
97                 ib_umem_notifier_start_account(umem_odp);
98                 complete_all(&umem_odp->notifier_completion);
99                 umem_odp->umem.ibdev->ops.invalidate_range(
100                         umem_odp, ib_umem_start(umem_odp),
101                         ib_umem_end(umem_odp));
102         }
103
104 out:
105         up_read(&per_mm->umem_rwsem);
106 }
107
108 static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
109                                              u64 start, u64 end, void *cookie)
110 {
111         ib_umem_notifier_start_account(item);
112         item->umem.ibdev->ops.invalidate_range(item, start, end);
113         return 0;
114 }
115
116 static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
117                                 const struct mmu_notifier_range *range)
118 {
119         struct ib_ucontext_per_mm *per_mm =
120                 container_of(mn, struct ib_ucontext_per_mm, mn);
121         int rc;
122
123         if (mmu_notifier_range_blockable(range))
124                 down_read(&per_mm->umem_rwsem);
125         else if (!down_read_trylock(&per_mm->umem_rwsem))
126                 return -EAGAIN;
127
128         if (!per_mm->mn.users) {
129                 up_read(&per_mm->umem_rwsem);
130                 /*
131                  * At this point users is permanently zero and visible to this
132                  * CPU without a lock, that fact is relied on to skip the unlock
133                  * in range_end.
134                  */
135                 return 0;
136         }
137
138         rc = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
139                                            range->end,
140                                            invalidate_range_start_trampoline,
141                                            mmu_notifier_range_blockable(range),
142                                            NULL);
143         if (rc)
144                 up_read(&per_mm->umem_rwsem);
145         return rc;
146 }
147
148 static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
149                                            u64 end, void *cookie)
150 {
151         ib_umem_notifier_end_account(item);
152         return 0;
153 }
154
155 static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
156                                 const struct mmu_notifier_range *range)
157 {
158         struct ib_ucontext_per_mm *per_mm =
159                 container_of(mn, struct ib_ucontext_per_mm, mn);
160
161         if (unlikely(!per_mm->mn.users))
162                 return;
163
164         rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
165                                       range->end,
166                                       invalidate_range_end_trampoline, true, NULL);
167         up_read(&per_mm->umem_rwsem);
168 }
169
170 static struct mmu_notifier *ib_umem_alloc_notifier(struct mm_struct *mm)
171 {
172         struct ib_ucontext_per_mm *per_mm;
173
174         per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
175         if (!per_mm)
176                 return ERR_PTR(-ENOMEM);
177
178         per_mm->umem_tree = RB_ROOT_CACHED;
179         init_rwsem(&per_mm->umem_rwsem);
180
181         WARN_ON(mm != current->mm);
182         rcu_read_lock();
183         per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
184         rcu_read_unlock();
185         return &per_mm->mn;
186 }
187
188 static void ib_umem_free_notifier(struct mmu_notifier *mn)
189 {
190         struct ib_ucontext_per_mm *per_mm =
191                 container_of(mn, struct ib_ucontext_per_mm, mn);
192
193         WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
194
195         put_pid(per_mm->tgid);
196         kfree(per_mm);
197 }
198
199 static const struct mmu_notifier_ops ib_umem_notifiers = {
200         .release                    = ib_umem_notifier_release,
201         .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
202         .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
203         .alloc_notifier             = ib_umem_alloc_notifier,
204         .free_notifier              = ib_umem_free_notifier,
205 };
206
207 static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp)
208 {
209         struct ib_ucontext_per_mm *per_mm;
210         struct mmu_notifier *mn;
211         int ret;
212
213         umem_odp->umem.is_odp = 1;
214         if (!umem_odp->is_implicit_odp) {
215                 size_t page_size = 1UL << umem_odp->page_shift;
216                 size_t pages;
217
218                 umem_odp->interval_tree.start =
219                         ALIGN_DOWN(umem_odp->umem.address, page_size);
220                 if (check_add_overflow(umem_odp->umem.address,
221                                        (unsigned long)umem_odp->umem.length,
222                                        &umem_odp->interval_tree.last))
223                         return -EOVERFLOW;
224                 umem_odp->interval_tree.last =
225                         ALIGN(umem_odp->interval_tree.last, page_size);
226                 if (unlikely(umem_odp->interval_tree.last < page_size))
227                         return -EOVERFLOW;
228
229                 pages = (umem_odp->interval_tree.last -
230                          umem_odp->interval_tree.start) >>
231                         umem_odp->page_shift;
232                 if (!pages)
233                         return -EINVAL;
234
235                 /*
236                  * Note that the representation of the intervals in the
237                  * interval tree considers the ending point as contained in
238                  * the interval.
239                  */
240                 umem_odp->interval_tree.last--;
241
242                 umem_odp->page_list = kvcalloc(
243                         pages, sizeof(*umem_odp->page_list), GFP_KERNEL);
244                 if (!umem_odp->page_list)
245                         return -ENOMEM;
246
247                 umem_odp->dma_list = kvcalloc(
248                         pages, sizeof(*umem_odp->dma_list), GFP_KERNEL);
249                 if (!umem_odp->dma_list) {
250                         ret = -ENOMEM;
251                         goto out_page_list;
252                 }
253         }
254
255         mn = mmu_notifier_get(&ib_umem_notifiers, umem_odp->umem.owning_mm);
256         if (IS_ERR(mn)) {
257                 ret = PTR_ERR(mn);
258                 goto out_dma_list;
259         }
260         umem_odp->per_mm = per_mm =
261                 container_of(mn, struct ib_ucontext_per_mm, mn);
262
263         mutex_init(&umem_odp->umem_mutex);
264         init_completion(&umem_odp->notifier_completion);
265
266         if (!umem_odp->is_implicit_odp) {
267                 down_write(&per_mm->umem_rwsem);
268                 interval_tree_insert(&umem_odp->interval_tree,
269                                      &per_mm->umem_tree);
270                 up_write(&per_mm->umem_rwsem);
271         }
272         mmgrab(umem_odp->umem.owning_mm);
273
274         return 0;
275
276 out_dma_list:
277         kvfree(umem_odp->dma_list);
278 out_page_list:
279         kvfree(umem_odp->page_list);
280         return ret;
281 }
282
283 /**
284  * ib_umem_odp_alloc_implicit - Allocate a parent implicit ODP umem
285  *
286  * Implicit ODP umems do not have a VA range and do not have any page lists.
287  * They exist only to hold the per_mm reference to help the driver create
288  * children umems.
289  *
290  * @udata: udata from the syscall being used to create the umem
291  * @access: ib_reg_mr access flags
292  */
293 struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
294                                                int access)
295 {
296         struct ib_ucontext *context =
297                 container_of(udata, struct uverbs_attr_bundle, driver_udata)
298                         ->context;
299         struct ib_umem *umem;
300         struct ib_umem_odp *umem_odp;
301         int ret;
302
303         if (access & IB_ACCESS_HUGETLB)
304                 return ERR_PTR(-EINVAL);
305
306         if (!context)
307                 return ERR_PTR(-EIO);
308         if (WARN_ON_ONCE(!context->device->ops.invalidate_range))
309                 return ERR_PTR(-EINVAL);
310
311         umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
312         if (!umem_odp)
313                 return ERR_PTR(-ENOMEM);
314         umem = &umem_odp->umem;
315         umem->ibdev = context->device;
316         umem->writable = ib_access_writable(access);
317         umem->owning_mm = current->mm;
318         umem_odp->is_implicit_odp = 1;
319         umem_odp->page_shift = PAGE_SHIFT;
320
321         ret = ib_init_umem_odp(umem_odp);
322         if (ret) {
323                 kfree(umem_odp);
324                 return ERR_PTR(ret);
325         }
326         return umem_odp;
327 }
328 EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
329
330 /**
331  * ib_umem_odp_alloc_child - Allocate a child ODP umem under an implicit
332  *                           parent ODP umem
333  *
334  * @root: The parent umem enclosing the child. This must be allocated using
335  *        ib_alloc_implicit_odp_umem()
336  * @addr: The starting userspace VA
337  * @size: The length of the userspace VA
338  */
339 struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
340                                             unsigned long addr, size_t size)
341 {
342         /*
343          * Caller must ensure that root cannot be freed during the call to
344          * ib_alloc_odp_umem.
345          */
346         struct ib_umem_odp *odp_data;
347         struct ib_umem *umem;
348         int ret;
349
350         if (WARN_ON(!root->is_implicit_odp))
351                 return ERR_PTR(-EINVAL);
352
353         odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
354         if (!odp_data)
355                 return ERR_PTR(-ENOMEM);
356         umem = &odp_data->umem;
357         umem->ibdev = root->umem.ibdev;
358         umem->length     = size;
359         umem->address    = addr;
360         umem->writable   = root->umem.writable;
361         umem->owning_mm  = root->umem.owning_mm;
362         odp_data->page_shift = PAGE_SHIFT;
363
364         ret = ib_init_umem_odp(odp_data);
365         if (ret) {
366                 kfree(odp_data);
367                 return ERR_PTR(ret);
368         }
369         return odp_data;
370 }
371 EXPORT_SYMBOL(ib_umem_odp_alloc_child);
372
373 /**
374  * ib_umem_odp_get - Create a umem_odp for a userspace va
375  *
376  * @udata: userspace context to pin memory for
377  * @addr: userspace virtual address to start at
378  * @size: length of region to pin
379  * @access: IB_ACCESS_xxx flags for memory being pinned
380  *
381  * The driver should use when the access flags indicate ODP memory. It avoids
382  * pinning, instead, stores the mm for future page fault handling in
383  * conjunction with MMU notifiers.
384  */
385 struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
386                                     size_t size, int access)
387 {
388         struct ib_umem_odp *umem_odp;
389         struct ib_ucontext *context;
390         struct mm_struct *mm;
391         int ret;
392
393         if (!udata)
394                 return ERR_PTR(-EIO);
395
396         context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
397                           ->context;
398         if (!context)
399                 return ERR_PTR(-EIO);
400
401         if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)) ||
402             WARN_ON_ONCE(!context->device->ops.invalidate_range))
403                 return ERR_PTR(-EINVAL);
404
405         umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
406         if (!umem_odp)
407                 return ERR_PTR(-ENOMEM);
408
409         umem_odp->umem.ibdev = context->device;
410         umem_odp->umem.length = size;
411         umem_odp->umem.address = addr;
412         umem_odp->umem.writable = ib_access_writable(access);
413         umem_odp->umem.owning_mm = mm = current->mm;
414
415         umem_odp->page_shift = PAGE_SHIFT;
416         if (access & IB_ACCESS_HUGETLB) {
417                 struct vm_area_struct *vma;
418                 struct hstate *h;
419
420                 down_read(&mm->mmap_sem);
421                 vma = find_vma(mm, ib_umem_start(umem_odp));
422                 if (!vma || !is_vm_hugetlb_page(vma)) {
423                         up_read(&mm->mmap_sem);
424                         ret = -EINVAL;
425                         goto err_free;
426                 }
427                 h = hstate_vma(vma);
428                 umem_odp->page_shift = huge_page_shift(h);
429                 up_read(&mm->mmap_sem);
430         }
431
432         ret = ib_init_umem_odp(umem_odp);
433         if (ret)
434                 goto err_free;
435         return umem_odp;
436
437 err_free:
438         kfree(umem_odp);
439         return ERR_PTR(ret);
440 }
441 EXPORT_SYMBOL(ib_umem_odp_get);
442
443 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
444 {
445         struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
446
447         /*
448          * Ensure that no more pages are mapped in the umem.
449          *
450          * It is the driver's responsibility to ensure, before calling us,
451          * that the hardware will not attempt to access the MR any more.
452          */
453         if (!umem_odp->is_implicit_odp) {
454                 ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
455                                             ib_umem_end(umem_odp));
456                 kvfree(umem_odp->dma_list);
457                 kvfree(umem_odp->page_list);
458         }
459
460         down_write(&per_mm->umem_rwsem);
461         if (!umem_odp->is_implicit_odp) {
462                 interval_tree_remove(&umem_odp->interval_tree,
463                                      &per_mm->umem_tree);
464                 complete_all(&umem_odp->notifier_completion);
465         }
466         /*
467          * NOTE! mmu_notifier_unregister() can happen between a start/end
468          * callback, resulting in a missing end, and thus an unbalanced
469          * lock. This doesn't really matter to us since we are about to kfree
470          * the memory that holds the lock, however LOCKDEP doesn't like this.
471          * Thus we call the mmu_notifier_put under the rwsem and test the
472          * internal users count to reliably see if we are past this point.
473          */
474         mmu_notifier_put(&per_mm->mn);
475         up_write(&per_mm->umem_rwsem);
476
477         mmdrop(umem_odp->umem.owning_mm);
478         kfree(umem_odp);
479 }
480 EXPORT_SYMBOL(ib_umem_odp_release);
481
482 /*
483  * Map for DMA and insert a single page into the on-demand paging page tables.
484  *
485  * @umem: the umem to insert the page to.
486  * @page_index: index in the umem to add the page to.
487  * @page: the page struct to map and add.
488  * @access_mask: access permissions needed for this page.
489  * @current_seq: sequence number for synchronization with invalidations.
490  *               the sequence number is taken from
491  *               umem_odp->notifiers_seq.
492  *
493  * The function returns -EFAULT if the DMA mapping operation fails. It returns
494  * -EAGAIN if a concurrent invalidation prevents us from updating the page.
495  *
496  * The page is released via put_user_page even if the operation failed. For
497  * on-demand pinning, the page is released whenever it isn't stored in the
498  * umem.
499  */
500 static int ib_umem_odp_map_dma_single_page(
501                 struct ib_umem_odp *umem_odp,
502                 int page_index,
503                 struct page *page,
504                 u64 access_mask,
505                 unsigned long current_seq)
506 {
507         struct ib_device *dev = umem_odp->umem.ibdev;
508         dma_addr_t dma_addr;
509         int remove_existing_mapping = 0;
510         int ret = 0;
511
512         /*
513          * Note: we avoid writing if seq is different from the initial seq, to
514          * handle case of a racing notifier. This check also allows us to bail
515          * early if we have a notifier running in parallel with us.
516          */
517         if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
518                 ret = -EAGAIN;
519                 goto out;
520         }
521         if (!(umem_odp->dma_list[page_index])) {
522                 dma_addr =
523                         ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
524                                         DMA_BIDIRECTIONAL);
525                 if (ib_dma_mapping_error(dev, dma_addr)) {
526                         ret = -EFAULT;
527                         goto out;
528                 }
529                 umem_odp->dma_list[page_index] = dma_addr | access_mask;
530                 umem_odp->page_list[page_index] = page;
531                 umem_odp->npages++;
532         } else if (umem_odp->page_list[page_index] == page) {
533                 umem_odp->dma_list[page_index] |= access_mask;
534         } else {
535                 pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
536                        umem_odp->page_list[page_index], page);
537                 /* Better remove the mapping now, to prevent any further
538                  * damage. */
539                 remove_existing_mapping = 1;
540         }
541
542 out:
543         put_user_page(page);
544
545         if (remove_existing_mapping) {
546                 ib_umem_notifier_start_account(umem_odp);
547                 dev->ops.invalidate_range(
548                         umem_odp,
549                         ib_umem_start(umem_odp) +
550                                 (page_index << umem_odp->page_shift),
551                         ib_umem_start(umem_odp) +
552                                 ((page_index + 1) << umem_odp->page_shift));
553                 ib_umem_notifier_end_account(umem_odp);
554                 ret = -EAGAIN;
555         }
556
557         return ret;
558 }
559
560 /**
561  * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
562  *
563  * Pins the range of pages passed in the argument, and maps them to
564  * DMA addresses. The DMA addresses of the mapped pages is updated in
565  * umem_odp->dma_list.
566  *
567  * Returns the number of pages mapped in success, negative error code
568  * for failure.
569  * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
570  * the function from completing its task.
571  * An -ENOENT error code indicates that userspace process is being terminated
572  * and mm was already destroyed.
573  * @umem_odp: the umem to map and pin
574  * @user_virt: the address from which we need to map.
575  * @bcnt: the minimal number of bytes to pin and map. The mapping might be
576  *        bigger due to alignment, and may also be smaller in case of an error
577  *        pinning or mapping a page. The actual pages mapped is returned in
578  *        the return value.
579  * @access_mask: bit mask of the requested access permissions for the given
580  *               range.
581  * @current_seq: the MMU notifiers sequance value for synchronization with
582  *               invalidations. the sequance number is read from
583  *               umem_odp->notifiers_seq before calling this function
584  */
585 int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
586                               u64 bcnt, u64 access_mask,
587                               unsigned long current_seq)
588 {
589         struct task_struct *owning_process  = NULL;
590         struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
591         struct page       **local_page_list = NULL;
592         u64 page_mask, off;
593         int j, k, ret = 0, start_idx, npages = 0;
594         unsigned int flags = 0, page_shift;
595         phys_addr_t p = 0;
596
597         if (access_mask == 0)
598                 return -EINVAL;
599
600         if (user_virt < ib_umem_start(umem_odp) ||
601             user_virt + bcnt > ib_umem_end(umem_odp))
602                 return -EFAULT;
603
604         local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
605         if (!local_page_list)
606                 return -ENOMEM;
607
608         page_shift = umem_odp->page_shift;
609         page_mask = ~(BIT(page_shift) - 1);
610         off = user_virt & (~page_mask);
611         user_virt = user_virt & page_mask;
612         bcnt += off; /* Charge for the first page offset as well. */
613
614         /*
615          * owning_process is allowed to be NULL, this means somehow the mm is
616          * existing beyond the lifetime of the originating process.. Presumably
617          * mmget_not_zero will fail in this case.
618          */
619         owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
620         if (!owning_process || !mmget_not_zero(owning_mm)) {
621                 ret = -EINVAL;
622                 goto out_put_task;
623         }
624
625         if (access_mask & ODP_WRITE_ALLOWED_BIT)
626                 flags |= FOLL_WRITE;
627
628         start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
629         k = start_idx;
630
631         while (bcnt > 0) {
632                 const size_t gup_num_pages = min_t(size_t,
633                                 (bcnt + BIT(page_shift) - 1) >> page_shift,
634                                 PAGE_SIZE / sizeof(struct page *));
635
636                 down_read(&owning_mm->mmap_sem);
637                 /*
638                  * Note: this might result in redundent page getting. We can
639                  * avoid this by checking dma_list to be 0 before calling
640                  * get_user_pages. However, this make the code much more
641                  * complex (and doesn't gain us much performance in most use
642                  * cases).
643                  */
644                 npages = get_user_pages_remote(owning_process, owning_mm,
645                                 user_virt, gup_num_pages,
646                                 flags, local_page_list, NULL, NULL);
647                 up_read(&owning_mm->mmap_sem);
648
649                 if (npages < 0) {
650                         if (npages != -EAGAIN)
651                                 pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
652                         else
653                                 pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
654                         break;
655                 }
656
657                 bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
658                 mutex_lock(&umem_odp->umem_mutex);
659                 for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
660                         if (user_virt & ~page_mask) {
661                                 p += PAGE_SIZE;
662                                 if (page_to_phys(local_page_list[j]) != p) {
663                                         ret = -EFAULT;
664                                         break;
665                                 }
666                                 put_user_page(local_page_list[j]);
667                                 continue;
668                         }
669
670                         ret = ib_umem_odp_map_dma_single_page(
671                                         umem_odp, k, local_page_list[j],
672                                         access_mask, current_seq);
673                         if (ret < 0) {
674                                 if (ret != -EAGAIN)
675                                         pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
676                                 else
677                                         pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
678                                 break;
679                         }
680
681                         p = page_to_phys(local_page_list[j]);
682                         k++;
683                 }
684                 mutex_unlock(&umem_odp->umem_mutex);
685
686                 if (ret < 0) {
687                         /*
688                          * Release pages, remembering that the first page
689                          * to hit an error was already released by
690                          * ib_umem_odp_map_dma_single_page().
691                          */
692                         if (npages - (j + 1) > 0)
693                                 put_user_pages(&local_page_list[j+1],
694                                                npages - (j + 1));
695                         break;
696                 }
697         }
698
699         if (ret >= 0) {
700                 if (npages < 0 && k == start_idx)
701                         ret = npages;
702                 else
703                         ret = k - start_idx;
704         }
705
706         mmput(owning_mm);
707 out_put_task:
708         if (owning_process)
709                 put_task_struct(owning_process);
710         free_page((unsigned long)local_page_list);
711         return ret;
712 }
713 EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
714
715 void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
716                                  u64 bound)
717 {
718         int idx;
719         u64 addr;
720         struct ib_device *dev = umem_odp->umem.ibdev;
721
722         virt = max_t(u64, virt, ib_umem_start(umem_odp));
723         bound = min_t(u64, bound, ib_umem_end(umem_odp));
724         /* Note that during the run of this function, the
725          * notifiers_count of the MR is > 0, preventing any racing
726          * faults from completion. We might be racing with other
727          * invalidations, so we must make sure we free each page only
728          * once. */
729         mutex_lock(&umem_odp->umem_mutex);
730         for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
731                 idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
732                 if (umem_odp->page_list[idx]) {
733                         struct page *page = umem_odp->page_list[idx];
734                         dma_addr_t dma = umem_odp->dma_list[idx];
735                         dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
736
737                         WARN_ON(!dma_addr);
738
739                         ib_dma_unmap_page(dev, dma_addr,
740                                           BIT(umem_odp->page_shift),
741                                           DMA_BIDIRECTIONAL);
742                         if (dma & ODP_WRITE_ALLOWED_BIT) {
743                                 struct page *head_page = compound_head(page);
744                                 /*
745                                  * set_page_dirty prefers being called with
746                                  * the page lock. However, MMU notifiers are
747                                  * called sometimes with and sometimes without
748                                  * the lock. We rely on the umem_mutex instead
749                                  * to prevent other mmu notifiers from
750                                  * continuing and allowing the page mapping to
751                                  * be removed.
752                                  */
753                                 set_page_dirty(head_page);
754                         }
755                         umem_odp->page_list[idx] = NULL;
756                         umem_odp->dma_list[idx] = 0;
757                         umem_odp->npages--;
758                 }
759         }
760         mutex_unlock(&umem_odp->umem_mutex);
761 }
762 EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
763
764 /* @last is not a part of the interval. See comment for function
765  * node_last.
766  */
767 int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
768                                   u64 start, u64 last,
769                                   umem_call_back cb,
770                                   bool blockable,
771                                   void *cookie)
772 {
773         int ret_val = 0;
774         struct interval_tree_node *node, *next;
775         struct ib_umem_odp *umem;
776
777         if (unlikely(start == last))
778                 return ret_val;
779
780         for (node = interval_tree_iter_first(root, start, last - 1);
781                         node; node = next) {
782                 /* TODO move the blockable decision up to the callback */
783                 if (!blockable)
784                         return -EAGAIN;
785                 next = interval_tree_iter_next(node, start, last - 1);
786                 umem = container_of(node, struct ib_umem_odp, interval_tree);
787                 ret_val = cb(umem, start, last, cookie) || ret_val;
788         }
789
790         return ret_val;
791 }