Merge tag 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/rdma/rdma
[linux-block.git] / drivers / infiniband / hw / hfi1 / user_exp_rcv.c
index fbe48cd23c9a92088b1b2dd0f3ddeb17890cbb14..96058baf36edc09a4c0b4b1e13c64b54d61aed8f 100644 (file)
@@ -23,17 +23,24 @@ static void cacheless_tid_rb_remove(struct hfi1_filedata *fdata,
 static bool tid_rb_invalidate(struct mmu_interval_notifier *mni,
                              const struct mmu_notifier_range *range,
                              unsigned long cur_seq);
+static bool tid_cover_invalidate(struct mmu_interval_notifier *mni,
+                                const struct mmu_notifier_range *range,
+                                unsigned long cur_seq);
 static int program_rcvarray(struct hfi1_filedata *fd, struct tid_user_buf *,
                            struct tid_group *grp, u16 count,
                            u32 *tidlist, unsigned int *tididx,
                            unsigned int *pmapped);
-static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo,
-                             struct tid_group **grp);
+static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo);
+static void __clear_tid_node(struct hfi1_filedata *fd,
+                            struct tid_rb_node *node);
 static void clear_tid_node(struct hfi1_filedata *fd, struct tid_rb_node *node);
 
 static const struct mmu_interval_notifier_ops tid_mn_ops = {
        .invalidate = tid_rb_invalidate,
 };
+static const struct mmu_interval_notifier_ops tid_cover_ops = {
+       .invalidate = tid_cover_invalidate,
+};
 
 /*
  * Initialize context and file private data needed for Expected
@@ -246,54 +253,66 @@ int hfi1_user_exp_rcv_setup(struct hfi1_filedata *fd,
                tididx = 0, mapped, mapped_pages = 0;
        u32 *tidlist = NULL;
        struct tid_user_buf *tidbuf;
+       unsigned long mmu_seq = 0;
 
        if (!PAGE_ALIGNED(tinfo->vaddr))
                return -EINVAL;
+       if (tinfo->length == 0)
+               return -EINVAL;
 
        tidbuf = kzalloc(sizeof(*tidbuf), GFP_KERNEL);
        if (!tidbuf)
                return -ENOMEM;
 
+       mutex_init(&tidbuf->cover_mutex);
        tidbuf->vaddr = tinfo->vaddr;
        tidbuf->length = tinfo->length;
        tidbuf->npages = num_user_pages(tidbuf->vaddr, tidbuf->length);
        tidbuf->psets = kcalloc(uctxt->expected_count, sizeof(*tidbuf->psets),
                                GFP_KERNEL);
        if (!tidbuf->psets) {
-               kfree(tidbuf);
-               return -ENOMEM;
+               ret = -ENOMEM;
+               goto fail_release_mem;
+       }
+
+       if (fd->use_mn) {
+               ret = mmu_interval_notifier_insert(
+                       &tidbuf->notifier, current->mm,
+                       tidbuf->vaddr, tidbuf->npages * PAGE_SIZE,
+                       &tid_cover_ops);
+               if (ret)
+                       goto fail_release_mem;
+               mmu_seq = mmu_interval_read_begin(&tidbuf->notifier);
        }
 
        pinned = pin_rcv_pages(fd, tidbuf);
        if (pinned <= 0) {
-               kfree(tidbuf->psets);
-               kfree(tidbuf);
-               return pinned;
+               ret = (pinned < 0) ? pinned : -ENOSPC;
+               goto fail_unpin;
        }
 
        /* Find sets of physically contiguous pages */
        tidbuf->n_psets = find_phys_blocks(tidbuf, pinned);
 
-       /*
-        * We don't need to access this under a lock since tid_used is per
-        * process and the same process cannot be in hfi1_user_exp_rcv_clear()
-        * and hfi1_user_exp_rcv_setup() at the same time.
-        */
+       /* Reserve the number of expected tids to be used. */
        spin_lock(&fd->tid_lock);
        if (fd->tid_used + tidbuf->n_psets > fd->tid_limit)
                pageset_count = fd->tid_limit - fd->tid_used;
        else
                pageset_count = tidbuf->n_psets;
+       fd->tid_used += pageset_count;
        spin_unlock(&fd->tid_lock);
 
-       if (!pageset_count)
-               goto bail;
+       if (!pageset_count) {
+               ret = -ENOSPC;
+               goto fail_unreserve;
+       }
 
        ngroups = pageset_count / dd->rcv_entries.group_size;
        tidlist = kcalloc(pageset_count, sizeof(*tidlist), GFP_KERNEL);
        if (!tidlist) {
                ret = -ENOMEM;
-               goto nomem;
+               goto fail_unreserve;
        }
 
        tididx = 0;
@@ -387,43 +406,78 @@ int hfi1_user_exp_rcv_setup(struct hfi1_filedata *fd,
        }
 unlock:
        mutex_unlock(&uctxt->exp_mutex);
-nomem:
        hfi1_cdbg(TID, "total mapped: tidpairs:%u pages:%u (%d)", tididx,
                  mapped_pages, ret);
-       if (tididx) {
-               spin_lock(&fd->tid_lock);
-               fd->tid_used += tididx;
-               spin_unlock(&fd->tid_lock);
-               tinfo->tidcnt = tididx;
-               tinfo->length = mapped_pages * PAGE_SIZE;
-
-               if (copy_to_user(u64_to_user_ptr(tinfo->tidlist),
-                                tidlist, sizeof(tidlist[0]) * tididx)) {
-                       /*
-                        * On failure to copy to the user level, we need to undo
-                        * everything done so far so we don't leak resources.
-                        */
-                       tinfo->tidlist = (unsigned long)&tidlist;
-                       hfi1_user_exp_rcv_clear(fd, tinfo);
-                       tinfo->tidlist = 0;
-                       ret = -EFAULT;
-                       goto bail;
+
+       /* fail if nothing was programmed, set error if none provided */
+       if (tididx == 0) {
+               if (ret >= 0)
+                       ret = -ENOSPC;
+               goto fail_unreserve;
+       }
+
+       /* adjust reserved tid_used to actual count */
+       spin_lock(&fd->tid_lock);
+       fd->tid_used -= pageset_count - tididx;
+       spin_unlock(&fd->tid_lock);
+
+       /* unpin all pages not covered by a TID */
+       unpin_rcv_pages(fd, tidbuf, NULL, mapped_pages, pinned - mapped_pages,
+                       false);
+
+       if (fd->use_mn) {
+               /* check for an invalidate during setup */
+               bool fail = false;
+
+               mutex_lock(&tidbuf->cover_mutex);
+               fail = mmu_interval_read_retry(&tidbuf->notifier, mmu_seq);
+               mutex_unlock(&tidbuf->cover_mutex);
+
+               if (fail) {
+                       ret = -EBUSY;
+                       goto fail_unprogram;
                }
        }
 
-       /*
-        * If not everything was mapped (due to insufficient RcvArray entries,
-        * for example), unpin all unmapped pages so we can pin them nex time.
-        */
-       if (mapped_pages != pinned)
-               unpin_rcv_pages(fd, tidbuf, NULL, mapped_pages,
-                               (pinned - mapped_pages), false);
-bail:
+       tinfo->tidcnt = tididx;
+       tinfo->length = mapped_pages * PAGE_SIZE;
+
+       if (copy_to_user(u64_to_user_ptr(tinfo->tidlist),
+                        tidlist, sizeof(tidlist[0]) * tididx)) {
+               ret = -EFAULT;
+               goto fail_unprogram;
+       }
+
+       if (fd->use_mn)
+               mmu_interval_notifier_remove(&tidbuf->notifier);
+       kfree(tidbuf->pages);
        kfree(tidbuf->psets);
+       kfree(tidbuf);
        kfree(tidlist);
+       return 0;
+
+fail_unprogram:
+       /* unprogram, unmap, and unpin all allocated TIDs */
+       tinfo->tidlist = (unsigned long)tidlist;
+       hfi1_user_exp_rcv_clear(fd, tinfo);
+       tinfo->tidlist = 0;
+       pinned = 0;             /* nothing left to unpin */
+       pageset_count = 0;      /* nothing left reserved */
+fail_unreserve:
+       spin_lock(&fd->tid_lock);
+       fd->tid_used -= pageset_count;
+       spin_unlock(&fd->tid_lock);
+fail_unpin:
+       if (fd->use_mn)
+               mmu_interval_notifier_remove(&tidbuf->notifier);
+       if (pinned > 0)
+               unpin_rcv_pages(fd, tidbuf, NULL, 0, pinned, false);
+fail_release_mem:
        kfree(tidbuf->pages);
+       kfree(tidbuf->psets);
        kfree(tidbuf);
-       return ret > 0 ? 0 : ret;
+       kfree(tidlist);
+       return ret;
 }
 
 int hfi1_user_exp_rcv_clear(struct hfi1_filedata *fd,
@@ -444,7 +498,7 @@ int hfi1_user_exp_rcv_clear(struct hfi1_filedata *fd,
 
        mutex_lock(&uctxt->exp_mutex);
        for (tididx = 0; tididx < tinfo->tidcnt; tididx++) {
-               ret = unprogram_rcvarray(fd, tidinfo[tididx], NULL);
+               ret = unprogram_rcvarray(fd, tidinfo[tididx]);
                if (ret) {
                        hfi1_cdbg(TID, "Failed to unprogram rcv array %d",
                                  ret);
@@ -696,6 +750,7 @@ static int set_rcvarray_entry(struct hfi1_filedata *fd,
        }
 
        node->fdata = fd;
+       mutex_init(&node->invalidate_mutex);
        node->phys = page_to_phys(pages[0]);
        node->npages = npages;
        node->rcventry = rcventry;
@@ -711,11 +766,6 @@ static int set_rcvarray_entry(struct hfi1_filedata *fd,
                        &tid_mn_ops);
                if (ret)
                        goto out_unmap;
-               /*
-                * FIXME: This is in the wrong order, the notifier should be
-                * established before the pages are pinned by pin_rcv_pages.
-                */
-               mmu_interval_read_begin(&node->notifier);
        }
        fd->entry_to_rb[node->rcventry - uctxt->expected_base] = node;
 
@@ -735,8 +785,7 @@ out_unmap:
        return -EFAULT;
 }
 
-static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo,
-                             struct tid_group **grp)
+static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo)
 {
        struct hfi1_ctxtdata *uctxt = fd->uctxt;
        struct hfi1_devdata *dd = uctxt->dd;
@@ -759,9 +808,6 @@ static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo,
        if (!node || node->rcventry != (uctxt->expected_base + rcventry))
                return -EBADF;
 
-       if (grp)
-               *grp = node->grp;
-
        if (fd->use_mn)
                mmu_interval_notifier_remove(&node->notifier);
        cacheless_tid_rb_remove(fd, node);
@@ -769,23 +815,34 @@ static int unprogram_rcvarray(struct hfi1_filedata *fd, u32 tidinfo,
        return 0;
 }
 
-static void clear_tid_node(struct hfi1_filedata *fd, struct tid_rb_node *node)
+static void __clear_tid_node(struct hfi1_filedata *fd, struct tid_rb_node *node)
 {
        struct hfi1_ctxtdata *uctxt = fd->uctxt;
        struct hfi1_devdata *dd = uctxt->dd;
 
+       mutex_lock(&node->invalidate_mutex);
+       if (node->freed)
+               goto done;
+       node->freed = true;
+
        trace_hfi1_exp_tid_unreg(uctxt->ctxt, fd->subctxt, node->rcventry,
                                 node->npages,
                                 node->notifier.interval_tree.start, node->phys,
                                 node->dma_addr);
 
-       /*
-        * Make sure device has seen the write before we unpin the
-        * pages.
-        */
+       /* Make sure device has seen the write before pages are unpinned */
        hfi1_put_tid(dd, node->rcventry, PT_INVALID_FLUSH, 0, 0);
 
        unpin_rcv_pages(fd, NULL, node, 0, node->npages, true);
+done:
+       mutex_unlock(&node->invalidate_mutex);
+}
+
+static void clear_tid_node(struct hfi1_filedata *fd, struct tid_rb_node *node)
+{
+       struct hfi1_ctxtdata *uctxt = fd->uctxt;
+
+       __clear_tid_node(fd, node);
 
        node->grp->used--;
        node->grp->map &= ~(1 << (node->rcventry - node->grp->base));
@@ -844,10 +901,16 @@ static bool tid_rb_invalidate(struct mmu_interval_notifier *mni,
        if (node->freed)
                return true;
 
+       /* take action only if unmapping */
+       if (range->event != MMU_NOTIFY_UNMAP)
+               return true;
+
        trace_hfi1_exp_tid_inval(uctxt->ctxt, fdata->subctxt,
                                 node->notifier.interval_tree.start,
                                 node->rcventry, node->npages, node->dma_addr);
-       node->freed = true;
+
+       /* clear the hardware rcvarray entry */
+       __clear_tid_node(fdata, node);
 
        spin_lock(&fdata->invalid_lock);
        if (fdata->invalid_tid_idx < uctxt->expected_count) {
@@ -876,6 +939,23 @@ static bool tid_rb_invalidate(struct mmu_interval_notifier *mni,
        return true;
 }
 
+static bool tid_cover_invalidate(struct mmu_interval_notifier *mni,
+                                const struct mmu_notifier_range *range,
+                                unsigned long cur_seq)
+{
+       struct tid_user_buf *tidbuf =
+               container_of(mni, struct tid_user_buf, notifier);
+
+       /* take action only if unmapping */
+       if (range->event == MMU_NOTIFY_UNMAP) {
+               mutex_lock(&tidbuf->cover_mutex);
+               mmu_interval_set_seq(mni, cur_seq);
+               mutex_unlock(&tidbuf->cover_mutex);
+       }
+
+       return true;
+}
+
 static void cacheless_tid_rb_remove(struct hfi1_filedata *fdata,
                                    struct tid_rb_node *tnode)
 {