Zulip Chat Archive
Stream: metaprogramming / tactics
Topic: Forward chaining with `aesop`
Geoffrey Irving (Nov 26 2023 at 16:19):
I'm applying aesop
to a theorem with a bunch of if
s, which generate a bunch of fresh hypotheses via split_ifs
. I'd like to do the following with the hypotheses:
- Generate all consequences of them according to a list of functions. E.g., if there are two hypotheses
h0 h1
,f
is in the list of functions, andf h0 h1
type checks, I wantf h0 h1
to become a hypothesis. - simplify using all available hypotheses
Is there a good way to do that?
For example, if there are hypotheses
t : ℕ
_ : t < 128
_ : ↑t &&& 127 = 0 -- In UInt64
I want to use a lemma that takes these and produces t = 0
, then have aesop
simp
using that.
Geoffrey Irving (Nov 26 2023 at 20:36):
Ah, there is a forward
rule builder.
Jannis Limperg (Nov 27 2023 at 10:59):
Yes, forward
is what you're looking for. The current implementation is very inefficient, so you may run into trouble if you have big goals or many forward rules. We're designing a better implementation right now. (In fact, if you run into performance issues, I'd appreciate a ping because the issues might make for a good case study.)
Geoffrey Irving (Nov 27 2023 at 21:42):
Hmm, actually aesop
seems to be doing this forward chaining automatically, but inconsistently. I can make a MWE if it seems worth it, but it'd be a nontrivial amount of work so I want to check first. Here's what's happening.
The (very non-closed, and thus nonworking!) code is
/-- `shiftRight` rounds down -/
lemma UInt128.toNat_shiftRight (x : UInt128) {s : UInt64} (sl : s < 128) :
(x.shiftRight s).toNat = x.toNat / 2^s.toNat := by
generalize ht : s.toNat = t
have st : s = (t : UInt64) := by rw [←ht, UInt64.cast_toNat]
have t64 : t < 2^64 := by rw [←ht]; exact UInt64.toNat_lt_2_pow_64 _
have t64' : t < UInt64.size := t64
have p64 : (64 : UInt64).toNat = 64 := rfl
have p127 : (127 : UInt64).toNat = 127 := rfl
have p128 : (128 : UInt64).toNat = 128 := rfl
simp only [st, UInt64.lt_iff_toNat_lt, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', p128, st] at sl ⊢
have t127 : t &&& 127 = t := by nth_rw 2 [←Nat.mod_eq_of_lt sl]; exact @Nat.land_eq_mod _ 7
have t127' : (t : UInt64) &&& 127 = ↑t := by
simp only [UInt64.eq_iff_toNat_eq, UInt64.toNat_land, UInt64.toNat_cast, Nat.mod_eq_of_lt t64',
p127, t127]
refine Nat.eq_of_testBit_eq fun i ↦ ?_
have e0 : ¬t = 0 → t < 64 → (64 - (t : UInt64)).toNat % 64 = 64 - t := by
intro t0 t64
rw [UInt64.toNat_sub]
· simp only [p64, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', ge_iff_le, Nat.mod_succ_eq_iff_lt]
exact Nat.sub_lt (by norm_num) (Nat.pos_iff_ne_zero.mpr t0)
· simp only [UInt64.le_iff_toNat_le, UInt64.toNat_cast, Nat.mod_eq_of_lt t64']
exact t64.le
have e1 : t < 64 → 64 - (64 - t) = t := fun h ↦ Nat.sub_sub_self h.le
have e2 : t < 64 → i - (64 - t) = i + t - 64 := fun h ↦ by rw [Nat.sub_sub_assoc h.le]
have e3 : t < 64 → t % 64 = t := fun h ↦ by rw [Nat.mod_eq_of_lt h]
have a0 : 64 ≤ i + t → i + t < 64 → False := fun le lt ↦ not_lt.mpr le lt
simp only [UInt128.shiftRight, UInt128.testBit_eq, bif_eq_if, beq_iff_eq, decide_eq_true_eq,
apply_ite (f := UInt64.toNat), apply_ite (f := fun x ↦ Nat.testBit x i),
apply_ite (f := fun x ↦ Nat.testBit x (i - 64)),
UInt64.toNat_shiftRight', UInt64.toNat_lor, Nat.testBit_lor, Nat.testBit_div_two_pow,
Nat.testBit_mul_two_pow, UInt64.toNat_shiftLeft', Nat.testBit_mod_two_pow, apply_decide,
UInt64.land_eq_hand, UInt64.eq_iff_toNat_eq, UInt64.toNat_zero, UInt64.toNat_land, p127,
UInt64.toNat_cast, Nat.mod_eq_of_lt t64', UInt64.lt_iff_toNat_lt, t127, p64, Bool.or_true,
Bool.or_false, Nat.testBit_zero', t127']
aesop
Right before the aesop
, the state is
x: UInt128
s: UInt64
t: ℕ
ht: s.toNat = t
st: s = ↑t
t64: t < 2 ^ 64
t64': t < UInt64.size
p64: 64.toNat = 64
p127: 127.toNat = 127
p128: 128.toNat = 128
sl: t < 128
t127: t &&& 127 = t
t127': ↑t &&& 127 = ↑t
i: ℕ
e0: ¬t = 0 → t < 64 → (64 - ↑t).toNat % 64 = 64 - t
e1: t < 64 → 64 - (64 - t) = t
e2: t < 64 → i - (64 - t) = i + t - 64
e3: t < 64 → t % 64 = t
a0: 64 ≤ i + t → i + t < 64 → False
⊢ (if i < 64 then
if t = 0 then Nat.testBit x.lo.toNat i
else
if t < 64 then
if
(64 - ↑t).toNat % 64 ≤ i ∧
(if i - (64 - ↑t).toNat % 64 < 64 - (64 - ↑t).toNat % 64 then
Nat.testBit x.hi.toNat (i - (64 - ↑t).toNat % 64) && true
else Nat.testBit x.hi.toNat (i - (64 - ↑t).toNat % 64) && false) =
true then
true
else Nat.testBit x.lo.toNat (i + t % 64)
else Nat.testBit x.hi.toNat (i + (↑t - 64).toNat % 64)
else if t < 64 then Nat.testBit x.hi.toNat (i - 64 + t % 64) else false) =
if i + t < 64 then Nat.testBit x.lo.toNat (i + t) else Nat.testBit x.hi.toNat (i + t - 64)
As aesop
splits the ifs, it produces hypotheses like t < 64
or the reverse, and e0, e1, e2, e3
are supposed to fire in order to simplify in ways that needed the conditional hypotheses. However, aesop
appears to be dropping them sometimes. For example, one of the produced non-closed goals is
x: UInt128
t: ℕ
t64: t < 2 ^ 64
t64': t < UInt64.size
sl: t < 128
t127: t &&& 127 = t
t127': ↑t &&& 127 = ↑t
i: ℕ
e0: (64 - ↑t).toNat % 64 = 64 - t
e1: 64 - (64 - t) = t
e2: i - (64 - t) = i + t - 64
a0: 64 ≤ i + t → False
ht: t % UInt64.size = t
h✝⁵: i < 64
h✝⁴: ¬t = 0
h✝³: t < 64
h✝²: i + t - 64 < t
h✝¹: 64 ≤ i + t → Nat.testBit x.hi.toNat (i + t - 64) = false
h✝: i + t < 64
⊢ Nat.testBit x.lo.toNat (i + t % 64) = Nat.testBit x.lo.toNat (i + t)
Ideally e3
would have fired and simplified t % 64
to t
to close the goal, but e3
is gone for some reason.
Geoffrey Irving (Nov 27 2023 at 22:12):
To say it explicitly, let me know if an MWE is worth it and I’ll make one if so.
Jannis Limperg (Nov 30 2023 at 00:35):
This should be handled by simp_all
, which Aesop calls internally -- it should use t < 64
to simplify t < 64 -> t % 64 = t
to True -> t % 64 = t
, then to t % 64 = t
, and finally use this equation to rewrite everywhere in the goal.
Maybe the subst
rule kicks in and rewrites the equation right to left before simp_all
gets a chance? You can try whether aesop (erase Aesop.BuiltinRules.subst)
makes more progress.
For more debugging, I'll need a WE (doesn't need to be minimal, a branch will do).
Geoffrey Irving (Dec 02 2023 at 11:15):
I tried that erase, but it didn’t have any effect.
Here's a WE:
import Mathlib.Data.UInt
import Mathlib.Init.IteSimp
import Mathlib.Data.Nat.Bitwise
@[ext]
structure UInt128 where
lo : UInt64
hi : UInt64
deriving DecidableEq, BEq
@[pp_dot] def UInt128.toNat (x : UInt128) : ℕ :=
(x.hi.toNat * 2^64) + x.lo.toNat
/-- Divide by `2^(s % 128)`, rounding down -/
@[pp_dot] def UInt128.shiftRight (x : UInt128) (s : UInt64) : UInt128 :=
let s := s.land 127
{ lo := bif s == 0 then x.lo
else bif s < 64 then x.lo >>> s ||| x.hi <<< (64-s)
else x.hi >>> (s - 64)
hi := bif s < 64 then x.hi >>> s else 0 }
@[simp] lemma UInt64.cast_toNat (n : UInt64) : (n.toNat : UInt64) = n := sorry
@[simp] lemma UInt64.toNat_lt_2_pow_64 (n : UInt64) : n.toNat < 2^64 := Fin.prop _
lemma UInt64.eq_iff_toNat_eq (m n : UInt64) : m = n ↔ m.toNat = n.toNat := sorry
lemma UInt64.le_iff_toNat_le (m n : UInt64) : m ≤ n ↔ m.toNat ≤ n.toNat := sorry
lemma UInt64.lt_iff_toNat_lt (m n : UInt64) : m < n ↔ m.toNat < n.toNat := sorry
@[simp] lemma UInt64.toNat_cast (n : ℕ) : (n : UInt64).toNat = n % UInt64.size := sorry
@[simp] lemma UInt64.toNat_land {x y : UInt64} : (x &&& y).toNat = x.toNat &&& y.toNat := sorry
lemma UInt64.toNat_sub {x y : UInt64} (h : y ≤ x) : (x - y).toNat = x.toNat - y.toNat := sorry
lemma Nat.land_eq_mod {n k : ℕ} : n &&& (2^k-1) = n % 2^k := sorry
lemma Nat.sub_sub_assoc {a b c : ℕ} (h : c ≤ b) : a + c - b = a - (b - c) := sorry
@[simp] lemma Nat.testBit_div_two_pow {n k i : ℕ} : testBit (n / 2^k) i = testBit n (i+k) := sorry
@[simp] lemma Nat.testBit_mul_two_pow {n k i : ℕ} :
testBit (n * 2^k) i = decide (k ≤ i ∧ testBit n (i-k)) := sorry
lemma UInt128.testBit_eq {x : UInt128} {i : ℕ} :
x.toNat.testBit i = if i < 64 then x.lo.toNat.testBit i else x.hi.toNat.testBit (i-64) := sorry
lemma bif_eq_if {b : Bool} {x y : α} : (bif b then x else y) = if b then x else y := sorry
lemma UInt64.toNat_shiftRight' {x s : UInt64}
: (x >>> s).toNat = x.toNat / 2^(s.toNat % 64) := sorry
lemma UInt64.toNat_shiftLeft' {x s : UInt64} :
(x <<< s).toNat = x.toNat % 2^(64 - s.toNat % 64) * 2^(s.toNat % 64) := sorry
@[simp] lemma UInt64.toNat_lor {x y : UInt64} : (x ||| y).toNat = x.toNat ||| y.toNat := sorry
@[simp] lemma Nat.testBit_mod_two_pow {n k i : ℕ} :
testBit (n % 2^k) i = (testBit n i && i < k) := sorry
lemma apply_decide {f : Bool → α} {p : Prop} {dp : Decidable p} :
(f (@decide p dp)) = if p then f true else f false := sorry
@[simp] lemma UInt64.land_eq_hand {x y : UInt64} : UInt64.land x y = x &&& y := rfl
@[simp] lemma UInt64.toNat_zero : (0 : UInt64).toNat = 0 := rfl
@[simp] lemma Nat.testBit_zero' {i : ℕ} : testBit 0 i = false := sorry
/-- `shiftRight` rounds down -/
lemma UInt128.toNat_shiftRight (x : UInt128) {s : UInt64} (sl : s < 128) :
(x.shiftRight s).toNat = x.toNat / 2^s.toNat := by
generalize ht : s.toNat = t
have st : s = (t : UInt64) := by rw [←ht, UInt64.cast_toNat]
have t64 : t < 2^64 := by rw [←ht]; exact UInt64.toNat_lt_2_pow_64 _
have t64' : t < UInt64.size := t64
have p64 : (64 : UInt64).toNat = 64 := rfl
have p127 : (127 : UInt64).toNat = 127 := rfl
have p128 : (128 : UInt64).toNat = 128 := rfl
simp only [st, UInt64.lt_iff_toNat_lt, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', p128, st] at sl ⊢
clear st ht s
have t127 : t &&& 127 = t := by nth_rw 2 [←Nat.mod_eq_of_lt sl]; exact @Nat.land_eq_mod _ 7
have t127' : (t : UInt64) &&& 127 = ↑t := by
simp only [UInt64.eq_iff_toNat_eq, UInt64.toNat_land, UInt64.toNat_cast, Nat.mod_eq_of_lt t64',
p127, t127]
refine Nat.eq_of_testBit_eq fun i ↦ ?_
have e0 : ¬t = 0 → t < 64 → (64 - (t : UInt64)).toNat % 64 = 64 - t := by
intro t0 t64
rw [UInt64.toNat_sub]
· simp only [p64, UInt64.toNat_cast, Nat.mod_eq_of_lt t64', ge_iff_le, Nat.mod_succ_eq_iff_lt]
exact Nat.sub_lt (by norm_num) (Nat.pos_iff_ne_zero.mpr t0)
· simp only [UInt64.le_iff_toNat_le, UInt64.toNat_cast, Nat.mod_eq_of_lt t64']
exact t64.le
have e1 : t < 64 → 64 - (64 - t) = t := fun h ↦ Nat.sub_sub_self h.le
have e2 : t < 64 → i - (64 - t) = i + t - 64 := fun h ↦ by rw [Nat.sub_sub_assoc h.le]
have e3 : t < 64 → t % 64 = t := fun h ↦ by rw [Nat.mod_eq_of_lt h]
have a0 : 64 ≤ i + t → i + t < 64 → False := fun le lt ↦ not_lt.mpr le lt
simp only [UInt128.shiftRight, UInt128.testBit_eq, bif_eq_if, beq_iff_eq, decide_eq_true_eq,
apply_ite (f := UInt64.toNat), apply_ite (f := fun x ↦ Nat.testBit x i),
apply_ite (f := fun x ↦ Nat.testBit x (i - 64)),
UInt64.toNat_shiftRight', UInt64.toNat_lor, Nat.testBit_lor, Nat.testBit_div_two_pow,
Nat.testBit_mul_two_pow, UInt64.toNat_shiftLeft', Nat.testBit_mod_two_pow, apply_decide,
UInt64.land_eq_hand, UInt64.eq_iff_toNat_eq, UInt64.toNat_zero, UInt64.toNat_land, p127,
UInt64.toNat_cast, Nat.mod_eq_of_lt t64', UInt64.lt_iff_toNat_lt, t127, p64, Bool.or_true,
Bool.or_false, Nat.testBit_zero', t127']
aesop -- There is a `t % 64` under a `t < 64` assumption in the first goal, but `e3` is gone
Last updated: Dec 20 2023 at 11:08 UTC