net/tcp: Calculate TCP-AO traffic keys
authorDmitry Safonov <dima@arista.com>
Mon, 23 Oct 2023 19:21:57 +0000 (20:21 +0100)
committerDavid S. Miller <davem@davemloft.net>
Fri, 27 Oct 2023 09:35:44 +0000 (10:35 +0100)
Add traffic key calculation the way it's described in RFC5926.
Wire it up to tcp_finish_connect() and cache the new keys straight away
on already established TCP connections.

Co-developed-by: Francesco Ruggeri <fruggeri@arista.com>
Signed-off-by: Francesco Ruggeri <fruggeri@arista.com>
Co-developed-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Salam Noureddine <noureddine@arista.com>
Signed-off-by: Dmitry Safonov <dima@arista.com>
Acked-by: David Ahern <dsahern@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/tcp.h
include/net/tcp_ao.h
net/ipv4/tcp_ao.c
net/ipv4/tcp_input.c
net/ipv4/tcp_ipv4.c
net/ipv4/tcp_output.c
net/ipv6/tcp_ao.c
net/ipv6/tcp_ipv6.c

index 0272117511eaa01ee8d2fb7ed03ee1f4b3727d5b..b72c46cf229b892395aaffd52466354e09dc532e 100644 (file)
@@ -2197,6 +2197,9 @@ struct tcp_sock_af_ops {
        struct tcp_ao_key *(*ao_lookup)(const struct sock *sk,
                                        struct sock *addr_sk,
                                        int sndid, int rcvid);
+       int (*ao_calc_key_sk)(struct tcp_ao_key *mkt, u8 *key,
+                             const struct sock *sk,
+                             __be32 sisn, __be32 disn, bool send);
 #endif
 };
 
index 3c7f576376f9ac4aa180fbd5734ccfa4bbee2adb..b021a811511bad8f80c25b33bffd4338d64a2f38 100644 (file)
@@ -89,8 +89,32 @@ struct tcp_ao_info {
 };
 
 #ifdef CONFIG_TCP_AO
+/* TCP-AO structures and functions */
+
+struct tcp4_ao_context {
+       __be32          saddr;
+       __be32          daddr;
+       __be16          sport;
+       __be16          dport;
+       __be32          sisn;
+       __be32          disn;
+};
+
+struct tcp6_ao_context {
+       struct in6_addr saddr;
+       struct in6_addr daddr;
+       __be16          sport;
+       __be16          dport;
+       __be32          sisn;
+       __be32          disn;
+};
+
+struct tcp_sigpool;
+
 int tcp_parse_ao(struct sock *sk, int cmd, unsigned short int family,
                 sockptr_t optval, int optlen);
+int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
+                           unsigned int len, struct tcp_sigpool *hp);
 void tcp_ao_destroy_sock(struct sock *sk);
 struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
                                    const union tcp_ao_addr *addr,
@@ -99,11 +123,22 @@ struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
 int tcp_v4_parse_ao(struct sock *sk, int cmd, sockptr_t optval, int optlen);
 struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
                                    int sndid, int rcvid);
+int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
+                         const struct sock *sk,
+                         __be32 sisn, __be32 disn, bool send);
 /* ipv6 specific functions */
-int tcp_v6_parse_ao(struct sock *sk, int cmd, sockptr_t optval, int optlen);
+int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
+                         const struct sock *sk, __be32 sisn,
+                         __be32 disn, bool send);
 struct tcp_ao_key *tcp_v6_ao_lookup(const struct sock *sk,
                                    struct sock *addr_sk, int sndid, int rcvid);
-#else
+int tcp_v6_parse_ao(struct sock *sk, int cmd, sockptr_t optval, int optlen);
+void tcp_ao_established(struct sock *sk);
+void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb);
+void tcp_ao_connect_init(struct sock *sk);
+
+#else /* CONFIG_TCP_AO */
+
 static inline struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
                const union tcp_ao_addr *addr, int family, int sndid, int rcvid)
 {
@@ -113,6 +148,18 @@ static inline struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
 static inline void tcp_ao_destroy_sock(struct sock *sk)
 {
 }
+
+static inline void tcp_ao_established(struct sock *sk)
+{
+}
+
+static inline void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
+{
+}
+
+static inline void tcp_ao_connect_init(struct sock *sk)
+{
+}
 #endif
 
 #endif /* _TCP_AO_H */
index ee23356101f49c2df24b543fb721db84351d25d2..e478341fc33622aed360ebf70b885874aceb7e45 100644 (file)
 #include <net/tcp.h>
 #include <net/ipv6.h>
 
+int tcp_ao_calc_traffic_key(struct tcp_ao_key *mkt, u8 *key, void *ctx,
+                           unsigned int len, struct tcp_sigpool *hp)
+{
+       struct scatterlist sg;
+       int ret;
+
+       if (crypto_ahash_setkey(crypto_ahash_reqtfm(hp->req),
+                               mkt->key, mkt->keylen))
+               goto clear_hash;
+
+       ret = crypto_ahash_init(hp->req);
+       if (ret)
+               goto clear_hash;
+
+       sg_init_one(&sg, ctx, len);
+       ahash_request_set_crypt(hp->req, &sg, key, len);
+       crypto_ahash_update(hp->req);
+
+       ret = crypto_ahash_final(hp->req);
+       if (ret)
+               goto clear_hash;
+
+       return 0;
+clear_hash:
+       memset(key, 0, tcp_ao_digest_size(mkt));
+       return 1;
+}
+
 /* Optimized version of tcp_ao_do_lookup(): only for sockets for which
  * it's known that the keys in ao_info are matching peer's
  * family/address/VRF/etc.
@@ -169,6 +197,71 @@ void tcp_ao_destroy_sock(struct sock *sk)
        kfree_rcu(ao, rcu);
 }
 
+/* 4 tuple and ISNs are expected in NBO */
+static int tcp_v4_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
+                             __be32 saddr, __be32 daddr,
+                             __be16 sport, __be16 dport,
+                             __be32 sisn,  __be32 disn)
+{
+       /* See RFC5926 3.1.1 */
+       struct kdf_input_block {
+               u8                      counter;
+               u8                      label[6];
+               struct tcp4_ao_context  ctx;
+               __be16                  outlen;
+       } __packed * tmp;
+       struct tcp_sigpool hp;
+       int err;
+
+       err = tcp_sigpool_start(mkt->tcp_sigpool_id, &hp);
+       if (err)
+               return err;
+
+       tmp = hp.scratch;
+       tmp->counter    = 1;
+       memcpy(tmp->label, "TCP-AO", 6);
+       tmp->ctx.saddr  = saddr;
+       tmp->ctx.daddr  = daddr;
+       tmp->ctx.sport  = sport;
+       tmp->ctx.dport  = dport;
+       tmp->ctx.sisn   = sisn;
+       tmp->ctx.disn   = disn;
+       tmp->outlen     = htons(tcp_ao_digest_size(mkt) * 8); /* in bits */
+
+       err = tcp_ao_calc_traffic_key(mkt, key, tmp, sizeof(*tmp), &hp);
+       tcp_sigpool_end(&hp);
+
+       return err;
+}
+
+int tcp_v4_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
+                         const struct sock *sk,
+                         __be32 sisn, __be32 disn, bool send)
+{
+       if (send)
+               return tcp_v4_ao_calc_key(mkt, key, sk->sk_rcv_saddr,
+                                         sk->sk_daddr, htons(sk->sk_num),
+                                         sk->sk_dport, sisn, disn);
+       else
+               return tcp_v4_ao_calc_key(mkt, key, sk->sk_daddr,
+                                         sk->sk_rcv_saddr, sk->sk_dport,
+                                         htons(sk->sk_num), disn, sisn);
+}
+
+static int tcp_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
+                             const struct sock *sk,
+                             __be32 sisn, __be32 disn, bool send)
+{
+       if (mkt->family == AF_INET)
+               return tcp_v4_ao_calc_key_sk(mkt, key, sk, sisn, disn, send);
+#if IS_ENABLED(CONFIG_IPV6)
+       else if (mkt->family == AF_INET6)
+               return tcp_v6_ao_calc_key_sk(mkt, key, sk, sisn, disn, send);
+#endif
+       else
+               return -EOPNOTSUPP;
+}
+
 struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
                                    int sndid, int rcvid)
 {
@@ -177,6 +270,113 @@ struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
        return tcp_ao_do_lookup(sk, addr, AF_INET, sndid, rcvid);
 }
 
+static int tcp_ao_cache_traffic_keys(const struct sock *sk,
+                                    struct tcp_ao_info *ao,
+                                    struct tcp_ao_key *ao_key)
+{
+       u8 *traffic_key = snd_other_key(ao_key);
+       int ret;
+
+       ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk,
+                                ao->lisn, ao->risn, true);
+       if (ret)
+               return ret;
+
+       traffic_key = rcv_other_key(ao_key);
+       ret = tcp_ao_calc_key_sk(ao_key, traffic_key, sk,
+                                ao->lisn, ao->risn, false);
+       return ret;
+}
+
+void tcp_ao_connect_init(struct sock *sk)
+{
+       struct tcp_sock *tp = tcp_sk(sk);
+       struct tcp_ao_info *ao_info;
+       union tcp_ao_addr *addr;
+       struct tcp_ao_key *key;
+       int family;
+
+       ao_info = rcu_dereference_protected(tp->ao_info,
+                                           lockdep_sock_is_held(sk));
+       if (!ao_info)
+               return;
+
+       /* Remove all keys that don't match the peer */
+       family = sk->sk_family;
+       if (family == AF_INET)
+               addr = (union tcp_ao_addr *)&sk->sk_daddr;
+#if IS_ENABLED(CONFIG_IPV6)
+       else if (family == AF_INET6)
+               addr = (union tcp_ao_addr *)&sk->sk_v6_daddr;
+#endif
+       else
+               return;
+
+       hlist_for_each_entry_rcu(key, &ao_info->head, node) {
+               if (!tcp_ao_key_cmp(key, addr, key->prefixlen, family, -1, -1))
+                       continue;
+
+               if (key == ao_info->current_key)
+                       ao_info->current_key = NULL;
+               if (key == ao_info->rnext_key)
+                       ao_info->rnext_key = NULL;
+               hlist_del_rcu(&key->node);
+               atomic_sub(tcp_ao_sizeof_key(key), &sk->sk_omem_alloc);
+               call_rcu(&key->rcu, tcp_ao_key_free_rcu);
+       }
+
+       key = tp->af_specific->ao_lookup(sk, sk, -1, -1);
+       if (key) {
+               /* if current_key or rnext_key were not provided,
+                * use the first key matching the peer
+                */
+               if (!ao_info->current_key)
+                       ao_info->current_key = key;
+               if (!ao_info->rnext_key)
+                       ao_info->rnext_key = key;
+               tp->tcp_header_len += tcp_ao_len(key);
+
+               ao_info->lisn = htonl(tp->write_seq);
+       } else {
+               /* Can't happen: tcp_connect() verifies that there's
+                * at least one tcp-ao key that matches the remote peer.
+                */
+               WARN_ON_ONCE(1);
+               rcu_assign_pointer(tp->ao_info, NULL);
+               kfree(ao_info);
+       }
+}
+
+void tcp_ao_established(struct sock *sk)
+{
+       struct tcp_ao_info *ao;
+       struct tcp_ao_key *key;
+
+       ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
+                                      lockdep_sock_is_held(sk));
+       if (!ao)
+               return;
+
+       hlist_for_each_entry_rcu(key, &ao->head, node)
+               tcp_ao_cache_traffic_keys(sk, ao, key);
+}
+
+void tcp_ao_finish_connect(struct sock *sk, struct sk_buff *skb)
+{
+       struct tcp_ao_info *ao;
+       struct tcp_ao_key *key;
+
+       ao = rcu_dereference_protected(tcp_sk(sk)->ao_info,
+                                      lockdep_sock_is_held(sk));
+       if (!ao)
+               return;
+
+       WRITE_ONCE(ao->risn, tcp_hdr(skb)->seq);
+
+       hlist_for_each_entry_rcu(key, &ao->head, node)
+               tcp_ao_cache_traffic_keys(sk, ao, key);
+}
+
 static bool tcp_ao_can_set_current_rnext(struct sock *sk)
 {
        /* There aren't current/rnext keys on TCP_LISTEN sockets */
@@ -558,6 +758,12 @@ static int tcp_ao_add_cmd(struct sock *sk, unsigned short int family,
        if (ret < 0)
                goto err_free_sock;
 
+       /* Change this condition if we allow adding keys in states
+        * like close_wait, syn_sent or fin_wait...
+        */
+       if (sk->sk_state == TCP_ESTABLISHED)
+               tcp_ao_cache_traffic_keys(sk, ao_info, key);
+
        tcp_ao_link_mkt(ao_info, key);
        if (first) {
                sk_gso_disable(sk);
index 00d04ab68958cf3c34280f1aec3f688cd62ee9b8..6ee0342b5338540318f173a4527a4d4b3af17b54 100644 (file)
@@ -6151,6 +6151,7 @@ void tcp_finish_connect(struct sock *sk, struct sk_buff *skb)
        struct tcp_sock *tp = tcp_sk(sk);
        struct inet_connection_sock *icsk = inet_csk(sk);
 
+       tcp_ao_finish_connect(sk, skb);
        tcp_set_state(sk, TCP_ESTABLISHED);
        icsk->icsk_ack.lrcvtime = tcp_jiffies32;
 
@@ -6648,6 +6649,7 @@ int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb)
                                          skb);
                        WRITE_ONCE(tp->copied_seq, tp->rcv_nxt);
                }
+               tcp_ao_established(sk);
                smp_mb();
                tcp_set_state(sk, TCP_ESTABLISHED);
                sk->sk_state_change(sk);
index 698e58a3ccec9d16536104a7b9310f7410322070..3c73b58293775c58a21599cab5a1b88efaea8d5d 100644 (file)
@@ -2288,6 +2288,7 @@ static const struct tcp_sock_af_ops tcp_sock_ipv4_specific = {
 #ifdef CONFIG_TCP_AO
        .ao_lookup              = tcp_v4_ao_lookup,
        .ao_parse               = tcp_v4_parse_ao,
+       .ao_calc_key_sk         = tcp_v4_ao_calc_key_sk,
 #endif
 };
 #endif
index 1b90107f7038b51dcdc7a04cec36fcde4602f6ed..9fbf1b2e20258ae343da6136a7f9cd0cd915dc4f 100644 (file)
@@ -3749,6 +3749,8 @@ static void tcp_connect_init(struct sock *sk)
        if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_timestamps))
                tp->tcp_header_len += TCPOLEN_TSTAMP_ALIGNED;
 
+       tcp_ao_connect_init(sk);
+
        /* If user gave his TCP_MAXSEG, record it to clamp */
        if (tp->rx_opt.user_mss)
                tp->rx_opt.mss_clamp = tp->rx_opt.user_mss;
index 0640acaee67b9b56de75d58e8761b9b5e49e94d5..9ab594fadbd9316863a0ba8c8c9ff275e9153ac3 100644 (file)
 #include <net/tcp.h>
 #include <net/ipv6.h>
 
+static int tcp_v6_ao_calc_key(struct tcp_ao_key *mkt, u8 *key,
+                             const struct in6_addr *saddr,
+                             const struct in6_addr *daddr,
+                             __be16 sport, __be16 dport,
+                             __be32 sisn, __be32 disn)
+{
+       struct kdf_input_block {
+               u8                      counter;
+               u8                      label[6];
+               struct tcp6_ao_context  ctx;
+               __be16                  outlen;
+       } __packed * tmp;
+       struct tcp_sigpool hp;
+       int err;
+
+       err = tcp_sigpool_start(mkt->tcp_sigpool_id, &hp);
+       if (err)
+               return err;
+
+       tmp = hp.scratch;
+       tmp->counter    = 1;
+       memcpy(tmp->label, "TCP-AO", 6);
+       tmp->ctx.saddr  = *saddr;
+       tmp->ctx.daddr  = *daddr;
+       tmp->ctx.sport  = sport;
+       tmp->ctx.dport  = dport;
+       tmp->ctx.sisn   = sisn;
+       tmp->ctx.disn   = disn;
+       tmp->outlen     = htons(tcp_ao_digest_size(mkt) * 8); /* in bits */
+
+       err = tcp_ao_calc_traffic_key(mkt, key, tmp, sizeof(*tmp), &hp);
+       tcp_sigpool_end(&hp);
+
+       return err;
+}
+
+int tcp_v6_ao_calc_key_sk(struct tcp_ao_key *mkt, u8 *key,
+                         const struct sock *sk, __be32 sisn,
+                         __be32 disn, bool send)
+{
+       if (send)
+               return tcp_v6_ao_calc_key(mkt, key, &sk->sk_v6_rcv_saddr,
+                                         &sk->sk_v6_daddr, htons(sk->sk_num),
+                                         sk->sk_dport, sisn, disn);
+       else
+               return tcp_v6_ao_calc_key(mkt, key, &sk->sk_v6_daddr,
+                                         &sk->sk_v6_rcv_saddr, sk->sk_dport,
+                                         htons(sk->sk_num), disn, sisn);
+}
+
 static struct tcp_ao_key *tcp_v6_ao_do_lookup(const struct sock *sk,
                                              const struct in6_addr *addr,
                                              int sndid, int rcvid)
index 70a3842f47faaf8ca979725b7d1da8dbf56c548f..074e16fe00e0814c848533a3740dd8e3d6f41c78 100644 (file)
@@ -1921,6 +1921,7 @@ static const struct tcp_sock_af_ops tcp_sock_ipv6_specific = {
 #ifdef CONFIG_TCP_AO
        .ao_lookup      =       tcp_v6_ao_lookup,
        .ao_parse       =       tcp_v6_parse_ao,
+       .ao_calc_key_sk =       tcp_v6_ao_calc_key_sk,
 #endif
 };
 #endif