staging: vc04_services: use kref + RCU to reference count services
authorMarcelo Diop-Gonzalez <marcgonzalez@google.com>
Wed, 12 Feb 2020 18:43:32 +0000 (13:43 -0500)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Wed, 12 Feb 2020 21:40:43 +0000 (13:40 -0800)
Currently reference counts are implemented by locking service_spinlock
and then incrementing the service's ->ref_count field, calling
kfree() when the last reference has been dropped. But at the same
time, there's code in multiple places that dereferences pointers
to services without having a reference, so there could be a race there.

It should be possible to avoid taking any lock in unlock_service()
or service_release() because we are setting a single array element
to NULL, and on service creation, a mutex is locked before looking
for a NULL spot to put the new service in.

Using a struct kref and RCU-delaying the freeing of services fixes
this race condition while still making it possible to skip
grabbing a reference in many places. Also it avoids the need to
acquire a single spinlock when e.g. taking a reference on
state->services[i] when somebody else is in the middle of taking
a reference on state->services[j].

Signed-off-by: Marcelo Diop-Gonzalez <marcgonzalez@google.com>
Link: https://lore.kernel.org/r/3bf6f1ec6ace64d7072025505e165b8dd18b25ca.1581532523.git.marcgonzalez@google.com
Signed-off-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
drivers/staging/vc04_services/interface/vchiq_arm/vchiq_arm.c
drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.c
drivers/staging/vc04_services/interface/vchiq_arm/vchiq_core.h

index c456ced431af0d9132a8cbb838aa3cf8a083ef95..3ed0e4ea7f5c0e91e4d3fc0f5c3600b4a6455ed2 100644 (file)
@@ -22,6 +22,7 @@
 #include <linux/platform_device.h>
 #include <linux/compat.h>
 #include <linux/dma-mapping.h>
+#include <linux/rcupdate.h>
 #include <soc/bcm2835/raspberrypi-firmware.h>
 
 #include "vchiq_core.h"
@@ -2096,10 +2097,12 @@ int vchiq_dump_platform_instances(void *dump_context)
        /* There is no list of instances, so instead scan all services,
                marking those that have been dumped. */
 
+       rcu_read_lock();
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service;
                struct vchiq_instance *instance;
 
+               service = rcu_dereference(state->services[i]);
                if (!service || service->base.callback != service_callback)
                        continue;
 
@@ -2107,18 +2110,26 @@ int vchiq_dump_platform_instances(void *dump_context)
                if (instance)
                        instance->mark = 0;
        }
+       rcu_read_unlock();
 
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service;
                struct vchiq_instance *instance;
                int err;
 
-               if (!service || service->base.callback != service_callback)
+               rcu_read_lock();
+               service = rcu_dereference(state->services[i]);
+               if (!service || service->base.callback != service_callback) {
+                       rcu_read_unlock();
                        continue;
+               }
 
                instance = service->instance;
-               if (!instance || instance->mark)
+               if (!instance || instance->mark) {
+                       rcu_read_unlock();
                        continue;
+               }
+               rcu_read_unlock();
 
                len = snprintf(buf, sizeof(buf),
                               "Instance %pK: pid %d,%s completions %d/%d",
@@ -2128,7 +2139,6 @@ int vchiq_dump_platform_instances(void *dump_context)
                               instance->completion_insert -
                               instance->completion_remove,
                               MAX_COMPLETIONS);
-
                err = vchiq_dump(dump_context, buf, len + 1);
                if (err)
                        return err;
@@ -2585,8 +2595,10 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
        if (active_services > MAX_SERVICES)
                only_nonzero = 1;
 
+       rcu_read_lock();
        for (i = 0; i < active_services; i++) {
-               struct vchiq_service *service_ptr = state->services[i];
+               struct vchiq_service *service_ptr =
+                       rcu_dereference(state->services[i]);
 
                if (!service_ptr)
                        continue;
@@ -2604,6 +2616,7 @@ vchiq_dump_service_use_state(struct vchiq_state *state)
                if (found >= MAX_SERVICES)
                        break;
        }
+       rcu_read_unlock();
 
        read_unlock_bh(&arm_state->susp_res_lock);
 
index b2d9013b7f796142f8d8a5c394308b2917026691..65270a5b29db266329023caba8e23fe4672e6d2f 100644 (file)
@@ -1,6 +1,9 @@
 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
 /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */
 
+#include <linux/kref.h>
+#include <linux/rcupdate.h>
+
 #include "vchiq_core.h"
 
 #define VCHIQ_SLOT_HANDLER_STACK 8192
@@ -54,7 +57,6 @@ int vchiq_core_log_level = VCHIQ_LOG_DEFAULT;
 int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
 int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;
 
-static DEFINE_SPINLOCK(service_spinlock);
 DEFINE_SPINLOCK(bulk_waiter_spinlock);
 static DEFINE_SPINLOCK(quota_spinlock);
 
@@ -136,44 +138,41 @@ find_service_by_handle(unsigned int handle)
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
-           service->handle == handle) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
-       return service;
+           service->handle == handle &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
+       return NULL;
 }
 
 struct vchiq_service *
 find_service_by_port(struct vchiq_state *state, int localport)
 {
-       struct vchiq_service *service = NULL;
 
        if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
-               spin_lock(&service_spinlock);
-               service = state->services[localport];
-               if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
-                       WARN_ON(service->ref_count == 0);
-                       service->ref_count++;
-               } else
-                       service = NULL;
-               spin_unlock(&service_spinlock);
-       }
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid port %d", localport);
+               struct vchiq_service *service;
 
-       return service;
+               rcu_read_lock();
+               service = rcu_dereference(state->services[localport]);
+               if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
+                       return service;
+               }
+               rcu_read_unlock();
+       }
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid port %d", localport);
+       return NULL;
 }
 
 struct vchiq_service *
@@ -182,22 +181,20 @@ find_service_for_instance(struct vchiq_instance *instance,
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
            service->handle == handle &&
-           service->instance == instance) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
-       return service;
+           service->instance == instance &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
+       return NULL;
 }
 
 struct vchiq_service *
@@ -206,23 +203,21 @@ find_closed_service_for_instance(struct vchiq_instance *instance,
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service &&
            (service->srvstate == VCHIQ_SRVSTATE_FREE ||
             service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
            service->handle == handle &&
-           service->instance == instance) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
+           service->instance == instance &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
        return service;
 }
 
@@ -233,19 +228,19 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
        struct vchiq_service *service = NULL;
        int idx = *pidx;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        while (idx < state->unused_service) {
-               struct vchiq_service *srv = state->services[idx++];
+               struct vchiq_service *srv;
 
+               srv = rcu_dereference(state->services[idx++]);
                if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
-                   srv->instance == instance) {
-                       service = srv;
-                       WARN_ON(service->ref_count == 0);
-                       service->ref_count++;
+                   srv->instance == instance &&
+                   kref_get_unless_zero(&srv->ref_count)) {
+                       service = rcu_pointer_handoff(srv);
                        break;
                }
        }
-       spin_unlock(&service_spinlock);
+       rcu_read_unlock();
 
        *pidx = idx;
 
@@ -255,43 +250,34 @@ next_service_by_instance(struct vchiq_state *state, struct vchiq_instance *insta
 void
 lock_service(struct vchiq_service *service)
 {
-       spin_lock(&service_spinlock);
-       WARN_ON(!service);
-       if (service) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
+       if (!service) {
+               WARN(1, "%s service is NULL\n", __func__);
+               return;
        }
-       spin_unlock(&service_spinlock);
+       kref_get(&service->ref_count);
+}
+
+static void service_release(struct kref *kref)
+{
+       struct vchiq_service *service =
+               container_of(kref, struct vchiq_service, ref_count);
+       struct vchiq_state *state = service->state;
+
+       WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
+       rcu_assign_pointer(state->services[service->localport], NULL);
+       if (service->userdata_term)
+               service->userdata_term(service->base.userdata);
+       kfree_rcu(service, rcu);
 }
 
 void
 unlock_service(struct vchiq_service *service)
 {
-       spin_lock(&service_spinlock);
        if (!service) {
                WARN(1, "%s: service is NULL\n", __func__);
-               goto unlock;
-       }
-       if (!service->ref_count) {
-               WARN(1, "%s: ref_count is zero\n", __func__);
-               goto unlock;
-       }
-       service->ref_count--;
-       if (!service->ref_count) {
-               struct vchiq_state *state = service->state;
-
-               WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
-               state->services[service->localport] = NULL;
-       } else {
-               service = NULL;
+               return;
        }
-unlock:
-       spin_unlock(&service_spinlock);
-
-       if (service && service->userdata_term)
-               service->userdata_term(service->base.userdata);
-
-       kfree(service);
+       kref_put(&service->ref_count, service_release);
 }
 
 int
@@ -310,9 +296,14 @@ vchiq_get_client_id(unsigned int handle)
 void *
 vchiq_get_service_userdata(unsigned int handle)
 {
-       struct vchiq_service *service = handle_to_service(handle);
+       void *userdata;
+       struct vchiq_service *service;
 
-       return service ? service->base.userdata : NULL;
+       rcu_read_lock();
+       service = handle_to_service(handle);
+       userdata = service ? service->base.userdata : NULL;
+       rcu_read_unlock();
+       return userdata;
 }
 
 static void
@@ -460,19 +451,23 @@ get_listening_service(struct vchiq_state *state, int fourcc)
 
        WARN_ON(fourcc == VCHIQ_FOURCC_INVALID);
 
+       rcu_read_lock();
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service;
 
+               service = rcu_dereference(state->services[i]);
                if (service &&
                    service->public_fourcc == fourcc &&
                    (service->srvstate == VCHIQ_SRVSTATE_LISTENING ||
                     (service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-                     service->remoteport == VCHIQ_PORT_FREE))) {
-                       lock_service(service);
+                     service->remoteport == VCHIQ_PORT_FREE)) &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
                        return service;
                }
        }
-
+       rcu_read_unlock();
        return NULL;
 }
 
@@ -482,15 +477,20 @@ get_connected_service(struct vchiq_state *state, unsigned int port)
 {
        int i;
 
+       rcu_read_lock();
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service =
+                       rcu_dereference(state->services[i]);
 
                if (service && service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-                   service->remoteport == port) {
-                       lock_service(service);
+                   service->remoteport == port &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
                        return service;
                }
        }
+       rcu_read_unlock();
        return NULL;
 }
 
@@ -2260,7 +2260,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
                           vchiq_userdata_term userdata_term)
 {
        struct vchiq_service *service;
-       struct vchiq_service **pservice = NULL;
+       struct vchiq_service __rcu **pservice = NULL;
        struct vchiq_service_quota *service_quota;
        int i;
 
@@ -2272,7 +2272,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
        service->base.callback = params->callback;
        service->base.userdata = params->userdata;
        service->handle        = VCHIQ_SERVICE_HANDLE_INVALID;
-       service->ref_count     = 1;
+       kref_init(&service->ref_count);
        service->srvstate      = VCHIQ_SRVSTATE_FREE;
        service->userdata_term = userdata_term;
        service->localport     = VCHIQ_PORT_FREE;
@@ -2298,7 +2298,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
        mutex_init(&service->bulk_mutex);
        memset(&service->stats, 0, sizeof(service->stats));
 
-       /* Although it is perfectly possible to use service_spinlock
+       /* Although it is perfectly possible to use spinlock
        ** to protect the creation of services, it is overkill as it
        ** disables interrupts while the array is searched.
        ** The only danger is of another thread trying to create a
@@ -2316,17 +2316,17 @@ vchiq_add_service_internal(struct vchiq_state *state,
 
        if (srvstate == VCHIQ_SRVSTATE_OPENING) {
                for (i = 0; i < state->unused_service; i++) {
-                       struct vchiq_service *srv = state->services[i];
-
-                       if (!srv) {
+                       if (!rcu_access_pointer(state->services[i])) {
                                pservice = &state->services[i];
                                break;
                        }
                }
        } else {
+               rcu_read_lock();
                for (i = (state->unused_service - 1); i >= 0; i--) {
-                       struct vchiq_service *srv = state->services[i];
+                       struct vchiq_service *srv;
 
+                       srv = rcu_dereference(state->services[i]);
                        if (!srv)
                                pservice = &state->services[i];
                        else if ((srv->public_fourcc == params->fourcc)
@@ -2339,6 +2339,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
                                break;
                        }
                }
+               rcu_read_unlock();
        }
 
        if (pservice) {
@@ -2350,7 +2351,7 @@ vchiq_add_service_internal(struct vchiq_state *state,
                        (state->id * VCHIQ_MAX_SERVICES) |
                        service->localport;
                handle_seq += VCHIQ_MAX_STATES * VCHIQ_MAX_SERVICES;
-               *pservice = service;
+               rcu_assign_pointer(*pservice, service);
                if (pservice == &state->services[state->unused_service])
                        state->unused_service++;
        }
@@ -2416,10 +2417,10 @@ vchiq_open_service_internal(struct vchiq_service *service, int client_id)
                           (service->srvstate != VCHIQ_SRVSTATE_OPENSYNC)) {
                        if (service->srvstate != VCHIQ_SRVSTATE_CLOSEWAIT)
                                vchiq_log_error(vchiq_core_log_level,
-                                               "%d: osi - srvstate = %s (ref %d)",
+                                               "%d: osi - srvstate = %s (ref %u)",
                                                service->state->id,
                                                srvstate_names[service->srvstate],
-                                               service->ref_count);
+                                               kref_read(&service->ref_count));
                        status = VCHIQ_ERROR;
                        VCHIQ_SERVICE_STATS_INC(service, error_count);
                        vchiq_release_service_internal(service);
@@ -3425,10 +3426,13 @@ int vchiq_dump_service_state(void *dump_context, struct vchiq_service *service)
        char buf[80];
        int len;
        int err;
+       unsigned int ref_count;
 
+       /*Don't include the lock just taken*/
+       ref_count = kref_read(&service->ref_count) - 1;
        len = scnprintf(buf, sizeof(buf), "Service %u: %s (ref %u)",
                        service->localport, srvstate_names[service->srvstate],
-                       service->ref_count - 1); /*Don't include the lock just taken*/
+                       ref_count);
 
        if (service->srvstate != VCHIQ_SRVSTATE_FREE) {
                char remoteport[30];
index 604d0c3308191280610c44ca78527b9af2542b10..30e4965c76667cdd2be77d9a1ee414ec12e28ef5 100644 (file)
@@ -7,6 +7,8 @@
 #include <linux/mutex.h>
 #include <linux/completion.h>
 #include <linux/kthread.h>
+#include <linux/kref.h>
+#include <linux/rcupdate.h>
 #include <linux/wait.h>
 
 #include "vchiq_cfg.h"
@@ -251,7 +253,8 @@ struct vchiq_slot_info {
 struct vchiq_service {
        struct vchiq_service_base base;
        unsigned int handle;
-       unsigned int ref_count;
+       struct kref ref_count;
+       struct rcu_head rcu;
        int srvstate;
        vchiq_userdata_term userdata_term;
        unsigned int localport;
@@ -464,7 +467,7 @@ struct vchiq_state {
                int error_count;
        } stats;
 
-       struct vchiq_service *services[VCHIQ_MAX_SERVICES];
+       struct vchiq_service __rcu *services[VCHIQ_MAX_SERVICES];
        struct vchiq_service_quota service_quotas[VCHIQ_MAX_SERVICES];
        struct vchiq_slot_info slot_info[VCHIQ_MAX_SLOTS];
 
@@ -545,12 +548,13 @@ request_poll(struct vchiq_state *state, struct vchiq_service *service,
 static inline struct vchiq_service *
 handle_to_service(unsigned int handle)
 {
+       int idx = handle & (VCHIQ_MAX_SERVICES - 1);
        struct vchiq_state *state = vchiq_states[(handle / VCHIQ_MAX_SERVICES) &
                (VCHIQ_MAX_STATES - 1)];
+
        if (!state)
                return NULL;
-
-       return state->services[handle & (VCHIQ_MAX_SERVICES - 1)];
+       return rcu_dereference(state->services[idx]);
 }
 
 extern struct vchiq_service *