inet: ipmr: fix data-races
authorEric Dumazet <edumazet@google.com>
Tue, 14 Jan 2025 22:10:49 +0000 (22:10 +0000)
committerJakub Kicinski <kuba@kernel.org>
Wed, 15 Jan 2025 23:07:23 +0000 (15:07 -0800)
Following fields of 'struct mr_mfc' can be updated
concurrently (no lock protection) from ip_mr_forward()
and ip6_mr_forward()

- bytes
- pkt
- wrong_if
- lastuse

They also can be read from other functions.

Convert bytes, pkt and wrong_if to atomic_long_t,
and use READ_ONCE()/WRITE_ONCE() for lastuse.

Fixes: 1da177e4c3f4 ("Linux-2.6.12-rc2")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Link: https://patch.msgid.link/20250114221049.1190631-1-edumazet@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/net/ethernet/mellanox/mlxsw/spectrum_mr.c
include/linux/mroute_base.h
net/ipv4/ipmr.c
net/ipv4/ipmr_base.c
net/ipv6/ip6mr.c

index 69cd689dbc83e9b2a1508b8d160f07159081c7cc..5afe6b155ef0d5c9c3048c5f33f4822ebc8a7ac2 100644 (file)
@@ -1003,10 +1003,10 @@ static void mlxsw_sp_mr_route_stats_update(struct mlxsw_sp *mlxsw_sp,
        mr->mr_ops->route_stats(mlxsw_sp, mr_route->route_priv, &packets,
                                &bytes);
 
-       if (mr_route->mfc->mfc_un.res.pkt != packets)
-               mr_route->mfc->mfc_un.res.lastuse = jiffies;
-       mr_route->mfc->mfc_un.res.pkt = packets;
-       mr_route->mfc->mfc_un.res.bytes = bytes;
+       if (atomic_long_read(&mr_route->mfc->mfc_un.res.pkt) != packets)
+               WRITE_ONCE(mr_route->mfc->mfc_un.res.lastuse, jiffies);
+       atomic_long_set(&mr_route->mfc->mfc_un.res.pkt, packets);
+       atomic_long_set(&mr_route->mfc->mfc_un.res.bytes, bytes);
 }
 
 static void mlxsw_sp_mr_stats_update(struct work_struct *work)
index 9dd4bf1572553ffbf41bade97393fac091797a8d..58a2401e4b551bad0eadb9c5d4c341ddad48b39b 100644 (file)
@@ -146,9 +146,9 @@ struct mr_mfc {
                        unsigned long last_assert;
                        int minvif;
                        int maxvif;
-                       unsigned long bytes;
-                       unsigned long pkt;
-                       unsigned long wrong_if;
+                       atomic_long_t bytes;
+                       atomic_long_t pkt;
+                       atomic_long_t wrong_if;
                        unsigned long lastuse;
                        unsigned char ttls[MAXVIFS];
                        refcount_t refcount;
index 99d8faa508e5325c653c920855a52f24b661618c..21ae7594a8525a0df01ce01b801d0075dada0959 100644 (file)
@@ -831,7 +831,7 @@ static void ipmr_update_thresholds(struct mr_table *mrt, struct mr_mfc *cache,
                                cache->mfc_un.res.maxvif = vifi + 1;
                }
        }
-       cache->mfc_un.res.lastuse = jiffies;
+       WRITE_ONCE(cache->mfc_un.res.lastuse, jiffies);
 }
 
 static int vif_add(struct net *net, struct mr_table *mrt,
@@ -1681,9 +1681,9 @@ int ipmr_ioctl(struct sock *sk, int cmd, void *arg)
                rcu_read_lock();
                c = ipmr_cache_find(mrt, sr->src.s_addr, sr->grp.s_addr);
                if (c) {
-                       sr->pktcnt = c->_c.mfc_un.res.pkt;
-                       sr->bytecnt = c->_c.mfc_un.res.bytes;
-                       sr->wrong_if = c->_c.mfc_un.res.wrong_if;
+                       sr->pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+                       sr->bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+                       sr->wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
                        rcu_read_unlock();
                        return 0;
                }
@@ -1753,9 +1753,9 @@ int ipmr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg)
                rcu_read_lock();
                c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr);
                if (c) {
-                       sr.pktcnt = c->_c.mfc_un.res.pkt;
-                       sr.bytecnt = c->_c.mfc_un.res.bytes;
-                       sr.wrong_if = c->_c.mfc_un.res.wrong_if;
+                       sr.pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+                       sr.bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+                       sr.wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
                        rcu_read_unlock();
 
                        if (copy_to_user(arg, &sr, sizeof(sr)))
@@ -1988,9 +1988,9 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt,
        int vif, ct;
 
        vif = c->_c.mfc_parent;
-       c->_c.mfc_un.res.pkt++;
-       c->_c.mfc_un.res.bytes += skb->len;
-       c->_c.mfc_un.res.lastuse = jiffies;
+       atomic_long_inc(&c->_c.mfc_un.res.pkt);
+       atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes);
+       WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies);
 
        if (c->mfc_origin == htonl(INADDR_ANY) && true_vifi >= 0) {
                struct mfc_cache *cache_proxy;
@@ -2021,7 +2021,7 @@ static void ip_mr_forward(struct net *net, struct mr_table *mrt,
                        goto dont_forward;
                }
 
-               c->_c.mfc_un.res.wrong_if++;
+               atomic_long_inc(&c->_c.mfc_un.res.wrong_if);
 
                if (true_vifi >= 0 && mrt->mroute_do_assert &&
                    /* pimsm uses asserts, when switching from RPT to SPT,
@@ -3029,9 +3029,9 @@ static int ipmr_mfc_seq_show(struct seq_file *seq, void *v)
 
                if (it->cache != &mrt->mfc_unres_queue) {
                        seq_printf(seq, " %8lu %8lu %8lu",
-                                  mfc->_c.mfc_un.res.pkt,
-                                  mfc->_c.mfc_un.res.bytes,
-                                  mfc->_c.mfc_un.res.wrong_if);
+                                  atomic_long_read(&mfc->_c.mfc_un.res.pkt),
+                                  atomic_long_read(&mfc->_c.mfc_un.res.bytes),
+                                  atomic_long_read(&mfc->_c.mfc_un.res.wrong_if));
                        for (n = mfc->_c.mfc_un.res.minvif;
                             n < mfc->_c.mfc_un.res.maxvif; n++) {
                                if (VIF_EXISTS(mrt, n) &&
index f0af12a2f70bcdf5ba54321bf7ebebe798318abb..03b6eee407a24117612d2254c5eb72e78f39c196 100644 (file)
@@ -263,9 +263,9 @@ int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
        lastuse = READ_ONCE(c->mfc_un.res.lastuse);
        lastuse = time_after_eq(jiffies, lastuse) ? jiffies - lastuse : 0;
 
-       mfcs.mfcs_packets = c->mfc_un.res.pkt;
-       mfcs.mfcs_bytes = c->mfc_un.res.bytes;
-       mfcs.mfcs_wrong_if = c->mfc_un.res.wrong_if;
+       mfcs.mfcs_packets = atomic_long_read(&c->mfc_un.res.pkt);
+       mfcs.mfcs_bytes = atomic_long_read(&c->mfc_un.res.bytes);
+       mfcs.mfcs_wrong_if = atomic_long_read(&c->mfc_un.res.wrong_if);
        if (nla_put_64bit(skb, RTA_MFC_STATS, sizeof(mfcs), &mfcs, RTA_PAD) ||
            nla_put_u64_64bit(skb, RTA_EXPIRES, jiffies_to_clock_t(lastuse),
                              RTA_PAD))
index 578ff1336afeff7a9f468d54c8fc47fddcaedbbb..535e9f72514c06ad655e46d3200c14298f584d99 100644 (file)
@@ -520,9 +520,9 @@ static int ipmr_mfc_seq_show(struct seq_file *seq, void *v)
 
                if (it->cache != &mrt->mfc_unres_queue) {
                        seq_printf(seq, " %8lu %8lu %8lu",
-                                  mfc->_c.mfc_un.res.pkt,
-                                  mfc->_c.mfc_un.res.bytes,
-                                  mfc->_c.mfc_un.res.wrong_if);
+                                  atomic_long_read(&mfc->_c.mfc_un.res.pkt),
+                                  atomic_long_read(&mfc->_c.mfc_un.res.bytes),
+                                  atomic_long_read(&mfc->_c.mfc_un.res.wrong_if));
                        for (n = mfc->_c.mfc_un.res.minvif;
                             n < mfc->_c.mfc_un.res.maxvif; n++) {
                                if (VIF_EXISTS(mrt, n) &&
@@ -884,7 +884,7 @@ static void ip6mr_update_thresholds(struct mr_table *mrt,
                                cache->mfc_un.res.maxvif = vifi + 1;
                }
        }
-       cache->mfc_un.res.lastuse = jiffies;
+       WRITE_ONCE(cache->mfc_un.res.lastuse, jiffies);
 }
 
 static int mif6_add(struct net *net, struct mr_table *mrt,
@@ -1945,9 +1945,9 @@ int ip6mr_ioctl(struct sock *sk, int cmd, void *arg)
                c = ip6mr_cache_find(mrt, &sr->src.sin6_addr,
                                     &sr->grp.sin6_addr);
                if (c) {
-                       sr->pktcnt = c->_c.mfc_un.res.pkt;
-                       sr->bytecnt = c->_c.mfc_un.res.bytes;
-                       sr->wrong_if = c->_c.mfc_un.res.wrong_if;
+                       sr->pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+                       sr->bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+                       sr->wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
                        rcu_read_unlock();
                        return 0;
                }
@@ -2017,9 +2017,9 @@ int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg)
                rcu_read_lock();
                c = ip6mr_cache_find(mrt, &sr.src.sin6_addr, &sr.grp.sin6_addr);
                if (c) {
-                       sr.pktcnt = c->_c.mfc_un.res.pkt;
-                       sr.bytecnt = c->_c.mfc_un.res.bytes;
-                       sr.wrong_if = c->_c.mfc_un.res.wrong_if;
+                       sr.pktcnt = atomic_long_read(&c->_c.mfc_un.res.pkt);
+                       sr.bytecnt = atomic_long_read(&c->_c.mfc_un.res.bytes);
+                       sr.wrong_if = atomic_long_read(&c->_c.mfc_un.res.wrong_if);
                        rcu_read_unlock();
 
                        if (copy_to_user(arg, &sr, sizeof(sr)))
@@ -2142,9 +2142,9 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
        int true_vifi = ip6mr_find_vif(mrt, dev);
 
        vif = c->_c.mfc_parent;
-       c->_c.mfc_un.res.pkt++;
-       c->_c.mfc_un.res.bytes += skb->len;
-       c->_c.mfc_un.res.lastuse = jiffies;
+       atomic_long_inc(&c->_c.mfc_un.res.pkt);
+       atomic_long_add(skb->len, &c->_c.mfc_un.res.bytes);
+       WRITE_ONCE(c->_c.mfc_un.res.lastuse, jiffies);
 
        if (ipv6_addr_any(&c->mf6c_origin) && true_vifi >= 0) {
                struct mfc6_cache *cache_proxy;
@@ -2162,7 +2162,7 @@ static void ip6_mr_forward(struct net *net, struct mr_table *mrt,
         * Wrong interface: drop packet and (maybe) send PIM assert.
         */
        if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) {
-               c->_c.mfc_un.res.wrong_if++;
+               atomic_long_inc(&c->_c.mfc_un.res.wrong_if);
 
                if (true_vifi >= 0 && mrt->mroute_do_assert &&
                    /* pimsm uses asserts, when switching from RPT to SPT,