net/tcp: Prevent TCP-MD5 with TCP-AO being set
[linux-2.6-block.git] / net / ipv6 / tcp_ipv6.c
index 44b6949d72b2216b4d3934cd302bfff94ecae23d..70a3842f47faaf8ca979725b7d1da8dbf56c548f 100644 (file)
@@ -76,16 +76,9 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb);
 
 static const struct inet_connection_sock_af_ops ipv6_mapped;
 const struct inet_connection_sock_af_ops ipv6_specific;
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
 static const struct tcp_sock_af_ops tcp_sock_ipv6_specific;
 static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific;
-#else
-static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk,
-                                                  const struct in6_addr *addr,
-                                                  int l3index)
-{
-       return NULL;
-}
 #endif
 
 /* Helper returning the inet6 address from a given tcp socket.
@@ -135,7 +128,7 @@ static int tcp_v6_pre_connect(struct sock *sk, struct sockaddr *uaddr,
 
        sock_owned_by_me(sk);
 
-       return BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr);
+       return BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr, &addr_len);
 }
 
 static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
@@ -163,7 +156,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
 
        memset(&fl6, 0, sizeof(fl6));
 
-       if (np->sndflow) {
+       if (inet6_test_bit(SNDFLOW, sk)) {
                fl6.flowlabel = usin->sin6_flowinfo&IPV6_FLOWINFO_MASK;
                IP6_ECN_flow_init(fl6.flowlabel);
                if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) {
@@ -239,7 +232,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
                if (sk_is_mptcp(sk))
                        mptcpv6_handle_mapped(sk, true);
                sk->sk_backlog_rcv = tcp_v4_do_rcv;
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
                tp->af_specific = &tcp_sock_ipv6_mapped_specific;
 #endif
 
@@ -252,7 +245,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
                        if (sk_is_mptcp(sk))
                                mptcpv6_handle_mapped(sk, false);
                        sk->sk_backlog_rcv = tcp_v6_do_rcv;
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
                        tp->af_specific = &tcp_sock_ipv6_specific;
 #endif
                        goto failure;
@@ -286,6 +279,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
                goto failure;
        }
 
+       tp->tcp_usec_ts = dst_tcp_usec_ts(dst);
        tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row;
 
        if (!saddr) {
@@ -508,7 +502,7 @@ static int tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
                        tcp_ld_RTO_revert(sk, seq);
        }
 
-       if (!sock_owned_by_user(sk) && np->recverr) {
+       if (!sock_owned_by_user(sk) && inet6_test_bit(RECVERR6, sk)) {
                WRITE_ONCE(sk->sk_err, err);
                sk_error_report(sk);
        } else {
@@ -548,7 +542,7 @@ static int tcp_v6_send_synack(const struct sock *sk, struct dst_entry *dst,
                                    &ireq->ir_v6_rmt_addr);
 
                fl6->daddr = ireq->ir_v6_rmt_addr;
-               if (np->repflow && ireq->pktopts)
+               if (inet6_test_bit(REPFLOW, sk) && ireq->pktopts)
                        fl6->flowlabel = ip6_flowlabel(ipv6_hdr(ireq->pktopts));
 
                tclass = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos) ?
@@ -565,7 +559,7 @@ static int tcp_v6_send_synack(const struct sock *sk, struct dst_entry *dst,
                if (!opt)
                        opt = rcu_dereference(np->opt);
                err = ip6_xmit(sk, skb, fl6, skb->mark ? : READ_ONCE(sk->sk_mark),
-                              opt, tclass, sk->sk_priority);
+                              opt, tclass, READ_ONCE(sk->sk_priority));
                rcu_read_unlock();
                err = net_xmit_eval(err);
        }
@@ -606,6 +600,7 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
 {
        struct tcp_md5sig cmd;
        struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.tcpm_addr;
+       union tcp_ao_addr *addr;
        int l3index = 0;
        u8 prefixlen;
        u8 flags;
@@ -660,17 +655,32 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
        if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
                return -EINVAL;
 
-       if (ipv6_addr_v4mapped(&sin6->sin6_addr))
-               return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
+       if (ipv6_addr_v4mapped(&sin6->sin6_addr)) {
+               addr = (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3];
+
+               /* Don't allow keys for peers that have a matching TCP-AO key.
+                * See the comment in tcp_ao_add_cmd()
+                */
+               if (tcp_ao_required(sk, addr, AF_INET))
+                       return -EKEYREJECTED;
+               return tcp_md5_do_add(sk, addr,
                                      AF_INET, prefixlen, l3index, flags,
                                      cmd.tcpm_key, cmd.tcpm_keylen);
+       }
 
-       return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
-                             AF_INET6, prefixlen, l3index, flags,
+       addr = (union tcp_md5_addr *)&sin6->sin6_addr;
+
+       /* Don't allow keys for peers that have a matching TCP-AO key.
+        * See the comment in tcp_ao_add_cmd()
+        */
+       if (tcp_ao_required(sk, addr, AF_INET6))
+               return -EKEYREJECTED;
+
+       return tcp_md5_do_add(sk, addr, AF_INET6, prefixlen, l3index, flags,
                              cmd.tcpm_key, cmd.tcpm_keylen);
 }
 
-static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
+static int tcp_v6_md5_hash_headers(struct tcp_sigpool *hp,
                                   const struct in6_addr *daddr,
                                   const struct in6_addr *saddr,
                                   const struct tcphdr *th, int nbytes)
@@ -691,39 +701,36 @@ static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
        _th->check = 0;
 
        sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th));
-       ahash_request_set_crypt(hp->md5_req, &sg, NULL,
+       ahash_request_set_crypt(hp->req, &sg, NULL,
                                sizeof(*bp) + sizeof(*th));
-       return crypto_ahash_update(hp->md5_req);
+       return crypto_ahash_update(hp->req);
 }
 
 static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key,
                               const struct in6_addr *daddr, struct in6_addr *saddr,
                               const struct tcphdr *th)
 {
-       struct tcp_md5sig_pool *hp;
-       struct ahash_request *req;
+       struct tcp_sigpool hp;
 
-       hp = tcp_get_md5sig_pool();
-       if (!hp)
-               goto clear_hash_noput;
-       req = hp->md5_req;
+       if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
+               goto clear_hash_nostart;
 
-       if (crypto_ahash_init(req))
+       if (crypto_ahash_init(hp.req))
                goto clear_hash;
-       if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2))
+       if (tcp_v6_md5_hash_headers(&hp, daddr, saddr, th, th->doff << 2))
                goto clear_hash;
-       if (tcp_md5_hash_key(hp, key))
+       if (tcp_md5_hash_key(&hp, key))
                goto clear_hash;
-       ahash_request_set_crypt(req, NULL, md5_hash, 0);
-       if (crypto_ahash_final(req))
+       ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
+       if (crypto_ahash_final(hp.req))
                goto clear_hash;
 
-       tcp_put_md5sig_pool();
+       tcp_sigpool_end(&hp);
        return 0;
 
 clear_hash:
-       tcp_put_md5sig_pool();
-clear_hash_noput:
+       tcp_sigpool_end(&hp);
+clear_hash_nostart:
        memset(md5_hash, 0, 16);
        return 1;
 }
@@ -733,10 +740,9 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,
                               const struct sock *sk,
                               const struct sk_buff *skb)
 {
-       const struct in6_addr *saddr, *daddr;
-       struct tcp_md5sig_pool *hp;
-       struct ahash_request *req;
        const struct tcphdr *th = tcp_hdr(skb);
+       const struct in6_addr *saddr, *daddr;
+       struct tcp_sigpool hp;
 
        if (sk) { /* valid for establish/request sockets */
                saddr = &sk->sk_v6_rcv_saddr;
@@ -747,34 +753,38 @@ static int tcp_v6_md5_hash_skb(char *md5_hash,
                daddr = &ip6h->daddr;
        }
 
-       hp = tcp_get_md5sig_pool();
-       if (!hp)
-               goto clear_hash_noput;
-       req = hp->md5_req;
+       if (tcp_sigpool_start(tcp_md5_sigpool_id, &hp))
+               goto clear_hash_nostart;
 
-       if (crypto_ahash_init(req))
+       if (crypto_ahash_init(hp.req))
                goto clear_hash;
 
-       if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, skb->len))
+       if (tcp_v6_md5_hash_headers(&hp, daddr, saddr, th, skb->len))
                goto clear_hash;
-       if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2))
+       if (tcp_sigpool_hash_skb_data(&hp, skb, th->doff << 2))
                goto clear_hash;
-       if (tcp_md5_hash_key(hp, key))
+       if (tcp_md5_hash_key(&hp, key))
                goto clear_hash;
-       ahash_request_set_crypt(req, NULL, md5_hash, 0);
-       if (crypto_ahash_final(req))
+       ahash_request_set_crypt(hp.req, NULL, md5_hash, 0);
+       if (crypto_ahash_final(hp.req))
                goto clear_hash;
 
-       tcp_put_md5sig_pool();
+       tcp_sigpool_end(&hp);
        return 0;
 
 clear_hash:
-       tcp_put_md5sig_pool();
-clear_hash_noput:
+       tcp_sigpool_end(&hp);
+clear_hash_nostart:
        memset(md5_hash, 0, 16);
        return 1;
 }
-
+#else /* CONFIG_TCP_MD5SIG */
+static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk,
+                                                  const struct in6_addr *addr,
+                                                  int l3index)
+{
+       return NULL;
+}
 #endif
 
 static void tcp_v6_init_req(struct request_sock *req,
@@ -797,7 +807,7 @@ static void tcp_v6_init_req(struct request_sock *req,
            (ipv6_opt_accepted(sk_listener, skb, &TCP_SKB_CB(skb)->header.h6) ||
             np->rxopt.bits.rxinfo ||
             np->rxopt.bits.rxoinfo || np->rxopt.bits.rxhlim ||
-            np->rxopt.bits.rxohlim || np->repflow)) {
+            np->rxopt.bits.rxohlim || inet6_test_bit(REPFLOW, sk_listener))) {
                refcount_inc(&skb->users);
                ireq->pktopts = skb;
        }
@@ -1055,12 +1065,10 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
        if (sk) {
                oif = sk->sk_bound_dev_if;
                if (sk_fullsock(sk)) {
-                       const struct ipv6_pinfo *np = tcp_inet6_sk(sk);
-
                        trace_tcp_send_reset(sk, skb);
-                       if (np->repflow)
+                       if (inet6_test_bit(REPFLOW, sk))
                                label = ip6_flowlabel(ipv6h);
-                       priority = sk->sk_priority;
+                       priority = READ_ONCE(sk->sk_priority);
                        txhash = sk->sk_txhash;
                }
                if (sk->sk_state == TCP_TIME_WAIT) {
@@ -1098,7 +1106,7 @@ static void tcp_v6_timewait_ack(struct sock *sk, struct sk_buff *skb)
 
        tcp_v6_send_ack(sk, skb, tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt,
                        tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale,
-                       tcp_time_stamp_raw() + tcptw->tw_ts_offset,
+                       tcp_tw_tsval(tcptw),
                        tcptw->tw_ts_recent, tw->tw_bound_dev_if, tcp_twsk_md5_key(tcptw),
                        tw->tw_tclass, cpu_to_be32(tw->tw_flowlabel), tw->tw_priority,
                        tw->tw_txhash);
@@ -1125,7 +1133,7 @@ static void tcp_v6_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb,
                        tcp_rsk(req)->snt_isn + 1 : tcp_sk(sk)->snd_nxt,
                        tcp_rsk(req)->rcv_nxt,
                        req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale,
-                       tcp_time_stamp_raw() + tcp_rsk(req)->ts_off,
+                       tcp_rsk_tsval(tcp_rsk(req)),
                        READ_ONCE(req->ts_recent), sk->sk_bound_dev_if,
                        tcp_v6_md5_do_lookup(sk, &ipv6_hdr(skb)->saddr, l3index),
                        ipv6_get_dsfield(ipv6_hdr(skb)), 0,
@@ -1235,7 +1243,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
                if (sk_is_mptcp(newsk))
                        mptcpv6_handle_mapped(newsk, true);
                newsk->sk_backlog_rcv = tcp_v4_do_rcv;
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
                newtp->af_specific = &tcp_sock_ipv6_mapped_specific;
 #endif
 
@@ -1247,7 +1255,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
                newnp->mcast_oif   = inet_iif(skb);
                newnp->mcast_hops  = ip_hdr(skb)->ttl;
                newnp->rcv_flowinfo = 0;
-               if (np->repflow)
+               if (inet6_test_bit(REPFLOW, sk))
                        newnp->flow_label = 0;
 
                /*
@@ -1320,7 +1328,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
        newnp->mcast_oif  = tcp_v6_iif(skb);
        newnp->mcast_hops = ipv6_hdr(skb)->hop_limit;
        newnp->rcv_flowinfo = ip6_flowinfo(ipv6_hdr(skb));
-       if (np->repflow)
+       if (inet6_test_bit(REPFLOW, sk))
                newnp->flow_label = ip6_flowlabel(ipv6_hdr(skb));
 
        /* Set ToS of the new socket based upon the value of incoming SYN.
@@ -1542,10 +1550,11 @@ ipv6_pktoptions:
                if (np->rxopt.bits.rxinfo || np->rxopt.bits.rxoinfo)
                        np->mcast_oif = tcp_v6_iif(opt_skb);
                if (np->rxopt.bits.rxhlim || np->rxopt.bits.rxohlim)
-                       np->mcast_hops = ipv6_hdr(opt_skb)->hop_limit;
+                       WRITE_ONCE(np->mcast_hops,
+                                  ipv6_hdr(opt_skb)->hop_limit);
                if (np->rxopt.bits.rxflow || np->rxopt.bits.rxtclass)
                        np->rcv_flowinfo = ip6_flowinfo(ipv6_hdr(opt_skb));
-               if (np->repflow)
+               if (inet6_test_bit(REPFLOW, sk))
                        np->flow_label = ip6_flowlabel(ipv6_hdr(opt_skb));
                if (ipv6_opt_accepted(sk, opt_skb, &TCP_SKB_CB(opt_skb)->header.h6)) {
                        tcp_v6_restore_cb(opt_skb);
@@ -1895,7 +1904,6 @@ const struct inet_connection_sock_af_ops ipv6_specific = {
        .conn_request      = tcp_v6_conn_request,
        .syn_recv_sock     = tcp_v6_syn_recv_sock,
        .net_header_len    = sizeof(struct ipv6hdr),
-       .net_frag_header_len = sizeof(struct frag_hdr),
        .setsockopt        = ipv6_setsockopt,
        .getsockopt        = ipv6_getsockopt,
        .addr2sockaddr     = inet6_csk_addr2sockaddr,
@@ -1903,11 +1911,17 @@ const struct inet_connection_sock_af_ops ipv6_specific = {
        .mtu_reduced       = tcp_v6_mtu_reduced,
 };
 
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
 static const struct tcp_sock_af_ops tcp_sock_ipv6_specific = {
+#ifdef CONFIG_TCP_MD5SIG
        .md5_lookup     =       tcp_v6_md5_lookup,
        .calc_md5_hash  =       tcp_v6_md5_hash_skb,
        .md5_parse      =       tcp_v6_parse_md5_keys,
+#endif
+#ifdef CONFIG_TCP_AO
+       .ao_lookup      =       tcp_v6_ao_lookup,
+       .ao_parse       =       tcp_v6_parse_ao,
+#endif
 };
 #endif
 
@@ -1929,11 +1943,17 @@ static const struct inet_connection_sock_af_ops ipv6_mapped = {
        .mtu_reduced       = tcp_v4_mtu_reduced,
 };
 
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
 static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific = {
+#ifdef CONFIG_TCP_MD5SIG
        .md5_lookup     =       tcp_v4_md5_lookup,
        .calc_md5_hash  =       tcp_v4_md5_hash_skb,
        .md5_parse      =       tcp_v6_parse_md5_keys,
+#endif
+#ifdef CONFIG_TCP_AO
+       .ao_lookup      =       tcp_v6_ao_lookup,
+       .ao_parse       =       tcp_v6_parse_ao,
+#endif
 };
 #endif
 
@@ -1948,7 +1968,7 @@ static int tcp_v6_init_sock(struct sock *sk)
 
        icsk->icsk_af_ops = &ipv6_specific;
 
-#ifdef CONFIG_TCP_MD5SIG
+#if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
        tcp_sk(sk)->af_specific = &tcp_sock_ipv6_specific;
 #endif