Merge branch 'linux-4.7' of git://github.com/skeggsb/linux into drm-fixes
[linux-2.6-block.git] / drivers / infiniband / sw / rdmavt / mr.c
1 /*
2  * Copyright(c) 2016 Intel Corporation.
3  *
4  * This file is provided under a dual BSD/GPLv2 license.  When using or
5  * redistributing this file, you may do so under either license.
6  *
7  * GPL LICENSE SUMMARY
8  *
9  * This program is free software; you can redistribute it and/or modify
10  * it under the terms of version 2 of the GNU General Public License as
11  * published by the Free Software Foundation.
12  *
13  * This program is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * General Public License for more details.
17  *
18  * BSD LICENSE
19  *
20  * Redistribution and use in source and binary forms, with or without
21  * modification, are permitted provided that the following conditions
22  * are met:
23  *
24  *  - Redistributions of source code must retain the above copyright
25  *    notice, this list of conditions and the following disclaimer.
26  *  - Redistributions in binary form must reproduce the above copyright
27  *    notice, this list of conditions and the following disclaimer in
28  *    the documentation and/or other materials provided with the
29  *    distribution.
30  *  - Neither the name of Intel Corporation nor the names of its
31  *    contributors may be used to endorse or promote products derived
32  *    from this software without specific prior written permission.
33  *
34  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
35  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
36  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
37  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
38  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
39  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
40  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
41  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
42  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
43  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
44  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
45  *
46  */
47
48 #include <linux/slab.h>
49 #include <linux/vmalloc.h>
50 #include <rdma/ib_umem.h>
51 #include <rdma/rdma_vt.h>
52 #include "vt.h"
53 #include "mr.h"
54
55 /**
56  * rvt_driver_mr_init - Init MR resources per driver
57  * @rdi: rvt dev struct
58  *
59  * Do any intilization needed when a driver registers with rdmavt.
60  *
61  * Return: 0 on success or errno on failure
62  */
63 int rvt_driver_mr_init(struct rvt_dev_info *rdi)
64 {
65         unsigned int lkey_table_size = rdi->dparms.lkey_table_size;
66         unsigned lk_tab_size;
67         int i;
68
69         /*
70          * The top hfi1_lkey_table_size bits are used to index the
71          * table.  The lower 8 bits can be owned by the user (copied from
72          * the LKEY).  The remaining bits act as a generation number or tag.
73          */
74         if (!lkey_table_size)
75                 return -EINVAL;
76
77         spin_lock_init(&rdi->lkey_table.lock);
78
79         /* ensure generation is at least 4 bits */
80         if (lkey_table_size > RVT_MAX_LKEY_TABLE_BITS) {
81                 rvt_pr_warn(rdi, "lkey bits %u too large, reduced to %u\n",
82                             lkey_table_size, RVT_MAX_LKEY_TABLE_BITS);
83                 rdi->dparms.lkey_table_size = RVT_MAX_LKEY_TABLE_BITS;
84                 lkey_table_size = rdi->dparms.lkey_table_size;
85         }
86         rdi->lkey_table.max = 1 << lkey_table_size;
87         lk_tab_size = rdi->lkey_table.max * sizeof(*rdi->lkey_table.table);
88         rdi->lkey_table.table = (struct rvt_mregion __rcu **)
89                                vmalloc_node(lk_tab_size, rdi->dparms.node);
90         if (!rdi->lkey_table.table)
91                 return -ENOMEM;
92
93         RCU_INIT_POINTER(rdi->dma_mr, NULL);
94         for (i = 0; i < rdi->lkey_table.max; i++)
95                 RCU_INIT_POINTER(rdi->lkey_table.table[i], NULL);
96
97         return 0;
98 }
99
100 /**
101  *rvt_mr_exit: clean up MR
102  *@rdi: rvt dev structure
103  *
104  * called when drivers have unregistered or perhaps failed to register with us
105  */
106 void rvt_mr_exit(struct rvt_dev_info *rdi)
107 {
108         if (rdi->dma_mr)
109                 rvt_pr_err(rdi, "DMA MR not null!\n");
110
111         vfree(rdi->lkey_table.table);
112 }
113
114 static void rvt_deinit_mregion(struct rvt_mregion *mr)
115 {
116         int i = mr->mapsz;
117
118         mr->mapsz = 0;
119         while (i)
120                 kfree(mr->map[--i]);
121 }
122
123 static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd,
124                             int count)
125 {
126         int m, i = 0;
127         struct rvt_dev_info *dev = ib_to_rvt(pd->device);
128
129         mr->mapsz = 0;
130         m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
131         for (; i < m; i++) {
132                 mr->map[i] = kzalloc_node(sizeof(*mr->map[0]), GFP_KERNEL,
133                                           dev->dparms.node);
134                 if (!mr->map[i]) {
135                         rvt_deinit_mregion(mr);
136                         return -ENOMEM;
137                 }
138                 mr->mapsz++;
139         }
140         init_completion(&mr->comp);
141         /* count returning the ptr to user */
142         atomic_set(&mr->refcount, 1);
143         mr->pd = pd;
144         mr->max_segs = count;
145         return 0;
146 }
147
148 /**
149  * rvt_alloc_lkey - allocate an lkey
150  * @mr: memory region that this lkey protects
151  * @dma_region: 0->normal key, 1->restricted DMA key
152  *
153  * Returns 0 if successful, otherwise returns -errno.
154  *
155  * Increments mr reference count as required.
156  *
157  * Sets the lkey field mr for non-dma regions.
158  *
159  */
160 static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region)
161 {
162         unsigned long flags;
163         u32 r;
164         u32 n;
165         int ret = 0;
166         struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
167         struct rvt_lkey_table *rkt = &dev->lkey_table;
168
169         rvt_get_mr(mr);
170         spin_lock_irqsave(&rkt->lock, flags);
171
172         /* special case for dma_mr lkey == 0 */
173         if (dma_region) {
174                 struct rvt_mregion *tmr;
175
176                 tmr = rcu_access_pointer(dev->dma_mr);
177                 if (!tmr) {
178                         rcu_assign_pointer(dev->dma_mr, mr);
179                         mr->lkey_published = 1;
180                 } else {
181                         rvt_put_mr(mr);
182                 }
183                 goto success;
184         }
185
186         /* Find the next available LKEY */
187         r = rkt->next;
188         n = r;
189         for (;;) {
190                 if (!rcu_access_pointer(rkt->table[r]))
191                         break;
192                 r = (r + 1) & (rkt->max - 1);
193                 if (r == n)
194                         goto bail;
195         }
196         rkt->next = (r + 1) & (rkt->max - 1);
197         /*
198          * Make sure lkey is never zero which is reserved to indicate an
199          * unrestricted LKEY.
200          */
201         rkt->gen++;
202         /*
203          * bits are capped to ensure enough bits for generation number
204          */
205         mr->lkey = (r << (32 - dev->dparms.lkey_table_size)) |
206                 ((((1 << (24 - dev->dparms.lkey_table_size)) - 1) & rkt->gen)
207                  << 8);
208         if (mr->lkey == 0) {
209                 mr->lkey |= 1 << 8;
210                 rkt->gen++;
211         }
212         rcu_assign_pointer(rkt->table[r], mr);
213         mr->lkey_published = 1;
214 success:
215         spin_unlock_irqrestore(&rkt->lock, flags);
216 out:
217         return ret;
218 bail:
219         rvt_put_mr(mr);
220         spin_unlock_irqrestore(&rkt->lock, flags);
221         ret = -ENOMEM;
222         goto out;
223 }
224
225 /**
226  * rvt_free_lkey - free an lkey
227  * @mr: mr to free from tables
228  */
229 static void rvt_free_lkey(struct rvt_mregion *mr)
230 {
231         unsigned long flags;
232         u32 lkey = mr->lkey;
233         u32 r;
234         struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
235         struct rvt_lkey_table *rkt = &dev->lkey_table;
236         int freed = 0;
237
238         spin_lock_irqsave(&rkt->lock, flags);
239         if (!mr->lkey_published)
240                 goto out;
241         if (lkey == 0) {
242                 RCU_INIT_POINTER(dev->dma_mr, NULL);
243         } else {
244                 r = lkey >> (32 - dev->dparms.lkey_table_size);
245                 RCU_INIT_POINTER(rkt->table[r], NULL);
246         }
247         mr->lkey_published = 0;
248         freed++;
249 out:
250         spin_unlock_irqrestore(&rkt->lock, flags);
251         if (freed) {
252                 synchronize_rcu();
253                 rvt_put_mr(mr);
254         }
255 }
256
257 static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd)
258 {
259         struct rvt_mr *mr;
260         int rval = -ENOMEM;
261         int m;
262
263         /* Allocate struct plus pointers to first level page tables. */
264         m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
265         mr = kzalloc(sizeof(*mr) + m * sizeof(mr->mr.map[0]), GFP_KERNEL);
266         if (!mr)
267                 goto bail;
268
269         rval = rvt_init_mregion(&mr->mr, pd, count);
270         if (rval)
271                 goto bail;
272         /*
273          * ib_reg_phys_mr() will initialize mr->ibmr except for
274          * lkey and rkey.
275          */
276         rval = rvt_alloc_lkey(&mr->mr, 0);
277         if (rval)
278                 goto bail_mregion;
279         mr->ibmr.lkey = mr->mr.lkey;
280         mr->ibmr.rkey = mr->mr.lkey;
281 done:
282         return mr;
283
284 bail_mregion:
285         rvt_deinit_mregion(&mr->mr);
286 bail:
287         kfree(mr);
288         mr = ERR_PTR(rval);
289         goto done;
290 }
291
292 static void __rvt_free_mr(struct rvt_mr *mr)
293 {
294         rvt_deinit_mregion(&mr->mr);
295         rvt_free_lkey(&mr->mr);
296         vfree(mr);
297 }
298
299 /**
300  * rvt_get_dma_mr - get a DMA memory region
301  * @pd: protection domain for this memory region
302  * @acc: access flags
303  *
304  * Return: the memory region on success, otherwise returns an errno.
305  * Note that all DMA addresses should be created via the
306  * struct ib_dma_mapping_ops functions (see dma.c).
307  */
308 struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc)
309 {
310         struct rvt_mr *mr;
311         struct ib_mr *ret;
312         int rval;
313
314         if (ibpd_to_rvtpd(pd)->user)
315                 return ERR_PTR(-EPERM);
316
317         mr = kzalloc(sizeof(*mr), GFP_KERNEL);
318         if (!mr) {
319                 ret = ERR_PTR(-ENOMEM);
320                 goto bail;
321         }
322
323         rval = rvt_init_mregion(&mr->mr, pd, 0);
324         if (rval) {
325                 ret = ERR_PTR(rval);
326                 goto bail;
327         }
328
329         rval = rvt_alloc_lkey(&mr->mr, 1);
330         if (rval) {
331                 ret = ERR_PTR(rval);
332                 goto bail_mregion;
333         }
334
335         mr->mr.access_flags = acc;
336         ret = &mr->ibmr;
337 done:
338         return ret;
339
340 bail_mregion:
341         rvt_deinit_mregion(&mr->mr);
342 bail:
343         kfree(mr);
344         goto done;
345 }
346
347 /**
348  * rvt_reg_user_mr - register a userspace memory region
349  * @pd: protection domain for this memory region
350  * @start: starting userspace address
351  * @length: length of region to register
352  * @mr_access_flags: access flags for this memory region
353  * @udata: unused by the driver
354  *
355  * Return: the memory region on success, otherwise returns an errno.
356  */
357 struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
358                               u64 virt_addr, int mr_access_flags,
359                               struct ib_udata *udata)
360 {
361         struct rvt_mr *mr;
362         struct ib_umem *umem;
363         struct scatterlist *sg;
364         int n, m, entry;
365         struct ib_mr *ret;
366
367         if (length == 0)
368                 return ERR_PTR(-EINVAL);
369
370         umem = ib_umem_get(pd->uobject->context, start, length,
371                            mr_access_flags, 0);
372         if (IS_ERR(umem))
373                 return (void *)umem;
374
375         n = umem->nmap;
376
377         mr = __rvt_alloc_mr(n, pd);
378         if (IS_ERR(mr)) {
379                 ret = (struct ib_mr *)mr;
380                 goto bail_umem;
381         }
382
383         mr->mr.user_base = start;
384         mr->mr.iova = virt_addr;
385         mr->mr.length = length;
386         mr->mr.offset = ib_umem_offset(umem);
387         mr->mr.access_flags = mr_access_flags;
388         mr->umem = umem;
389
390         if (is_power_of_2(umem->page_size))
391                 mr->mr.page_shift = ilog2(umem->page_size);
392         m = 0;
393         n = 0;
394         for_each_sg(umem->sg_head.sgl, sg, umem->nmap, entry) {
395                 void *vaddr;
396
397                 vaddr = page_address(sg_page(sg));
398                 if (!vaddr) {
399                         ret = ERR_PTR(-EINVAL);
400                         goto bail_inval;
401                 }
402                 mr->mr.map[m]->segs[n].vaddr = vaddr;
403                 mr->mr.map[m]->segs[n].length = umem->page_size;
404                 n++;
405                 if (n == RVT_SEGSZ) {
406                         m++;
407                         n = 0;
408                 }
409         }
410         return &mr->ibmr;
411
412 bail_inval:
413         __rvt_free_mr(mr);
414
415 bail_umem:
416         ib_umem_release(umem);
417
418         return ret;
419 }
420
421 /**
422  * rvt_dereg_mr - unregister and free a memory region
423  * @ibmr: the memory region to free
424  *
425  *
426  * Note that this is called to free MRs created by rvt_get_dma_mr()
427  * or rvt_reg_user_mr().
428  *
429  * Returns 0 on success.
430  */
431 int rvt_dereg_mr(struct ib_mr *ibmr)
432 {
433         struct rvt_mr *mr = to_imr(ibmr);
434         struct rvt_dev_info *rdi = ib_to_rvt(ibmr->pd->device);
435         int ret = 0;
436         unsigned long timeout;
437
438         rvt_free_lkey(&mr->mr);
439
440         rvt_put_mr(&mr->mr); /* will set completion if last */
441         timeout = wait_for_completion_timeout(&mr->mr.comp, 5 * HZ);
442         if (!timeout) {
443                 rvt_pr_err(rdi,
444                            "rvt_dereg_mr timeout mr %p pd %p refcount %u\n",
445                            mr, mr->mr.pd, atomic_read(&mr->mr.refcount));
446                 rvt_get_mr(&mr->mr);
447                 ret = -EBUSY;
448                 goto out;
449         }
450         rvt_deinit_mregion(&mr->mr);
451         if (mr->umem)
452                 ib_umem_release(mr->umem);
453         kfree(mr);
454 out:
455         return ret;
456 }
457
458 /**
459  * rvt_alloc_mr - Allocate a memory region usable with the
460  * @pd: protection domain for this memory region
461  * @mr_type: mem region type
462  * @max_num_sg: Max number of segments allowed
463  *
464  * Return: the memory region on success, otherwise return an errno.
465  */
466 struct ib_mr *rvt_alloc_mr(struct ib_pd *pd,
467                            enum ib_mr_type mr_type,
468                            u32 max_num_sg)
469 {
470         struct rvt_mr *mr;
471
472         if (mr_type != IB_MR_TYPE_MEM_REG)
473                 return ERR_PTR(-EINVAL);
474
475         mr = __rvt_alloc_mr(max_num_sg, pd);
476         if (IS_ERR(mr))
477                 return (struct ib_mr *)mr;
478
479         return &mr->ibmr;
480 }
481
482 /**
483  * rvt_alloc_fmr - allocate a fast memory region
484  * @pd: the protection domain for this memory region
485  * @mr_access_flags: access flags for this memory region
486  * @fmr_attr: fast memory region attributes
487  *
488  * Return: the memory region on success, otherwise returns an errno.
489  */
490 struct ib_fmr *rvt_alloc_fmr(struct ib_pd *pd, int mr_access_flags,
491                              struct ib_fmr_attr *fmr_attr)
492 {
493         struct rvt_fmr *fmr;
494         int m;
495         struct ib_fmr *ret;
496         int rval = -ENOMEM;
497
498         /* Allocate struct plus pointers to first level page tables. */
499         m = (fmr_attr->max_pages + RVT_SEGSZ - 1) / RVT_SEGSZ;
500         fmr = kzalloc(sizeof(*fmr) + m * sizeof(fmr->mr.map[0]), GFP_KERNEL);
501         if (!fmr)
502                 goto bail;
503
504         rval = rvt_init_mregion(&fmr->mr, pd, fmr_attr->max_pages);
505         if (rval)
506                 goto bail;
507
508         /*
509          * ib_alloc_fmr() will initialize fmr->ibfmr except for lkey &
510          * rkey.
511          */
512         rval = rvt_alloc_lkey(&fmr->mr, 0);
513         if (rval)
514                 goto bail_mregion;
515         fmr->ibfmr.rkey = fmr->mr.lkey;
516         fmr->ibfmr.lkey = fmr->mr.lkey;
517         /*
518          * Resources are allocated but no valid mapping (RKEY can't be
519          * used).
520          */
521         fmr->mr.access_flags = mr_access_flags;
522         fmr->mr.max_segs = fmr_attr->max_pages;
523         fmr->mr.page_shift = fmr_attr->page_shift;
524
525         ret = &fmr->ibfmr;
526 done:
527         return ret;
528
529 bail_mregion:
530         rvt_deinit_mregion(&fmr->mr);
531 bail:
532         kfree(fmr);
533         ret = ERR_PTR(rval);
534         goto done;
535 }
536
537 /**
538  * rvt_map_phys_fmr - set up a fast memory region
539  * @ibmfr: the fast memory region to set up
540  * @page_list: the list of pages to associate with the fast memory region
541  * @list_len: the number of pages to associate with the fast memory region
542  * @iova: the virtual address of the start of the fast memory region
543  *
544  * This may be called from interrupt context.
545  *
546  * Return: 0 on success
547  */
548
549 int rvt_map_phys_fmr(struct ib_fmr *ibfmr, u64 *page_list,
550                      int list_len, u64 iova)
551 {
552         struct rvt_fmr *fmr = to_ifmr(ibfmr);
553         struct rvt_lkey_table *rkt;
554         unsigned long flags;
555         int m, n, i;
556         u32 ps;
557         struct rvt_dev_info *rdi = ib_to_rvt(ibfmr->device);
558
559         i = atomic_read(&fmr->mr.refcount);
560         if (i > 2)
561                 return -EBUSY;
562
563         if (list_len > fmr->mr.max_segs)
564                 return -EINVAL;
565
566         rkt = &rdi->lkey_table;
567         spin_lock_irqsave(&rkt->lock, flags);
568         fmr->mr.user_base = iova;
569         fmr->mr.iova = iova;
570         ps = 1 << fmr->mr.page_shift;
571         fmr->mr.length = list_len * ps;
572         m = 0;
573         n = 0;
574         for (i = 0; i < list_len; i++) {
575                 fmr->mr.map[m]->segs[n].vaddr = (void *)page_list[i];
576                 fmr->mr.map[m]->segs[n].length = ps;
577                 if (++n == RVT_SEGSZ) {
578                         m++;
579                         n = 0;
580                 }
581         }
582         spin_unlock_irqrestore(&rkt->lock, flags);
583         return 0;
584 }
585
586 /**
587  * rvt_unmap_fmr - unmap fast memory regions
588  * @fmr_list: the list of fast memory regions to unmap
589  *
590  * Return: 0 on success.
591  */
592 int rvt_unmap_fmr(struct list_head *fmr_list)
593 {
594         struct rvt_fmr *fmr;
595         struct rvt_lkey_table *rkt;
596         unsigned long flags;
597         struct rvt_dev_info *rdi;
598
599         list_for_each_entry(fmr, fmr_list, ibfmr.list) {
600                 rdi = ib_to_rvt(fmr->ibfmr.device);
601                 rkt = &rdi->lkey_table;
602                 spin_lock_irqsave(&rkt->lock, flags);
603                 fmr->mr.user_base = 0;
604                 fmr->mr.iova = 0;
605                 fmr->mr.length = 0;
606                 spin_unlock_irqrestore(&rkt->lock, flags);
607         }
608         return 0;
609 }
610
611 /**
612  * rvt_dealloc_fmr - deallocate a fast memory region
613  * @ibfmr: the fast memory region to deallocate
614  *
615  * Return: 0 on success.
616  */
617 int rvt_dealloc_fmr(struct ib_fmr *ibfmr)
618 {
619         struct rvt_fmr *fmr = to_ifmr(ibfmr);
620         int ret = 0;
621         unsigned long timeout;
622
623         rvt_free_lkey(&fmr->mr);
624         rvt_put_mr(&fmr->mr); /* will set completion if last */
625         timeout = wait_for_completion_timeout(&fmr->mr.comp, 5 * HZ);
626         if (!timeout) {
627                 rvt_get_mr(&fmr->mr);
628                 ret = -EBUSY;
629                 goto out;
630         }
631         rvt_deinit_mregion(&fmr->mr);
632         kfree(fmr);
633 out:
634         return ret;
635 }
636
637 /**
638  * rvt_lkey_ok - check IB SGE for validity and initialize
639  * @rkt: table containing lkey to check SGE against
640  * @pd: protection domain
641  * @isge: outgoing internal SGE
642  * @sge: SGE to check
643  * @acc: access flags
644  *
645  * Check the IB SGE for validity and initialize our internal version
646  * of it.
647  *
648  * Return: 1 if valid and successful, otherwise returns 0.
649  *
650  * increments the reference count upon success
651  *
652  */
653 int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd,
654                 struct rvt_sge *isge, struct ib_sge *sge, int acc)
655 {
656         struct rvt_mregion *mr;
657         unsigned n, m;
658         size_t off;
659         struct rvt_dev_info *dev = ib_to_rvt(pd->ibpd.device);
660
661         /*
662          * We use LKEY == zero for kernel virtual addresses
663          * (see rvt_get_dma_mr and dma.c).
664          */
665         rcu_read_lock();
666         if (sge->lkey == 0) {
667                 if (pd->user)
668                         goto bail;
669                 mr = rcu_dereference(dev->dma_mr);
670                 if (!mr)
671                         goto bail;
672                 atomic_inc(&mr->refcount);
673                 rcu_read_unlock();
674
675                 isge->mr = mr;
676                 isge->vaddr = (void *)sge->addr;
677                 isge->length = sge->length;
678                 isge->sge_length = sge->length;
679                 isge->m = 0;
680                 isge->n = 0;
681                 goto ok;
682         }
683         mr = rcu_dereference(
684                 rkt->table[(sge->lkey >> (32 - dev->dparms.lkey_table_size))]);
685         if (unlikely(!mr || mr->lkey != sge->lkey || mr->pd != &pd->ibpd))
686                 goto bail;
687
688         off = sge->addr - mr->user_base;
689         if (unlikely(sge->addr < mr->user_base ||
690                      off + sge->length > mr->length ||
691                      (mr->access_flags & acc) != acc))
692                 goto bail;
693         atomic_inc(&mr->refcount);
694         rcu_read_unlock();
695
696         off += mr->offset;
697         if (mr->page_shift) {
698                 /*
699                  * page sizes are uniform power of 2 so no loop is necessary
700                  * entries_spanned_by_off is the number of times the loop below
701                  * would have executed.
702                 */
703                 size_t entries_spanned_by_off;
704
705                 entries_spanned_by_off = off >> mr->page_shift;
706                 off -= (entries_spanned_by_off << mr->page_shift);
707                 m = entries_spanned_by_off / RVT_SEGSZ;
708                 n = entries_spanned_by_off % RVT_SEGSZ;
709         } else {
710                 m = 0;
711                 n = 0;
712                 while (off >= mr->map[m]->segs[n].length) {
713                         off -= mr->map[m]->segs[n].length;
714                         n++;
715                         if (n >= RVT_SEGSZ) {
716                                 m++;
717                                 n = 0;
718                         }
719                 }
720         }
721         isge->mr = mr;
722         isge->vaddr = mr->map[m]->segs[n].vaddr + off;
723         isge->length = mr->map[m]->segs[n].length - off;
724         isge->sge_length = sge->length;
725         isge->m = m;
726         isge->n = n;
727 ok:
728         return 1;
729 bail:
730         rcu_read_unlock();
731         return 0;
732 }
733 EXPORT_SYMBOL(rvt_lkey_ok);
734
735 /**
736  * rvt_rkey_ok - check the IB virtual address, length, and RKEY
737  * @qp: qp for validation
738  * @sge: SGE state
739  * @len: length of data
740  * @vaddr: virtual address to place data
741  * @rkey: rkey to check
742  * @acc: access flags
743  *
744  * Return: 1 if successful, otherwise 0.
745  *
746  * increments the reference count upon success
747  */
748 int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge,
749                 u32 len, u64 vaddr, u32 rkey, int acc)
750 {
751         struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
752         struct rvt_lkey_table *rkt = &dev->lkey_table;
753         struct rvt_mregion *mr;
754         unsigned n, m;
755         size_t off;
756
757         /*
758          * We use RKEY == zero for kernel virtual addresses
759          * (see rvt_get_dma_mr and dma.c).
760          */
761         rcu_read_lock();
762         if (rkey == 0) {
763                 struct rvt_pd *pd = ibpd_to_rvtpd(qp->ibqp.pd);
764                 struct rvt_dev_info *rdi = ib_to_rvt(pd->ibpd.device);
765
766                 if (pd->user)
767                         goto bail;
768                 mr = rcu_dereference(rdi->dma_mr);
769                 if (!mr)
770                         goto bail;
771                 atomic_inc(&mr->refcount);
772                 rcu_read_unlock();
773
774                 sge->mr = mr;
775                 sge->vaddr = (void *)vaddr;
776                 sge->length = len;
777                 sge->sge_length = len;
778                 sge->m = 0;
779                 sge->n = 0;
780                 goto ok;
781         }
782
783         mr = rcu_dereference(
784                 rkt->table[(rkey >> (32 - dev->dparms.lkey_table_size))]);
785         if (unlikely(!mr || mr->lkey != rkey || qp->ibqp.pd != mr->pd))
786                 goto bail;
787
788         off = vaddr - mr->iova;
789         if (unlikely(vaddr < mr->iova || off + len > mr->length ||
790                      (mr->access_flags & acc) == 0))
791                 goto bail;
792         atomic_inc(&mr->refcount);
793         rcu_read_unlock();
794
795         off += mr->offset;
796         if (mr->page_shift) {
797                 /*
798                  * page sizes are uniform power of 2 so no loop is necessary
799                  * entries_spanned_by_off is the number of times the loop below
800                  * would have executed.
801                 */
802                 size_t entries_spanned_by_off;
803
804                 entries_spanned_by_off = off >> mr->page_shift;
805                 off -= (entries_spanned_by_off << mr->page_shift);
806                 m = entries_spanned_by_off / RVT_SEGSZ;
807                 n = entries_spanned_by_off % RVT_SEGSZ;
808         } else {
809                 m = 0;
810                 n = 0;
811                 while (off >= mr->map[m]->segs[n].length) {
812                         off -= mr->map[m]->segs[n].length;
813                         n++;
814                         if (n >= RVT_SEGSZ) {
815                                 m++;
816                                 n = 0;
817                         }
818                 }
819         }
820         sge->mr = mr;
821         sge->vaddr = mr->map[m]->segs[n].vaddr + off;
822         sge->length = mr->map[m]->segs[n].length - off;
823         sge->sge_length = len;
824         sge->m = m;
825         sge->n = n;
826 ok:
827         return 1;
828 bail:
829         rcu_read_unlock();
830         return 0;
831 }
832 EXPORT_SYMBOL(rvt_rkey_ok);