IB/core: Add netdev and gid attributes paramteres to cache
[linux-2.6-block.git] / drivers / infiniband / core / cm.c
index 3a972ebf3c0d1170efe280aa7bcf781c831fa98f..2d8a0e4c42d6567cbc925ce46a1d51e770ff5a89 100644 (file)
@@ -58,7 +58,7 @@ MODULE_DESCRIPTION("InfiniBand CM");
 MODULE_LICENSE("Dual BSD/GPL");
 
 static void cm_add_one(struct ib_device *device);
-static void cm_remove_one(struct ib_device *device);
+static void cm_remove_one(struct ib_device *device, void *client_data);
 
 static struct ib_client cm_client = {
        .name   = "cm",
@@ -213,13 +213,15 @@ struct cm_id_private {
        spinlock_t lock;        /* Do not acquire inside cm.lock */
        struct completion comp;
        atomic_t refcount;
+       /* Number of clients sharing this ib_cm_id. Only valid for listeners.
+        * Protected by the cm.lock spinlock. */
+       int listen_sharecount;
 
        struct ib_mad_send_buf *msg;
        struct cm_timewait_info *timewait_info;
        /* todo: use alternate port on send failure */
        struct cm_av av;
        struct cm_av alt_av;
-       struct ib_cm_compare_data *compare_data;
 
        void *private_data;
        __be64 tid;
@@ -363,7 +365,7 @@ static int cm_init_av_by_path(struct ib_sa_path_rec *path, struct cm_av *av)
        read_lock_irqsave(&cm.device_lock, flags);
        list_for_each_entry(cm_dev, &cm.device_list, list) {
                if (!ib_find_cached_gid(cm_dev->ib_device, &path->sgid,
-                                       &p, NULL)) {
+                                       NULL, &p, NULL)) {
                        port = cm_dev->port[p-1];
                        break;
                }
@@ -440,40 +442,6 @@ static struct cm_id_private * cm_acquire_id(__be32 local_id, __be32 remote_id)
        return cm_id_priv;
 }
 
-static void cm_mask_copy(u32 *dst, const u32 *src, const u32 *mask)
-{
-       int i;
-
-       for (i = 0; i < IB_CM_COMPARE_SIZE; i++)
-               dst[i] = src[i] & mask[i];
-}
-
-static int cm_compare_data(struct ib_cm_compare_data *src_data,
-                          struct ib_cm_compare_data *dst_data)
-{
-       u32 src[IB_CM_COMPARE_SIZE];
-       u32 dst[IB_CM_COMPARE_SIZE];
-
-       if (!src_data || !dst_data)
-               return 0;
-
-       cm_mask_copy(src, src_data->data, dst_data->mask);
-       cm_mask_copy(dst, dst_data->data, src_data->mask);
-       return memcmp(src, dst, sizeof(src));
-}
-
-static int cm_compare_private_data(u32 *private_data,
-                                  struct ib_cm_compare_data *dst_data)
-{
-       u32 src[IB_CM_COMPARE_SIZE];
-
-       if (!dst_data)
-               return 0;
-
-       cm_mask_copy(src, private_data, dst_data->mask);
-       return memcmp(src, dst_data->data, sizeof(src));
-}
-
 /*
  * Trivial helpers to strip endian annotation and compare; the
  * endianness doesn't actually matter since we just need a stable
@@ -506,18 +474,14 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv)
        struct cm_id_private *cur_cm_id_priv;
        __be64 service_id = cm_id_priv->id.service_id;
        __be64 service_mask = cm_id_priv->id.service_mask;
-       int data_cmp;
 
        while (*link) {
                parent = *link;
                cur_cm_id_priv = rb_entry(parent, struct cm_id_private,
                                          service_node);
-               data_cmp = cm_compare_data(cm_id_priv->compare_data,
-                                          cur_cm_id_priv->compare_data);
                if ((cur_cm_id_priv->id.service_mask & service_id) ==
                    (service_mask & cur_cm_id_priv->id.service_id) &&
-                   (cm_id_priv->id.device == cur_cm_id_priv->id.device) &&
-                   !data_cmp)
+                   (cm_id_priv->id.device == cur_cm_id_priv->id.device))
                        return cur_cm_id_priv;
 
                if (cm_id_priv->id.device < cur_cm_id_priv->id.device)
@@ -528,8 +492,6 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv)
                        link = &(*link)->rb_left;
                else if (be64_gt(service_id, cur_cm_id_priv->id.service_id))
                        link = &(*link)->rb_right;
-               else if (data_cmp < 0)
-                       link = &(*link)->rb_left;
                else
                        link = &(*link)->rb_right;
        }
@@ -539,20 +501,16 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv)
 }
 
 static struct cm_id_private * cm_find_listen(struct ib_device *device,
-                                            __be64 service_id,
-                                            u32 *private_data)
+                                            __be64 service_id)
 {
        struct rb_node *node = cm.listen_service_table.rb_node;
        struct cm_id_private *cm_id_priv;
-       int data_cmp;
 
        while (node) {
                cm_id_priv = rb_entry(node, struct cm_id_private, service_node);
-               data_cmp = cm_compare_private_data(private_data,
-                                                  cm_id_priv->compare_data);
                if ((cm_id_priv->id.service_mask & service_id) ==
                     cm_id_priv->id.service_id &&
-                   (cm_id_priv->id.device == device) && !data_cmp)
+                   (cm_id_priv->id.device == device))
                        return cm_id_priv;
 
                if (device < cm_id_priv->id.device)
@@ -563,8 +521,6 @@ static struct cm_id_private * cm_find_listen(struct ib_device *device,
                        node = node->rb_left;
                else if (be64_gt(service_id, cm_id_priv->id.service_id))
                        node = node->rb_right;
-               else if (data_cmp < 0)
-                       node = node->rb_left;
                else
                        node = node->rb_right;
        }
@@ -859,9 +815,15 @@ retest:
        spin_lock_irq(&cm_id_priv->lock);
        switch (cm_id->state) {
        case IB_CM_LISTEN:
-               cm_id->state = IB_CM_IDLE;
                spin_unlock_irq(&cm_id_priv->lock);
+
                spin_lock_irq(&cm.lock);
+               if (--cm_id_priv->listen_sharecount > 0) {
+                       /* The id is still shared. */
+                       cm_deref_id(cm_id_priv);
+                       spin_unlock_irq(&cm.lock);
+                       return;
+               }
                rb_erase(&cm_id_priv->service_node, &cm.listen_service_table);
                spin_unlock_irq(&cm.lock);
                break;
@@ -873,6 +835,11 @@ retest:
        case IB_CM_SIDR_REQ_RCVD:
                spin_unlock_irq(&cm_id_priv->lock);
                cm_reject_sidr_req(cm_id_priv, IB_SIDR_REJECT);
+               spin_lock_irq(&cm.lock);
+               if (!RB_EMPTY_NODE(&cm_id_priv->sidr_id_node))
+                       rb_erase(&cm_id_priv->sidr_id_node,
+                                &cm.remote_sidr_table);
+               spin_unlock_irq(&cm.lock);
                break;
        case IB_CM_REQ_SENT:
        case IB_CM_MRA_REQ_RCVD:
@@ -930,7 +897,6 @@ retest:
        wait_for_completion(&cm_id_priv->comp);
        while ((work = cm_dequeue_work(cm_id_priv)) != NULL)
                cm_free_work(work);
-       kfree(cm_id_priv->compare_data);
        kfree(cm_id_priv->private_data);
        kfree(cm_id_priv);
 }
@@ -941,11 +907,23 @@ void ib_destroy_cm_id(struct ib_cm_id *cm_id)
 }
 EXPORT_SYMBOL(ib_destroy_cm_id);
 
-int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask,
-                struct ib_cm_compare_data *compare_data)
+/**
+ * __ib_cm_listen - Initiates listening on the specified service ID for
+ *   connection and service ID resolution requests.
+ * @cm_id: Connection identifier associated with the listen request.
+ * @service_id: Service identifier matched against incoming connection
+ *   and service ID resolution requests.  The service ID should be specified
+ *   network-byte order.  If set to IB_CM_ASSIGN_SERVICE_ID, the CM will
+ *   assign a service ID to the caller.
+ * @service_mask: Mask applied to service ID used to listen across a
+ *   range of service IDs.  If set to 0, the service ID is matched
+ *   exactly.  This parameter is ignored if %service_id is set to
+ *   IB_CM_ASSIGN_SERVICE_ID.
+ */
+static int __ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id,
+                         __be64 service_mask)
 {
        struct cm_id_private *cm_id_priv, *cur_cm_id_priv;
-       unsigned long flags;
        int ret = 0;
 
        service_mask = service_mask ? service_mask : ~cpu_to_be64(0);
@@ -958,20 +936,9 @@ int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask,
        if (cm_id->state != IB_CM_IDLE)
                return -EINVAL;
 
-       if (compare_data) {
-               cm_id_priv->compare_data = kzalloc(sizeof *compare_data,
-                                                  GFP_KERNEL);
-               if (!cm_id_priv->compare_data)
-                       return -ENOMEM;
-               cm_mask_copy(cm_id_priv->compare_data->data,
-                            compare_data->data, compare_data->mask);
-               memcpy(cm_id_priv->compare_data->mask, compare_data->mask,
-                      sizeof(compare_data->mask));
-       }
-
        cm_id->state = IB_CM_LISTEN;
+       ++cm_id_priv->listen_sharecount;
 
-       spin_lock_irqsave(&cm.lock, flags);
        if (service_id == IB_CM_ASSIGN_SERVICE_ID) {
                cm_id->service_id = cpu_to_be64(cm.listen_service_id++);
                cm_id->service_mask = ~cpu_to_be64(0);
@@ -980,18 +947,95 @@ int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask,
                cm_id->service_mask = service_mask;
        }
        cur_cm_id_priv = cm_insert_listen(cm_id_priv);
-       spin_unlock_irqrestore(&cm.lock, flags);
 
        if (cur_cm_id_priv) {
                cm_id->state = IB_CM_IDLE;
-               kfree(cm_id_priv->compare_data);
-               cm_id_priv->compare_data = NULL;
+               --cm_id_priv->listen_sharecount;
                ret = -EBUSY;
        }
        return ret;
 }
+
+int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask)
+{
+       unsigned long flags;
+       int ret;
+
+       spin_lock_irqsave(&cm.lock, flags);
+       ret = __ib_cm_listen(cm_id, service_id, service_mask);
+       spin_unlock_irqrestore(&cm.lock, flags);
+
+       return ret;
+}
 EXPORT_SYMBOL(ib_cm_listen);
 
+/**
+ * Create a new listening ib_cm_id and listen on the given service ID.
+ *
+ * If there's an existing ID listening on that same device and service ID,
+ * return it.
+ *
+ * @device: Device associated with the cm_id.  All related communication will
+ * be associated with the specified device.
+ * @cm_handler: Callback invoked to notify the user of CM events.
+ * @service_id: Service identifier matched against incoming connection
+ *   and service ID resolution requests.  The service ID should be specified
+ *   network-byte order.  If set to IB_CM_ASSIGN_SERVICE_ID, the CM will
+ *   assign a service ID to the caller.
+ *
+ * Callers should call ib_destroy_cm_id when done with the listener ID.
+ */
+struct ib_cm_id *ib_cm_insert_listen(struct ib_device *device,
+                                    ib_cm_handler cm_handler,
+                                    __be64 service_id)
+{
+       struct cm_id_private *cm_id_priv;
+       struct ib_cm_id *cm_id;
+       unsigned long flags;
+       int err = 0;
+
+       /* Create an ID in advance, since the creation may sleep */
+       cm_id = ib_create_cm_id(device, cm_handler, NULL);
+       if (IS_ERR(cm_id))
+               return cm_id;
+
+       spin_lock_irqsave(&cm.lock, flags);
+
+       if (service_id == IB_CM_ASSIGN_SERVICE_ID)
+               goto new_id;
+
+       /* Find an existing ID */
+       cm_id_priv = cm_find_listen(device, service_id);
+       if (cm_id_priv) {
+               if (cm_id->cm_handler != cm_handler || cm_id->context) {
+                       /* Sharing an ib_cm_id with different handlers is not
+                        * supported */
+                       spin_unlock_irqrestore(&cm.lock, flags);
+                       return ERR_PTR(-EINVAL);
+               }
+               atomic_inc(&cm_id_priv->refcount);
+               ++cm_id_priv->listen_sharecount;
+               spin_unlock_irqrestore(&cm.lock, flags);
+
+               ib_destroy_cm_id(cm_id);
+               cm_id = &cm_id_priv->id;
+               return cm_id;
+       }
+
+new_id:
+       /* Use newly created ID */
+       err = __ib_cm_listen(cm_id, service_id, 0);
+
+       spin_unlock_irqrestore(&cm.lock, flags);
+
+       if (err) {
+               ib_destroy_cm_id(cm_id);
+               return ERR_PTR(err);
+       }
+       return cm_id;
+}
+EXPORT_SYMBOL(ib_cm_insert_listen);
+
 static __be64 cm_form_tid(struct cm_id_private *cm_id_priv,
                          enum cm_msg_sequence msg_seq)
 {
@@ -1268,6 +1312,7 @@ static void cm_format_paths_from_req(struct cm_req_msg *req_msg,
        primary_path->packet_life_time =
                cm_req_get_primary_local_ack_timeout(req_msg);
        primary_path->packet_life_time -= (primary_path->packet_life_time > 0);
+       primary_path->service_id = req_msg->service_id;
 
        if (req_msg->alt_local_lid) {
                memset(alt_path, 0, sizeof *alt_path);
@@ -1289,9 +1334,28 @@ static void cm_format_paths_from_req(struct cm_req_msg *req_msg,
                alt_path->packet_life_time =
                        cm_req_get_alt_local_ack_timeout(req_msg);
                alt_path->packet_life_time -= (alt_path->packet_life_time > 0);
+               alt_path->service_id = req_msg->service_id;
        }
 }
 
+static u16 cm_get_bth_pkey(struct cm_work *work)
+{
+       struct ib_device *ib_dev = work->port->cm_dev->ib_device;
+       u8 port_num = work->port->port_num;
+       u16 pkey_index = work->mad_recv_wc->wc->pkey_index;
+       u16 pkey;
+       int ret;
+
+       ret = ib_get_cached_pkey(ib_dev, port_num, pkey_index, &pkey);
+       if (ret) {
+               dev_warn_ratelimited(&ib_dev->dev, "ib_cm: Couldn't retrieve pkey for incoming request (port %d, pkey index %d). %d\n",
+                                    port_num, pkey_index, ret);
+               return 0;
+       }
+
+       return pkey;
+}
+
 static void cm_format_req_event(struct cm_work *work,
                                struct cm_id_private *cm_id_priv,
                                struct ib_cm_id *listen_id)
@@ -1302,6 +1366,7 @@ static void cm_format_req_event(struct cm_work *work,
        req_msg = (struct cm_req_msg *)work->mad_recv_wc->recv_buf.mad;
        param = &work->cm_event.param.req_rcvd;
        param->listen_id = listen_id;
+       param->bth_pkey = cm_get_bth_pkey(work);
        param->port = cm_id_priv->av.port->port_num;
        param->primary_path = &work->path[0];
        if (req_msg->alt_local_lid)
@@ -1484,8 +1549,7 @@ static struct cm_id_private * cm_match_req(struct cm_work *work,
 
        /* Find matching listen request. */
        listen_cm_id_priv = cm_find_listen(cm_id_priv->id.device,
-                                          req_msg->service_id,
-                                          req_msg->private_data);
+                                          req_msg->service_id);
        if (!listen_cm_id_priv) {
                cm_cleanup_timewait(cm_id_priv->timewait_info);
                spin_unlock_irq(&cm.lock);
@@ -1579,7 +1643,8 @@ static int cm_req_handler(struct cm_work *work)
        ret = cm_init_av_by_path(&work->path[0], &cm_id_priv->av);
        if (ret) {
                ib_get_cached_gid(work->port->cm_dev->ib_device,
-                                 work->port->port_num, 0, &work->path[0].sgid);
+                                 work->port->port_num, 0, &work->path[0].sgid,
+                                 NULL);
                ib_send_cm_rej(cm_id, IB_CM_REJ_INVALID_GID,
                               &work->path[0].sgid, sizeof work->path[0].sgid,
                               NULL, 0);
@@ -2992,6 +3057,8 @@ static void cm_format_sidr_req_event(struct cm_work *work,
        param = &work->cm_event.param.sidr_req_rcvd;
        param->pkey = __be16_to_cpu(sidr_req_msg->pkey);
        param->listen_id = listen_id;
+       param->service_id = sidr_req_msg->service_id;
+       param->bth_pkey = cm_get_bth_pkey(work);
        param->port = work->port->port_num;
        work->cm_event.private_data = &sidr_req_msg->private_data;
 }
@@ -3031,8 +3098,7 @@ static int cm_sidr_req_handler(struct cm_work *work)
        }
        cm_id_priv->id.state = IB_CM_SIDR_REQ_RCVD;
        cur_cm_id_priv = cm_find_listen(cm_id->device,
-                                       sidr_req_msg->service_id,
-                                       sidr_req_msg->private_data);
+                                       sidr_req_msg->service_id);
        if (!cur_cm_id_priv) {
                spin_unlock_irq(&cm.lock);
                cm_reject_sidr_req(cm_id_priv, IB_SIDR_UNSUPPORTED);
@@ -3112,7 +3178,10 @@ int ib_send_cm_sidr_rep(struct ib_cm_id *cm_id,
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 
        spin_lock_irqsave(&cm.lock, flags);
-       rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
+       if (!RB_EMPTY_NODE(&cm_id_priv->sidr_id_node)) {
+               rb_erase(&cm_id_priv->sidr_id_node, &cm.remote_sidr_table);
+               RB_CLEAR_NODE(&cm_id_priv->sidr_id_node);
+       }
        spin_unlock_irqrestore(&cm.lock, flags);
        return 0;
 
@@ -3886,9 +3955,9 @@ free:
        kfree(cm_dev);
 }
 
-static void cm_remove_one(struct ib_device *ib_device)
+static void cm_remove_one(struct ib_device *ib_device, void *client_data)
 {
-       struct cm_device *cm_dev;
+       struct cm_device *cm_dev = client_data;
        struct cm_port *port;
        struct ib_port_modify port_modify = {
                .clr_port_cap_mask = IB_PORT_CM_SUP
@@ -3896,7 +3965,6 @@ static void cm_remove_one(struct ib_device *ib_device)
        unsigned long flags;
        int i;
 
-       cm_dev = ib_get_client_data(ib_device, &cm_client);
        if (!cm_dev)
                return;