nvme: optimise io_uring passthrough completion
[linux-block.git] / fs / ksmbd / connection.c
index 365ac32af505804f270eadaabfba989dbd5e6bfb..4ed379f9b1aa65d635e774ae354310a9c0a60a95 100644 (file)
@@ -20,7 +20,7 @@ static DEFINE_MUTEX(init_lock);
 static struct ksmbd_conn_ops default_conn_ops;
 
 LIST_HEAD(conn_list);
-DEFINE_RWLOCK(conn_list_lock);
+DECLARE_RWSEM(conn_list_lock);
 
 /**
  * ksmbd_conn_free() - free resources of the connection instance
@@ -32,9 +32,9 @@ DEFINE_RWLOCK(conn_list_lock);
  */
 void ksmbd_conn_free(struct ksmbd_conn *conn)
 {
-       write_lock(&conn_list_lock);
+       down_write(&conn_list_lock);
        list_del(&conn->conns_list);
-       write_unlock(&conn_list_lock);
+       up_write(&conn_list_lock);
 
        xa_destroy(&conn->sessions);
        kvfree(conn->request_buf);
@@ -56,7 +56,7 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)
                return NULL;
 
        conn->need_neg = true;
-       conn->status = KSMBD_SESS_NEW;
+       ksmbd_conn_set_new(conn);
        conn->local_nls = load_nls("utf8");
        if (!conn->local_nls)
                conn->local_nls = load_nls_default();
@@ -84,9 +84,9 @@ struct ksmbd_conn *ksmbd_conn_alloc(void)
        spin_lock_init(&conn->llist_lock);
        INIT_LIST_HEAD(&conn->lock_list);
 
-       write_lock(&conn_list_lock);
+       down_write(&conn_list_lock);
        list_add(&conn->conns_list, &conn_list);
-       write_unlock(&conn_list_lock);
+       up_write(&conn_list_lock);
        return conn;
 }
 
@@ -95,7 +95,7 @@ bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c)
        struct ksmbd_conn *t;
        bool ret = false;
 
-       read_lock(&conn_list_lock);
+       down_read(&conn_list_lock);
        list_for_each_entry(t, &conn_list, conns_list) {
                if (memcmp(t->ClientGUID, c->ClientGUID, SMB2_CLIENT_GUID_SIZE))
                        continue;
@@ -103,7 +103,7 @@ bool ksmbd_conn_lookup_dialect(struct ksmbd_conn *c)
                ret = true;
                break;
        }
-       read_unlock(&conn_list_lock);
+       up_read(&conn_list_lock);
        return ret;
 }
 
@@ -147,19 +147,47 @@ int ksmbd_conn_try_dequeue_request(struct ksmbd_work *work)
        return ret;
 }
 
-static void ksmbd_conn_lock(struct ksmbd_conn *conn)
+void ksmbd_conn_lock(struct ksmbd_conn *conn)
 {
        mutex_lock(&conn->srv_mutex);
 }
 
-static void ksmbd_conn_unlock(struct ksmbd_conn *conn)
+void ksmbd_conn_unlock(struct ksmbd_conn *conn)
 {
        mutex_unlock(&conn->srv_mutex);
 }
 
-void ksmbd_conn_wait_idle(struct ksmbd_conn *conn)
+void ksmbd_all_conn_set_status(u64 sess_id, u32 status)
 {
+       struct ksmbd_conn *conn;
+
+       down_read(&conn_list_lock);
+       list_for_each_entry(conn, &conn_list, conns_list) {
+               if (conn->binding || xa_load(&conn->sessions, sess_id))
+                       WRITE_ONCE(conn->status, status);
+       }
+       up_read(&conn_list_lock);
+}
+
+void ksmbd_conn_wait_idle(struct ksmbd_conn *conn, u64 sess_id)
+{
+       struct ksmbd_conn *bind_conn;
+
        wait_event(conn->req_running_q, atomic_read(&conn->req_running) < 2);
+
+       down_read(&conn_list_lock);
+       list_for_each_entry(bind_conn, &conn_list, conns_list) {
+               if (bind_conn == conn)
+                       continue;
+
+               if ((bind_conn->binding || xa_load(&bind_conn->sessions, sess_id)) &&
+                   !ksmbd_conn_releasing(bind_conn) &&
+                   atomic_read(&bind_conn->req_running)) {
+                       wait_event(bind_conn->req_running_q,
+                               atomic_read(&bind_conn->req_running) == 0);
+               }
+       }
+       up_read(&conn_list_lock);
 }
 
 int ksmbd_conn_write(struct ksmbd_work *work)
@@ -243,7 +271,7 @@ bool ksmbd_conn_alive(struct ksmbd_conn *conn)
        if (!ksmbd_server_running())
                return false;
 
-       if (conn->status == KSMBD_SESS_EXITING)
+       if (ksmbd_conn_exiting(conn))
                return false;
 
        if (kthread_should_stop())
@@ -303,7 +331,7 @@ int ksmbd_conn_handler_loop(void *p)
                pdu_size = get_rfc1002_len(hdr_buf);
                ksmbd_debug(CONN, "RFC1002 header %u bytes\n", pdu_size);
 
-               if (conn->status == KSMBD_SESS_GOOD)
+               if (ksmbd_conn_good(conn))
                        max_allowed_pdu_size =
                                SMB3_MAX_MSGSIZE + conn->vals->max_write_size;
                else
@@ -312,7 +340,7 @@ int ksmbd_conn_handler_loop(void *p)
                if (pdu_size > max_allowed_pdu_size) {
                        pr_err_ratelimited("PDU length(%u) exceeded maximum allowed pdu size(%u) on connection(%d)\n",
                                        pdu_size, max_allowed_pdu_size,
-                                       conn->status);
+                                       READ_ONCE(conn->status));
                        break;
                }
 
@@ -360,10 +388,10 @@ int ksmbd_conn_handler_loop(void *p)
        }
 
 out:
+       ksmbd_conn_set_releasing(conn);
        /* Wait till all reference dropped to the Server object*/
        wait_event(conn->r_count_q, atomic_read(&conn->r_count) == 0);
 
-
        if (IS_ENABLED(CONFIG_UNICODE))
                utf8_unload(conn->um);
        unload_nls(conn->local_nls);
@@ -407,7 +435,7 @@ static void stop_sessions(void)
        struct ksmbd_transport *t;
 
 again:
-       read_lock(&conn_list_lock);
+       down_read(&conn_list_lock);
        list_for_each_entry(conn, &conn_list, conns_list) {
                struct task_struct *task;
 
@@ -416,14 +444,14 @@ again:
                if (task)
                        ksmbd_debug(CONN, "Stop session handler %s/%d\n",
                                    task->comm, task_pid_nr(task));
-               conn->status = KSMBD_SESS_EXITING;
+               ksmbd_conn_set_exiting(conn);
                if (t->ops->shutdown) {
-                       read_unlock(&conn_list_lock);
+                       up_read(&conn_list_lock);
                        t->ops->shutdown(t);
-                       read_lock(&conn_list_lock);
+                       down_read(&conn_list_lock);
                }
        }
-       read_unlock(&conn_list_lock);
+       up_read(&conn_list_lock);
 
        if (!list_empty(&conn_list)) {
                schedule_timeout_interruptible(HZ / 10); /* 100ms */