futex: Implement FUTEX2_MPOL
authorPeter Zijlstra <peterz@infradead.org>
Wed, 16 Apr 2025 16:29:17 +0000 (18:29 +0200)
committerPeter Zijlstra <peterz@infradead.org>
Sat, 3 May 2025 10:02:09 +0000 (12:02 +0200)
Extend the futex2 interface to be aware of mempolicy.

When FUTEX2_MPOL is specified and there is a MPOL_PREFERRED or
home_node specified covering the futex address, use that hash-map.

Notably, in this case the futex will go to the global node hashtable,
even if it is a PRIVATE futex.

When FUTEX2_NUMA|FUTEX2_MPOL is specified and the user specified node
value is FUTEX_NO_NODE, the MPOL lookup (as described above) will be
tried first before reverting to setting node to the local node.

[bigeasy: add CONFIG_FUTEX_MPOL, add MPOL to FUTEX2_VALID_MASK, write
the node only to user if FUTEX_NO_NODE was supplied]

Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Signed-off-by: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Link: https://lore.kernel.org/r/20250416162921.513656-18-bigeasy@linutronix.de
include/linux/mmap_lock.h
include/uapi/linux/futex.h
init/Kconfig
kernel/futex/core.c
kernel/futex/futex.h

index 4706c676990275d480fcf3e20437462f7858cdd8..e0eddfd306ef3b50c9cd7faee0110dd7b51efc5b 100644 (file)
@@ -7,6 +7,7 @@
 #include <linux/rwsem.h>
 #include <linux/tracepoint-defs.h>
 #include <linux/types.h>
+#include <linux/cleanup.h>
 
 #define MMAP_LOCK_INITIALIZER(name) \
        .mmap_lock = __RWSEM_INITIALIZER((name).mmap_lock),
@@ -211,6 +212,9 @@ static inline void mmap_read_unlock(struct mm_struct *mm)
        up_read(&mm->mmap_lock);
 }
 
+DEFINE_GUARD(mmap_read_lock, struct mm_struct *,
+            mmap_read_lock(_T), mmap_read_unlock(_T))
+
 static inline void mmap_read_unlock_non_owner(struct mm_struct *mm)
 {
        __mmap_lock_trace_released(mm, false);
index 6b94da467e705d10cad9d4103a9252ac0d04f912..7e2744ec89336a260e89883e95222eda199eeb7f 100644 (file)
@@ -63,7 +63,7 @@
 #define FUTEX2_SIZE_U32                0x02
 #define FUTEX2_SIZE_U64                0x03
 #define FUTEX2_NUMA            0x04
-                       /*      0x08 */
+#define FUTEX2_MPOL            0x08
                        /*      0x10 */
                        /*      0x20 */
                        /*      0x40 */
index 4b84da2b2ec456a35f2212e677a86010bf21c651..b373267ba2e2b6363865a73e54edc9a27da91ac3 100644 (file)
@@ -1704,6 +1704,11 @@ config FUTEX_PRIVATE_HASH
        depends on FUTEX && !BASE_SMALL && MMU
        default y
 
+config FUTEX_MPOL
+       bool
+       depends on FUTEX && NUMA
+       default y
+
 config EPOLL
        bool "Enable eventpoll support" if EXPERT
        default y
index 1490e6492993e3d7655ff82012a928b38e66a9ad..19a2c65f3d373c0b60c864a6fe0604787221d342 100644 (file)
@@ -43,6 +43,8 @@
 #include <linux/slab.h>
 #include <linux/prctl.h>
 #include <linux/rcuref.h>
+#include <linux/mempolicy.h>
+#include <linux/mmap_lock.h>
 
 #include "futex.h"
 #include "../locking/rtmutex_common.h"
@@ -328,6 +330,75 @@ struct futex_hash_bucket *futex_hash(union futex_key *key)
 
 #endif /* CONFIG_FUTEX_PRIVATE_HASH */
 
+#ifdef CONFIG_FUTEX_MPOL
+
+static int __futex_key_to_node(struct mm_struct *mm, unsigned long addr)
+{
+       struct vm_area_struct *vma = vma_lookup(mm, addr);
+       struct mempolicy *mpol;
+       int node = FUTEX_NO_NODE;
+
+       if (!vma)
+               return FUTEX_NO_NODE;
+
+       mpol = vma_policy(vma);
+       if (!mpol)
+               return FUTEX_NO_NODE;
+
+       switch (mpol->mode) {
+       case MPOL_PREFERRED:
+               node = first_node(mpol->nodes);
+               break;
+       case MPOL_PREFERRED_MANY:
+       case MPOL_BIND:
+               if (mpol->home_node != NUMA_NO_NODE)
+                       node = mpol->home_node;
+               break;
+       default:
+               break;
+       }
+
+       return node;
+}
+
+static int futex_key_to_node_opt(struct mm_struct *mm, unsigned long addr)
+{
+       int seq, node;
+
+       guard(rcu)();
+
+       if (!mmap_lock_speculate_try_begin(mm, &seq))
+               return -EBUSY;
+
+       node = __futex_key_to_node(mm, addr);
+
+       if (mmap_lock_speculate_retry(mm, seq))
+               return -EAGAIN;
+
+       return node;
+}
+
+static int futex_mpol(struct mm_struct *mm, unsigned long addr)
+{
+       int node;
+
+       node = futex_key_to_node_opt(mm, addr);
+       if (node >= FUTEX_NO_NODE)
+               return node;
+
+       guard(mmap_read_lock)(mm);
+       return __futex_key_to_node(mm, addr);
+}
+
+#else /* !CONFIG_FUTEX_MPOL */
+
+static int futex_mpol(struct mm_struct *mm, unsigned long addr)
+{
+       return FUTEX_NO_NODE;
+}
+
+#endif /* CONFIG_FUTEX_MPOL */
+
 /**
  * __futex_hash - Return the hash bucket
  * @key:       Pointer to the futex key for which the hash is calculated
@@ -342,18 +413,20 @@ struct futex_hash_bucket *futex_hash(union futex_key *key)
 static struct futex_hash_bucket *
 __futex_hash(union futex_key *key, struct futex_private_hash *fph)
 {
-       struct futex_hash_bucket *hb;
+       int node = key->both.node;
        u32 hash;
-       int node;
 
-       hb = __futex_hash_private(key, fph);
-       if (hb)
-               return hb;
+       if (node == FUTEX_NO_NODE) {
+               struct futex_hash_bucket *hb;
+
+               hb = __futex_hash_private(key, fph);
+               if (hb)
+                       return hb;
+       }
 
        hash = jhash2((u32 *)key,
                      offsetof(typeof(*key), both.offset) / sizeof(u32),
                      key->both.offset);
-       node = key->both.node;
 
        if (node == FUTEX_NO_NODE) {
                /*
@@ -480,6 +553,7 @@ int get_futex_key(u32 __user *uaddr, unsigned int flags, union futex_key *key,
        struct folio *folio;
        struct address_space *mapping;
        int node, err, size, ro = 0;
+       bool node_updated = false;
        bool fshared;
 
        fshared = flags & FLAGS_SHARED;
@@ -501,27 +575,37 @@ int get_futex_key(u32 __user *uaddr, unsigned int flags, union futex_key *key,
        if (unlikely(should_fail_futex(fshared)))
                return -EFAULT;
 
+       node = FUTEX_NO_NODE;
+
        if (flags & FLAGS_NUMA) {
                u32 __user *naddr = (void *)uaddr + size / 2;
 
                if (futex_get_value(&node, naddr))
                        return -EFAULT;
 
-               if (node == FUTEX_NO_NODE) {
-                       node = numa_node_id();
-                       if (futex_put_value(node, naddr))
-                               return -EFAULT;
-
-               } else if (node >= MAX_NUMNODES || !node_possible(node)) {
+               if (node != FUTEX_NO_NODE &&
+                   (node >= MAX_NUMNODES || !node_possible(node)))
                        return -EINVAL;
-               }
+       }
 
-               key->both.node = node;
+       if (node == FUTEX_NO_NODE && (flags & FLAGS_MPOL)) {
+               node = futex_mpol(mm, address);
+               node_updated = true;
+       }
 
-       } else {
-               key->both.node = FUTEX_NO_NODE;
+       if (flags & FLAGS_NUMA) {
+               u32 __user *naddr = (void *)uaddr + size / 2;
+
+               if (node == FUTEX_NO_NODE) {
+                       node = numa_node_id();
+                       node_updated = true;
+               }
+               if (node_updated && futex_put_value(node, naddr))
+                       return -EFAULT;
        }
 
+       key->both.node = node;
+
        /*
         * PROCESS_PRIVATE futexes are fast.
         * As the mm cannot disappear under us and the 'key' only needs
index acc79536788980cfbdf0e54af292ffd4866f6319..069fc2a83080d87d1a25b4b8543172174f7b29dc 100644 (file)
@@ -39,6 +39,7 @@
 #define FLAGS_HAS_TIMEOUT      0x0040
 #define FLAGS_NUMA             0x0080
 #define FLAGS_STRICT           0x0100
+#define FLAGS_MPOL             0x0200
 
 /* FUTEX_ to FLAGS_ */
 static inline unsigned int futex_to_flags(unsigned int op)
@@ -54,7 +55,7 @@ static inline unsigned int futex_to_flags(unsigned int op)
        return flags;
 }
 
-#define FUTEX2_VALID_MASK (FUTEX2_SIZE_MASK | FUTEX2_NUMA | FUTEX2_PRIVATE)
+#define FUTEX2_VALID_MASK (FUTEX2_SIZE_MASK | FUTEX2_NUMA | FUTEX2_MPOL | FUTEX2_PRIVATE)
 
 /* FUTEX2_ to FLAGS_ */
 static inline unsigned int futex2_to_flags(unsigned int flags2)
@@ -67,6 +68,9 @@ static inline unsigned int futex2_to_flags(unsigned int flags2)
        if (flags2 & FUTEX2_NUMA)
                flags |= FLAGS_NUMA;
 
+       if (flags2 & FUTEX2_MPOL)
+               flags |= FLAGS_MPOL;
+
        return flags;
 }