rxrpc: Permit multiple service binding
[linux-2.6-block.git] / net / rxrpc / af_rxrpc.c
index 7fb59c3f1542af319b882399b4a0f563dc0b8a0d..3b982bca7d22d029fc134042aeb4376a6dd6d4f4 100644 (file)
@@ -38,9 +38,6 @@ MODULE_PARM_DESC(debug, "RxRPC debugging mask");
 static struct proto rxrpc_proto;
 static const struct proto_ops rxrpc_rpc_ops;
 
-/* local epoch for detecting local-end reset */
-u32 rxrpc_epoch;
-
 /* current debugging ID */
 atomic_t rxrpc_debug_id;
 
@@ -134,9 +131,8 @@ static int rxrpc_validate_address(struct rxrpc_sock *rx,
 static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
 {
        struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)saddr;
-       struct sock *sk = sock->sk;
        struct rxrpc_local *local;
-       struct rxrpc_sock *rx = rxrpc_sk(sk);
+       struct rxrpc_sock *rx = rxrpc_sk(sock->sk);
        u16 service_id = srx->srx_service;
        int ret;
 
@@ -148,31 +144,48 @@ static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len)
 
        lock_sock(&rx->sk);
 
-       if (rx->sk.sk_state != RXRPC_UNBOUND) {
-               ret = -EINVAL;
-               goto error_unlock;
-       }
-
-       memcpy(&rx->srx, srx, sizeof(rx->srx));
+       switch (rx->sk.sk_state) {
+       case RXRPC_UNBOUND:
+               rx->srx = *srx;
+               local = rxrpc_lookup_local(sock_net(&rx->sk), &rx->srx);
+               if (IS_ERR(local)) {
+                       ret = PTR_ERR(local);
+                       goto error_unlock;
+               }
 
-       local = rxrpc_lookup_local(&rx->srx);
-       if (IS_ERR(local)) {
-               ret = PTR_ERR(local);
-               goto error_unlock;
-       }
+               if (service_id) {
+                       write_lock(&local->services_lock);
+                       if (rcu_access_pointer(local->service))
+                               goto service_in_use;
+                       rx->local = local;
+                       rcu_assign_pointer(local->service, rx);
+                       write_unlock(&local->services_lock);
+
+                       rx->sk.sk_state = RXRPC_SERVER_BOUND;
+               } else {
+                       rx->local = local;
+                       rx->sk.sk_state = RXRPC_CLIENT_BOUND;
+               }
+               break;
 
-       if (service_id) {
-               write_lock(&local->services_lock);
-               if (rcu_access_pointer(local->service))
-                       goto service_in_use;
-               rx->local = local;
-               rcu_assign_pointer(local->service, rx);
-               write_unlock(&local->services_lock);
+       case RXRPC_SERVER_BOUND:
+               ret = -EINVAL;
+               if (service_id == 0)
+                       goto error_unlock;
+               ret = -EADDRINUSE;
+               if (service_id == rx->srx.srx_service)
+                       goto error_unlock;
+               ret = -EINVAL;
+               srx->srx_service = rx->srx.srx_service;
+               if (memcmp(srx, &rx->srx, sizeof(*srx)) != 0)
+                       goto error_unlock;
+               rx->second_service = service_id;
+               rx->sk.sk_state = RXRPC_SERVER_BOUND2;
+               break;
 
-               rx->sk.sk_state = RXRPC_SERVER_BOUND;
-       } else {
-               rx->local = local;
-               rx->sk.sk_state = RXRPC_CLIENT_BOUND;
+       default:
+               ret = -EINVAL;
+               goto error_unlock;
        }
 
        release_sock(&rx->sk);
@@ -209,6 +222,7 @@ static int rxrpc_listen(struct socket *sock, int backlog)
                ret = -EADDRNOTAVAIL;
                break;
        case RXRPC_SERVER_BOUND:
+       case RXRPC_SERVER_BOUND2:
                ASSERT(rx->local != NULL);
                max = READ_ONCE(rxrpc_max_backlog);
                ret = -EINVAL;
@@ -434,7 +448,7 @@ static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len)
                        ret = -EAFNOSUPPORT;
                        goto error_unlock;
                }
-               local = rxrpc_lookup_local(&rx->srx);
+               local = rxrpc_lookup_local(sock_net(sock->sk), &rx->srx);
                if (IS_ERR(local)) {
                        ret = PTR_ERR(local);
                        goto error_unlock;
@@ -582,9 +596,6 @@ static int rxrpc_create(struct net *net, struct socket *sock, int protocol,
 
        _enter("%p,%d", sock, protocol);
 
-       if (!net_eq(net, &init_net))
-               return -EAFNOSUPPORT;
-
        /* we support transport protocol UDP/UDP6 only */
        if (protocol != PF_INET &&
            IS_ENABLED(CONFIG_AF_RXRPC_IPV6) && protocol != PF_INET6)
@@ -780,8 +791,6 @@ static int __init af_rxrpc_init(void)
 
        BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > FIELD_SIZEOF(struct sk_buff, cb));
 
-       get_random_bytes(&rxrpc_epoch, sizeof(rxrpc_epoch));
-       rxrpc_epoch |= RXRPC_RANDOM_EPOCH;
        get_random_bytes(&tmp, sizeof(tmp));
        tmp &= 0x3fffffff;
        if (tmp == 0)
@@ -809,6 +818,10 @@ static int __init af_rxrpc_init(void)
                goto error_security;
        }
 
+       ret = register_pernet_subsys(&rxrpc_net_ops);
+       if (ret)
+               goto error_pernet;
+
        ret = proto_register(&rxrpc_proto, 1);
        if (ret < 0) {
                pr_crit("Cannot register protocol\n");
@@ -839,11 +852,6 @@ static int __init af_rxrpc_init(void)
                goto error_sysctls;
        }
 
-#ifdef CONFIG_PROC_FS
-       proc_create("rxrpc_calls", 0, init_net.proc_net, &rxrpc_call_seq_fops);
-       proc_create("rxrpc_conns", 0, init_net.proc_net,
-                   &rxrpc_connection_seq_fops);
-#endif
        return 0;
 
 error_sysctls:
@@ -855,6 +863,8 @@ error_key_type:
 error_sock:
        proto_unregister(&rxrpc_proto);
 error_proto:
+       unregister_pernet_subsys(&rxrpc_net_ops);
+error_pernet:
        rxrpc_exit_security();
 error_security:
        destroy_workqueue(rxrpc_workqueue);
@@ -875,14 +885,16 @@ static void __exit af_rxrpc_exit(void)
        unregister_key_type(&key_type_rxrpc);
        sock_unregister(PF_RXRPC);
        proto_unregister(&rxrpc_proto);
-       rxrpc_destroy_all_calls();
-       rxrpc_destroy_all_connections();
+       unregister_pernet_subsys(&rxrpc_net_ops);
        ASSERTCMP(atomic_read(&rxrpc_n_tx_skbs), ==, 0);
        ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0);
-       rxrpc_destroy_all_locals();
 
-       remove_proc_entry("rxrpc_conns", init_net.proc_net);
-       remove_proc_entry("rxrpc_calls", init_net.proc_net);
+       /* Make sure the local and peer records pinned by any dying connections
+        * are released.
+        */
+       rcu_barrier();
+       rxrpc_destroy_client_conn_ids();
+
        destroy_workqueue(rxrpc_workqueue);
        rxrpc_exit_security();
        kmem_cache_destroy(rxrpc_call_jar);