udp: fix data-race in udp_set_dev_scratch()
[linux-2.6-block.git] / net / ipv4 / udp.c
index e8443cc5c1ab99970be7e0bf55a6ff2a7ce9a554..d1ed160af202c054839387201abd3f13b55d00e9 100644 (file)
@@ -821,6 +821,7 @@ static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4,
        int is_udplite = IS_UDPLITE(sk);
        int offset = skb_transport_offset(skb);
        int len = skb->len - offset;
+       int datalen = len - sizeof(*uh);
        __wsum csum = 0;
 
        /*
@@ -854,10 +855,12 @@ static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4,
                        return -EIO;
                }
 
-               skb_shinfo(skb)->gso_size = cork->gso_size;
-               skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4;
-               skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(len - sizeof(uh),
-                                                        cork->gso_size);
+               if (datalen > cork->gso_size) {
+                       skb_shinfo(skb)->gso_size = cork->gso_size;
+                       skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4;
+                       skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen,
+                                                                cork->gso_size);
+               }
                goto csum_partial;
        }
 
@@ -1313,6 +1316,20 @@ static void udp_set_dev_scratch(struct sk_buff *skb)
                scratch->_tsize_state |= UDP_SKB_IS_STATELESS;
 }
 
+static void udp_skb_csum_unnecessary_set(struct sk_buff *skb)
+{
+       /* We come here after udp_lib_checksum_complete() returned 0.
+        * This means that __skb_checksum_complete() might have
+        * set skb->csum_valid to 1.
+        * On 64bit platforms, we can set csum_unnecessary
+        * to true, but only if the skb is not shared.
+        */
+#if BITS_PER_LONG == 64
+       if (!skb_shared(skb))
+               udp_skb_scratch(skb)->csum_unnecessary = true;
+#endif
+}
+
 static int udp_skb_truesize(struct sk_buff *skb)
 {
        return udp_skb_scratch(skb)->_tsize_state & ~UDP_SKB_IS_STATELESS;
@@ -1547,10 +1564,7 @@ static struct sk_buff *__first_packet_length(struct sock *sk,
                        *total += skb->truesize;
                        kfree_skb(skb);
                } else {
-                       /* the csum related bits could be changed, refresh
-                        * the scratch area
-                        */
-                       udp_set_dev_scratch(skb);
+                       udp_skb_csum_unnecessary_set(skb);
                        break;
                }
        }
@@ -1574,7 +1588,7 @@ static int first_packet_length(struct sock *sk)
 
        spin_lock_bh(&rcvq->lock);
        skb = __first_packet_length(sk, rcvq, &total);
-       if (!skb && !skb_queue_empty(sk_queue)) {
+       if (!skb && !skb_queue_empty_lockless(sk_queue)) {
                spin_lock(&sk_queue->lock);
                skb_queue_splice_tail_init(sk_queue, rcvq);
                spin_unlock(&sk_queue->lock);
@@ -1647,7 +1661,7 @@ struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags,
                                return skb;
                        }
 
-                       if (skb_queue_empty(sk_queue)) {
+                       if (skb_queue_empty_lockless(sk_queue)) {
                                spin_unlock_bh(&queue->lock);
                                goto busy_check;
                        }
@@ -1673,7 +1687,7 @@ busy_check:
                                break;
 
                        sk_busy_loop(sk, flags & MSG_DONTWAIT);
-               } while (!skb_queue_empty(sk_queue));
+               } while (!skb_queue_empty_lockless(sk_queue));
 
                /* sk_queue is empty, reader_queue may contain peeked packets */
        } while (timeo &&
@@ -2709,7 +2723,7 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait)
        __poll_t mask = datagram_poll(file, sock, wait);
        struct sock *sk = sock->sk;
 
-       if (!skb_queue_empty(&udp_sk(sk)->reader_queue))
+       if (!skb_queue_empty_lockless(&udp_sk(sk)->reader_queue))
                mask |= EPOLLIN | EPOLLRDNORM;
 
        /* Check for false positives due to checksum errors */