bpf/verifier: refactor checks for range computation
[linux-2.6-block.git] / kernel / bpf / verifier.c
index 41c66cc6db80fd3a6899f41af8c742265729589d..bdaf0413bf06f3515cb92912dc5a774a8604faf3 100644 (file)
@@ -13878,6 +13878,50 @@ static void scalar_min_max_arsh(struct bpf_reg_state *dst_reg,
        __update_reg_bounds(dst_reg);
 }
 
+static bool is_safe_to_compute_dst_reg_range(struct bpf_insn *insn,
+                                            const struct bpf_reg_state *src_reg)
+{
+       bool src_is_const = false;
+       u64 insn_bitness = (BPF_CLASS(insn->code) == BPF_ALU64) ? 64 : 32;
+
+       if (insn_bitness == 32) {
+               if (tnum_subreg_is_const(src_reg->var_off)
+                   && src_reg->s32_min_value == src_reg->s32_max_value
+                   && src_reg->u32_min_value == src_reg->u32_max_value)
+                       src_is_const = true;
+       } else {
+               if (tnum_is_const(src_reg->var_off)
+                   && src_reg->smin_value == src_reg->smax_value
+                   && src_reg->umin_value == src_reg->umax_value)
+                       src_is_const = true;
+       }
+
+       switch (BPF_OP(insn->code)) {
+       case BPF_ADD:
+       case BPF_SUB:
+       case BPF_AND:
+               return true;
+
+       /* Compute range for the following only if the src_reg is const.
+        */
+       case BPF_XOR:
+       case BPF_OR:
+       case BPF_MUL:
+               return src_is_const;
+
+       /* Shift operators range is only computable if shift dimension operand
+        * is a constant. Shifts greater than 31 or 63 are undefined. This
+        * includes shifts by a negative number.
+        */
+       case BPF_LSH:
+       case BPF_RSH:
+       case BPF_ARSH:
+               return (src_is_const && src_reg->umax_value < insn_bitness);
+       default:
+               return false;
+       }
+}
+
 /* WARNING: This function does calculations on 64-bit values, but the actual
  * execution may occur on 32-bit values. Therefore, things like bitshifts
  * need extra checks in the 32-bit case.
@@ -13888,51 +13932,10 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                                      struct bpf_reg_state src_reg)
 {
        u8 opcode = BPF_OP(insn->code);
-       bool src_known;
-       s64 smin_val, smax_val;
-       u64 umin_val, umax_val;
-       s32 s32_min_val, s32_max_val;
-       u32 u32_min_val, u32_max_val;
-       u64 insn_bitness = (BPF_CLASS(insn->code) == BPF_ALU64) ? 64 : 32;
        bool alu32 = (BPF_CLASS(insn->code) != BPF_ALU64);
        int ret;
 
-       smin_val = src_reg.smin_value;
-       smax_val = src_reg.smax_value;
-       umin_val = src_reg.umin_value;
-       umax_val = src_reg.umax_value;
-
-       s32_min_val = src_reg.s32_min_value;
-       s32_max_val = src_reg.s32_max_value;
-       u32_min_val = src_reg.u32_min_value;
-       u32_max_val = src_reg.u32_max_value;
-
-       if (alu32) {
-               src_known = tnum_subreg_is_const(src_reg.var_off);
-               if ((src_known &&
-                    (s32_min_val != s32_max_val || u32_min_val != u32_max_val)) ||
-                   s32_min_val > s32_max_val || u32_min_val > u32_max_val) {
-                       /* Taint dst register if offset had invalid bounds
-                        * derived from e.g. dead branches.
-                        */
-                       __mark_reg_unknown(env, dst_reg);
-                       return 0;
-               }
-       } else {
-               src_known = tnum_is_const(src_reg.var_off);
-               if ((src_known &&
-                    (smin_val != smax_val || umin_val != umax_val)) ||
-                   smin_val > smax_val || umin_val > umax_val) {
-                       /* Taint dst register if offset had invalid bounds
-                        * derived from e.g. dead branches.
-                        */
-                       __mark_reg_unknown(env, dst_reg);
-                       return 0;
-               }
-       }
-
-       if (!src_known &&
-           opcode != BPF_ADD && opcode != BPF_SUB && opcode != BPF_AND) {
+       if (!is_safe_to_compute_dst_reg_range(insn, &src_reg)) {
                __mark_reg_unknown(env, dst_reg);
                return 0;
        }
@@ -13989,46 +13992,24 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
                scalar_min_max_xor(dst_reg, &src_reg);
                break;
        case BPF_LSH:
-               if (umax_val >= insn_bitness) {
-                       /* Shifts greater than 31 or 63 are undefined.
-                        * This includes shifts by a negative number.
-                        */
-                       __mark_reg_unknown(env, dst_reg);
-                       break;
-               }
                if (alu32)
                        scalar32_min_max_lsh(dst_reg, &src_reg);
                else
                        scalar_min_max_lsh(dst_reg, &src_reg);
                break;
        case BPF_RSH:
-               if (umax_val >= insn_bitness) {
-                       /* Shifts greater than 31 or 63 are undefined.
-                        * This includes shifts by a negative number.
-                        */
-                       __mark_reg_unknown(env, dst_reg);
-                       break;
-               }
                if (alu32)
                        scalar32_min_max_rsh(dst_reg, &src_reg);
                else
                        scalar_min_max_rsh(dst_reg, &src_reg);
                break;
        case BPF_ARSH:
-               if (umax_val >= insn_bitness) {
-                       /* Shifts greater than 31 or 63 are undefined.
-                        * This includes shifts by a negative number.
-                        */
-                       __mark_reg_unknown(env, dst_reg);
-                       break;
-               }
                if (alu32)
                        scalar32_min_max_arsh(dst_reg, &src_reg);
                else
                        scalar_min_max_arsh(dst_reg, &src_reg);
                break;
        default:
-               __mark_reg_unknown(env, dst_reg);
                break;
        }