Zulip Chat Archive

Stream: metaprogramming / tactics

Topic: Forward chaining with `aesop`


Geoffrey Irving (Nov 26 2023 at 16:19):

I'm applying aesop to a theorem with a bunch of ifs, which generate a bunch of fresh hypotheses via split_ifs. I'd like to do the following with the hypotheses:

  1. Generate all consequences of them according to a list of functions. E.g., if there are two hypotheses h0 h1, f is in the list of functions, and f h0 h1 type checks, I want f h0 h1 to become a hypothesis.
  2. simplify using all available hypotheses

Is there a good way to do that?

For example, if there are hypotheses

  1. t : ℕ
  2. _ : t < 128
  3. _ : ↑t &&& 127 = 0 -- In UInt64

I want to use a lemma that takes these and produces t = 0, then have aesop simp using that.

Geoffrey Irving (Nov 26 2023 at 20:36):

Ah, there is a forward rule builder.

Jannis Limperg (Nov 27 2023 at 10:59):

Yes, forward is what you're looking for. The current implementation is very inefficient, so you may run into trouble if you have big goals or many forward rules. We're designing a better implementation right now. (In fact, if you run into performance issues, I'd appreciate a ping because the issues might make for a good case study.)

Geoffrey Irving (Nov 27 2023 at 21:42):

Hmm, actually aesop seems to be doing this forward chaining automatically, but inconsistently. I can make a MWE if it seems worth it, but it'd be a nontrivial amount of work so I want to check first. Here's what's happening.

The (very non-closed, and thus nonworking!) code is

/-- `shiftRight` rounds down -/
lemma UInt128.toNat_shiftRight (x : UInt128) {s : UInt64} (sl : s < 128) :
    (x.shiftRight s).toNat = x.toNat / 2^s.toNat := by
  generalize ht : s.toNat = t
  have st : s = (t : UInt64) := by rw [ht, UInt64.cast_toNat]
  have t64 : t < 2^64 := by rw [ht]; exact UInt64.toNat_lt_2_pow_64 _
  have t64' : t < UInt64.size := t64
  have p64 : (64 : UInt64).toNat = 64 := rfl
  have p127 : (127 : UInt64).toNat = 127 := rfl
  have p128 : (128 : UInt64).toNat = 128 := rfl
  simp only [st, UInt64.lt_iff_toNat_lt, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', p128, st] at sl 
  have t127 : t &&& 127 = t := by nth_rw 2 [Nat.mod_eq_of_lt sl]; exact @Nat.land_eq_mod _ 7
  have t127' : (t : UInt64) &&& 127 = t := by
    simp only [UInt64.eq_iff_toNat_eq, UInt64.toNat_land, UInt64.toNat_cast, Nat.mod_eq_of_lt t64',
      p127, t127]
  refine Nat.eq_of_testBit_eq fun i  ?_
  have e0 : ¬t = 0  t < 64  (64 - (t : UInt64)).toNat % 64 = 64 - t := by
    intro t0 t64
    rw [UInt64.toNat_sub]
    · simp only [p64, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', ge_iff_le, Nat.mod_succ_eq_iff_lt]
      exact Nat.sub_lt (by norm_num) (Nat.pos_iff_ne_zero.mpr t0)
    · simp only [UInt64.le_iff_toNat_le, UInt64.toNat_cast, Nat.mod_eq_of_lt t64']
      exact t64.le
  have e1 : t < 64  64 - (64 - t) = t := fun h  Nat.sub_sub_self h.le
  have e2 : t < 64  i - (64 - t) = i + t - 64 := fun h  by rw [Nat.sub_sub_assoc h.le]
  have e3 : t < 64  t % 64 = t := fun h  by rw [Nat.mod_eq_of_lt h]
  have a0 : 64  i + t  i + t < 64  False := fun le lt  not_lt.mpr le lt
  simp only [UInt128.shiftRight, UInt128.testBit_eq, bif_eq_if, beq_iff_eq, decide_eq_true_eq,
    apply_ite (f := UInt64.toNat), apply_ite (f := fun x  Nat.testBit x i),
    apply_ite (f := fun x  Nat.testBit x (i - 64)),
    UInt64.toNat_shiftRight', UInt64.toNat_lor, Nat.testBit_lor, Nat.testBit_div_two_pow,
    Nat.testBit_mul_two_pow, UInt64.toNat_shiftLeft', Nat.testBit_mod_two_pow, apply_decide,
    UInt64.land_eq_hand, UInt64.eq_iff_toNat_eq, UInt64.toNat_zero, UInt64.toNat_land, p127,
    UInt64.toNat_cast, Nat.mod_eq_of_lt t64', UInt64.lt_iff_toNat_lt, t127, p64, Bool.or_true,
    Bool.or_false, Nat.testBit_zero', t127']
  aesop

Right before the aesop, the state is

x: UInt128
s: UInt64
t: 
ht: s.toNat = t
st: s = t
t64: t < 2 ^ 64
t64': t < UInt64.size
p64: 64.toNat = 64
p127: 127.toNat = 127
p128: 128.toNat = 128
sl: t < 128
t127: t &&& 127 = t
t127': t &&& 127 = t
i: 
e0: ¬t = 0  t < 64  (64 - t).toNat % 64 = 64 - t
e1: t < 64  64 - (64 - t) = t
e2: t < 64  i - (64 - t) = i + t - 64
e3: t < 64  t % 64 = t
a0: 64  i + t  i + t < 64  False
 (if i < 64 then
    if t = 0 then Nat.testBit x.lo.toNat i
    else
      if t < 64 then
        if
            (64 - t).toNat % 64  i 
              (if i - (64 - t).toNat % 64 < 64 - (64 - t).toNat % 64 then
                  Nat.testBit x.hi.toNat (i - (64 - t).toNat % 64) && true
                else Nat.testBit x.hi.toNat (i - (64 - t).toNat % 64) && false) =
                true then
          true
        else Nat.testBit x.lo.toNat (i + t % 64)
      else Nat.testBit x.hi.toNat (i + (t - 64).toNat % 64)
  else if t < 64 then Nat.testBit x.hi.toNat (i - 64 + t % 64) else false) =
  if i + t < 64 then Nat.testBit x.lo.toNat (i + t) else Nat.testBit x.hi.toNat (i + t - 64)

As aesop splits the ifs, it produces hypotheses like t < 64 or the reverse, and e0, e1, e2, e3 are supposed to fire in order to simplify in ways that needed the conditional hypotheses. However, aesop appears to be dropping them sometimes. For example, one of the produced non-closed goals is

x: UInt128
t: 
t64: t < 2 ^ 64
t64': t < UInt64.size
sl: t < 128
t127: t &&& 127 = t
t127': t &&& 127 = t
i: 
e0: (64 - t).toNat % 64 = 64 - t
e1: 64 - (64 - t) = t
e2: i - (64 - t) = i + t - 64
a0: 64  i + t  False
ht: t % UInt64.size = t
h✝⁵: i < 64
h✝⁴: ¬t = 0
h✝³: t < 64
h✝²: i + t - 64 < t
h✝¹: 64  i + t  Nat.testBit x.hi.toNat (i + t - 64) = false
h: i + t < 64
 Nat.testBit x.lo.toNat (i + t % 64) = Nat.testBit x.lo.toNat (i + t)

Ideally e3 would have fired and simplified t % 64 to t to close the goal, but e3 is gone for some reason.

Geoffrey Irving (Nov 27 2023 at 22:12):

To say it explicitly, let me know if an MWE is worth it and I’ll make one if so.

Jannis Limperg (Nov 30 2023 at 00:35):

This should be handled by simp_all, which Aesop calls internally -- it should use t < 64 to simplify t < 64 -> t % 64 = t to True -> t % 64 = t, then to t % 64 = t, and finally use this equation to rewrite everywhere in the goal.

Maybe the subst rule kicks in and rewrites the equation right to left before simp_all gets a chance? You can try whether aesop (erase Aesop.BuiltinRules.subst) makes more progress.

For more debugging, I'll need a WE (doesn't need to be minimal, a branch will do).

Geoffrey Irving (Dec 02 2023 at 11:15):

I tried that erase, but it didn’t have any effect.

Here's a WE:

import Mathlib.Data.UInt
import Mathlib.Init.IteSimp
import Mathlib.Data.Nat.Bitwise

@[ext]
structure UInt128 where
  lo : UInt64
  hi : UInt64
  deriving DecidableEq, BEq

@[pp_dot] def UInt128.toNat (x : UInt128) :  :=
  (x.hi.toNat * 2^64) + x.lo.toNat

/-- Divide by `2^(s % 128)`, rounding down -/
@[pp_dot] def UInt128.shiftRight (x : UInt128) (s : UInt64) : UInt128 :=
  let s := s.land 127
  { lo := bif s == 0 then x.lo
          else bif s < 64 then x.lo >>> s ||| x.hi <<< (64-s)
          else x.hi >>> (s - 64)
    hi := bif s < 64 then x.hi >>> s else 0 }

@[simp] lemma UInt64.cast_toNat (n : UInt64) : (n.toNat : UInt64) = n := sorry
@[simp] lemma UInt64.toNat_lt_2_pow_64 (n : UInt64) : n.toNat < 2^64 := Fin.prop _
lemma UInt64.eq_iff_toNat_eq (m n : UInt64) : m = n  m.toNat = n.toNat := sorry
lemma UInt64.le_iff_toNat_le (m n : UInt64) : m  n  m.toNat  n.toNat := sorry
lemma UInt64.lt_iff_toNat_lt (m n : UInt64) : m < n  m.toNat < n.toNat := sorry
@[simp] lemma UInt64.toNat_cast (n : ) : (n : UInt64).toNat = n % UInt64.size := sorry
@[simp] lemma UInt64.toNat_land {x y : UInt64} : (x &&& y).toNat = x.toNat &&& y.toNat := sorry
lemma UInt64.toNat_sub {x y : UInt64} (h : y  x) : (x - y).toNat = x.toNat - y.toNat := sorry
lemma Nat.land_eq_mod {n k : } : n &&& (2^k-1) = n % 2^k := sorry
lemma Nat.sub_sub_assoc {a b c : } (h : c  b) : a + c - b = a - (b - c) := sorry
@[simp] lemma Nat.testBit_div_two_pow {n k i : } : testBit (n / 2^k) i = testBit n (i+k) := sorry
@[simp] lemma Nat.testBit_mul_two_pow {n k i : } :
    testBit (n * 2^k) i = decide (k  i  testBit n (i-k)) := sorry
lemma UInt128.testBit_eq {x : UInt128} {i : } :
    x.toNat.testBit i = if i < 64 then x.lo.toNat.testBit i else x.hi.toNat.testBit (i-64) := sorry
lemma bif_eq_if {b : Bool} {x y : α} : (bif b then x else y) = if b then x else y := sorry
lemma UInt64.toNat_shiftRight' {x s : UInt64}
    : (x >>> s).toNat = x.toNat / 2^(s.toNat % 64) := sorry
lemma UInt64.toNat_shiftLeft' {x s : UInt64} :
    (x <<< s).toNat = x.toNat % 2^(64 - s.toNat % 64) * 2^(s.toNat % 64) := sorry
@[simp] lemma UInt64.toNat_lor {x y : UInt64} : (x ||| y).toNat = x.toNat ||| y.toNat := sorry
@[simp] lemma Nat.testBit_mod_two_pow {n k i : } :
    testBit (n % 2^k) i = (testBit n i && i < k) := sorry
lemma apply_decide {f : Bool  α} {p : Prop} {dp : Decidable p} :
    (f (@decide p dp)) = if p then f true else f false := sorry
@[simp] lemma UInt64.land_eq_hand {x y : UInt64} : UInt64.land x y = x &&& y := rfl
@[simp] lemma UInt64.toNat_zero : (0 : UInt64).toNat = 0 := rfl
@[simp] lemma Nat.testBit_zero' {i : } : testBit 0 i = false := sorry

/-- `shiftRight` rounds down -/
lemma UInt128.toNat_shiftRight (x : UInt128) {s : UInt64} (sl : s < 128) :
    (x.shiftRight s).toNat = x.toNat / 2^s.toNat := by
  generalize ht : s.toNat = t
  have st : s = (t : UInt64) := by rw [ht, UInt64.cast_toNat]
  have t64 : t < 2^64 := by rw [ht]; exact UInt64.toNat_lt_2_pow_64 _
  have t64' : t < UInt64.size := t64
  have p64 : (64 : UInt64).toNat = 64 := rfl
  have p127 : (127 : UInt64).toNat = 127 := rfl
  have p128 : (128 : UInt64).toNat = 128 := rfl
  simp only [st, UInt64.lt_iff_toNat_lt, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', p128, st] at sl 
  clear st ht s
  have t127 : t &&& 127 = t := by nth_rw 2 [Nat.mod_eq_of_lt sl]; exact @Nat.land_eq_mod _ 7
  have t127' : (t : UInt64) &&& 127 = t := by
    simp only [UInt64.eq_iff_toNat_eq, UInt64.toNat_land, UInt64.toNat_cast, Nat.mod_eq_of_lt t64',
      p127, t127]
  refine Nat.eq_of_testBit_eq fun i  ?_
  have e0 : ¬t = 0  t < 64  (64 - (t : UInt64)).toNat % 64 = 64 - t := by
    intro t0 t64
    rw [UInt64.toNat_sub]
    · simp only [p64, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', ge_iff_le, Nat.mod_succ_eq_iff_lt]
      exact Nat.sub_lt (by norm_num) (Nat.pos_iff_ne_zero.mpr t0)
    · simp only [UInt64.le_iff_toNat_le, UInt64.toNat_cast, Nat.mod_eq_of_lt t64']
      exact t64.le
  have e1 : t < 64  64 - (64 - t) = t := fun h  Nat.sub_sub_self h.le
  have e2 : t < 64  i - (64 - t) = i + t - 64 := fun h  by rw [Nat.sub_sub_assoc h.le]
  have e3 : t < 64  t % 64 = t := fun h  by rw [Nat.mod_eq_of_lt h]
  have a0 : 64  i + t  i + t < 64  False := fun le lt  not_lt.mpr le lt
  simp only [UInt128.shiftRight, UInt128.testBit_eq, bif_eq_if, beq_iff_eq, decide_eq_true_eq,
    apply_ite (f := UInt64.toNat), apply_ite (f := fun x  Nat.testBit x i),
    apply_ite (f := fun x  Nat.testBit x (i - 64)),
    UInt64.toNat_shiftRight', UInt64.toNat_lor, Nat.testBit_lor, Nat.testBit_div_two_pow,
    Nat.testBit_mul_two_pow, UInt64.toNat_shiftLeft', Nat.testBit_mod_two_pow, apply_decide,
    UInt64.land_eq_hand, UInt64.eq_iff_toNat_eq, UInt64.toNat_zero, UInt64.toNat_land, p127,
    UInt64.toNat_cast, Nat.mod_eq_of_lt t64', UInt64.lt_iff_toNat_lt, t127, p64, Bool.or_true,
    Bool.or_false, Nat.testBit_zero', t127']
  aesop  -- There is a `t % 64` under a `t < 64` assumption in the first goal, but `e3` is gone

Last updated: Dec 20 2023 at 11:08 UTC