netlink: add mask validation
[linux-2.6-block.git] / lib / nlattr.c
index 80ff9fe83696f881a5b5cda30c2afe13cd33b25a..9c99f5daa4d2ada5192038099de17feffbe421ca 100644 (file)
@@ -323,6 +323,37 @@ static int nla_validate_int_range(const struct nla_policy *pt,
        }
 }
 
+static int nla_validate_mask(const struct nla_policy *pt,
+                            const struct nlattr *nla,
+                            struct netlink_ext_ack *extack)
+{
+       u64 value;
+
+       switch (pt->type) {
+       case NLA_U8:
+               value = nla_get_u8(nla);
+               break;
+       case NLA_U16:
+               value = nla_get_u16(nla);
+               break;
+       case NLA_U32:
+               value = nla_get_u32(nla);
+               break;
+       case NLA_U64:
+               value = nla_get_u64(nla);
+               break;
+       default:
+               return -EINVAL;
+       }
+
+       if (value & ~(u64)pt->mask) {
+               NL_SET_ERR_MSG_ATTR(extack, nla, "reserved bit set");
+               return -EINVAL;
+       }
+
+       return 0;
+}
+
 static int validate_nla(const struct nlattr *nla, int maxtype,
                        const struct nla_policy *policy, unsigned int validate,
                        struct netlink_ext_ack *extack, unsigned int depth)
@@ -503,6 +534,11 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                if (err)
                        return err;
                break;
+       case NLA_VALIDATE_MASK:
+               err = nla_validate_mask(pt, nla, extack);
+               if (err)
+                       return err;
+               break;
        case NLA_VALIDATE_FUNCTION:
                if (pt->validate) {
                        err = pt->validate(nla, extack);