Zulip Chat Archive

Stream: general

Topic: extracting proof values from let decls in do notation


Frederick Pu (Sep 15 2025 at 22:08):

I'm trying to make a tactic that allows you to use the last let decl from inside a do notation as a proof about the function's return value
here's an example

import Mathlib

set_option pp.letVarTypes true

def womp (x : Nat) : Nat := Id.run <| do
  let y := 2 * x;
  have : Even y := by
    use x
    ring
  return y

example :  x, Even (womp x) := by
  intro x
  unfold womp
  /-
    x : ℕ
    ⊢ Even
      (let y : ℕ := 2 * x;
        have this : Even y := ⋯;
      pure y).run
  -/

does anybody know what a good way to approach this would be?

Aaron Liu (Sep 15 2025 at 22:11):

are you looking for extract_lets?

Frederick Pu (Sep 15 2025 at 22:27):

thanks that's exactly what i was looking for

Frederick Pu (Sep 15 2025 at 22:27):

now i can do this:

import Mathlib

set_option pp.letVarTypes true

/-
  simple tail assertion `have : Even y`
-/
def double_proof (x : Nat) : Nat := Id.run <| do
  let y := 2 * x;
  have : Even y := by
    use x
    ring
  return y

def double (x : Nat) : Nat := Id.run <| do
  let y := 2 * x;
  return y

theorem even_double :  x, Even (double x) := by
  intro x
  have : double = double_proof := rfl
  rw [this]
  unfold double_proof
  extract_lets y this
  (expose_names; exact this)

Frederick Pu (Sep 15 2025 at 22:45):

any tips if im trying to do something like this?

/-
  loop invariant example
-/
def sumN_proof (n : Nat) : Nat := Id.run <| do
  let mut out := 0
  have : (out : Nat) = ((0:) - 1 : ) * 0 / 2 := rfl
  for i in List.range n do
    let prev := out
    out := out + i
    let new := out
    have : prev = (i-1 : ) * i / 2  new = (i + 1 - 1 : ) * ((i + 1)) / 2 := by
      intro h
      have : new = prev + i := by tauto
      zify at this
      rw [this, h]
      qify
      rw [Int.cast_div, Int.cast_div]
      field_simp
      push_cast
      ring
      convert Int.two_dvd_mul_add_one (i + 1 - 1 : )
      ring
      norm_num
      convert Int.two_dvd_mul_add_one (i - 1 : )
      ring
      norm_num

  return out

def sumN (n : Nat) : Nat := Id.run <| do
  let mut out := 0
  for i in List.range n do
    out := out + i
  return out

theorem sumN_spec :  n : Nat, sumN n = n * (n + 1) / 2 := by
  have : sumN = sumN_proof := by rfl
  intro n
  rw [this]
  unfold sumN_proof
  /-
    this : sumN = sumN_proof
    n : ℕ
    ⊢ (let out : ℕ := 0;
        have this : ↑out = (↑0 - 1) * 0 / 2 := sumN_proof._proof_1;
        do
        let r ←
          forIn (List.range n) out fun i r =>
              let out : ℕ := r;
              let prev : ℕ := out;
              let out : ℕ := out + i;
              let new : ℕ := out;
              have this : ↑prev = (↑i - 1) * ↑i / 2 → ↑new = (↑i + 1 - 1) * (↑i + 1) / 2 := ⋯;
              do
              pure PUnit.unit
              pure (ForInStep.yield out)
        have out : ℕ := r
        pure out).run =
      n * (n + 1) / 2
  -/

Aaron Liu (Sep 15 2025 at 22:56):

you may be interested in #general > new monadic program verification framework

Frederick Pu (Sep 15 2025 at 23:01):

im worried that mvcgen will remove the the locally annotated invariants during simplification or smth

Frederick Pu (Sep 15 2025 at 23:02):

hopefully if that's the case there's a flag or smth

Frederick Pu (Sep 15 2025 at 23:08):

im also getting:

unexpected token '⌜'; expected '_' or identifier

Frederick Pu (Sep 15 2025 at 23:08):

when i try to declare

theorem sumN_spec (n : Nat) : True sumN n r => r = n * (n + 1) / 2 := by sorry

Aaron Liu (Sep 15 2025 at 23:24):

You need to open something

Frederick Pu (Sep 15 2025 at 23:25):

i did

open Std Do

theorem sumN_spec (n : Nat) : True sumN n r => r = n * (n + 1) / 2  := by

but now i get

Type of sumN n is not a type application: 

Frederick Pu (Sep 15 2025 at 23:26):

k i changed the function to return Id Nat but now i get:

Type mismatch
  r = n * (n + 1) / 2
has type
  Prop
but is expected to have type
  Assertion PostShape.pure

Frederick Pu (Sep 16 2025 at 03:10):

so something liek this should work at extrcating proofs out of for loops:

/-
  loop invariant example
-/
def sumN_proof (n : Nat) : Id Nat := do
  let mut out := 0
  have : (out : Nat) = ((0:) - 1 : ) * 0 / 2 := rfl
  for i in List.range n do
    let prev := out
    out := out + i
    let new := out
    have : prev = (i-1 : ) * i / 2  new = (i + 1 - 1 : ) * ((i + 1)) / 2 := by
      intro h
      have : new = prev + i := by tauto
      zify at this
      rw [this, h]
      qify
      rw [Int.cast_div, Int.cast_div]
      field_simp
      push_cast
      ring
      convert Int.two_dvd_mul_add_one (i + 1 - 1 : )
      ring
      norm_num
      convert Int.two_dvd_mul_add_one (i - 1 : )
      ring
      norm_num

  return out

def sumN (n : Nat) : Id Nat := do
  let mut out := 0
  for i in List.range n do
    out := out + i
  have : out = 2 := by
    expose_names

  return out

open Std Do

universe u v w u₁ u₂

#check Invariant
#check ForInStep.value

theorem crux {β}
  (n : Nat) (b : β)
  (f : Nat  β  Id (ForInStep β))
  (inv : Nat  β  Prop)
  (base : inv 0 b)
  (ind :  b : β,  i : Nat, inv i b  inv (i + 1) (Id.run (f i b)).value) :
  inv n (Id.run (forIn (List.range n) b f)) := by
  sorry

-- #check forIn
#check Invariant.withEarlyReturn
#check PostCond
theorem sumN_spec (n : Nat) (hn : n > 0) :
  sumN n = (n-1) * n / 2:= by
  /-
  unexpected token '⌜'; expected '_' or identifier
  -/
  have : sumN = sumN_proof := by rfl
  rw [this]
  unfold sumN_proof
  extract_lets out this
  let l : Id  := (forIn (List.range n) out (fun i r =>
          let out :  := r;
          let prev :  := out;
          let out :  := out + i;
          let new :  := out;
          have this : prev = (i - 1 : ) * i / 2  new = (i + 1 - 1 : ) * (i + 1) / 2 := by
            intro h
            have : new = prev + i := by tauto
            zify at this
            rw [this, h]
            qify
            rw [Int.cast_div, Int.cast_div]
            field_simp
            push_cast
            ring
            convert Int.two_dvd_mul_add_one (i + 1 - 1 : )
            ring
            norm_num
            convert Int.two_dvd_mul_add_one (i - 1 : )
            ring
            norm_num
          do
          pure PUnit.unit
          pure (ForInStep.yield out)))
  let inv (i r : Nat) := r = (i - 1 : ) * i / 2
  have : (do l) = pure ((n - 1 : ) * n / 2) := by
    apply crux _ _ _ inv
    exact this
    intro b i h
    extract_lets
    tauto
  have : l = pure ((n - 1) * n / 2) := by
    suffices : Id.run l = Id.run (pure ((n - 1) * n / 2))
    exact this
    simp
    have : Id.run (do l : Id ) = Id.run (pure ((n - 1 : ) * n / 2)) := by exact this
    simp at this
    zify
    convert this
    exact Int.natCast_pred_of_pos hn
  dsimp only [l] at this
  dsimp
  rw [this]
  rfl

Frederick Pu (Sep 16 2025 at 03:11):

In reality the definition of l would be replaced by proof automation. i think at least for Id this is much cleaner than using Hoare Triples and the Invariant class.

Frederick Pu (Sep 16 2025 at 03:14):

actually all that's required to get l easily would be if extract_lets worked with do notation

Frederick Pu (Sep 16 2025 at 04:23):

also why dont we get access to r here?

def sumN (n : Nat) : Id Nat := do
  let mut out := 0
  for i in List.range n do
    out := out + i
  have : out = 2 := by
    /-
      n : ℕ
      out_1 : ℕ := 0
      r : ℕ
      out : ℕ := r
      ⊢ out = 2
    -/
    expose_names

  return out

Frederick Pu (Sep 16 2025 at 06:42):

was able to get it working with this:

partial def decomposeBinds (mvarId : MVarId) (e : Expr) : MetaM MVarId := do
  if e.isAppOf' ``Bind.bind then
    dbg_trace "bruh"
    let #[α, β, m, f, c, d] := e.getAppArgs' | throwError "unexpected args to Bind.bind"
    dbg_trace "bind found:"
    dbg_trace "  monad value: {c}"
    dbg_trace "  continuation: {d}"

    dbg_trace c
    -- Add `f` as a let in the main goal
    let mvarId  mvarId.define `c ( Meta.inferType c) c
    let (_, mvarId)  mvarId.intros
    decomposeBinds mvarId d
  else
    let mut mvarId := mvarId
    -- recurse into all subterms
    for arg in e.getAppArgs' do
      mvarId  decomposeBinds mvarId arg
    return mvarId

elab "decomposeBinds" : tactic => do
  let g  getMainGoal
  let t  g.getType
  -- dbg_trace "goal type: {t}"
  let g'  decomposeBinds g t
  replaceMainGoal [g']

theorem sumN_spec (n : Nat) (hn : n > 0) :
  sumN n = (n-1) * n / 2:= by
  /-
  unexpected token '⌜'; expected '_' or identifier
  -/
  have : sumN = sumN_proof := by rfl
  rw [this]
  unfold sumN_proof
  extract_lets out this
  decomposeBinds
  expose_names
  let inv (i r : Nat) := r = (i - 1 : ) * i / 2
  have : (do c) = pure ((n - 1 : ) * n / 2) := by
    apply crux _ _ _ inv
    exact this
    intro b i h
    extract_lets
    tauto
  have : c = pure ((n - 1) * n / 2) := by
    suffices : Id.run c = Id.run (pure ((n - 1) * n / 2))
    exact this
    simp
    have : Id.run (do c : Id ) = Id.run (pure ((n - 1 : ) * n / 2)) := by exact this
    simp at this
    zify
    convert this
    exact Int.natCast_pred_of_pos hn
  dsimp only [c] at this
  dsimp
  rw [this]
  rfl

Frederick Pu (Sep 16 2025 at 06:48):

once you find the right invariant you can often golf it even more:

theorem sumN_spec (n : Nat) : (do sumN n) = pure ((n-1 : ) * (n:) / 2) := by
  have : sumN = sumN_proof := by rfl
  rw [this]
  let inv (i r : Nat) := r = (i - 1 : ) * i / 2
  unfold sumN_proof
  extract_lets
  apply crux _ _ _ inv
  · tauto
  · intro b i h
    extract_lets
    tauto

Frederick Pu (Sep 16 2025 at 07:18):

here's a slightly less trivial example:

/-
  loop invariant example
-/
def sumN_proof (n : Nat) : Id Nat := do
  let mut out := 0
  have : (out : Nat) = ((0:) - 1 : ) * 0 / 2 := rfl
  for i in List.range n do
    let prev := out
    out := out + i
    let new := out
    have : prev = (i-1 : ) * i / 2  new = (i + 1 - 1 : ) * ((i + 1)) / 2 := by
      intro h
      have : new = prev + i := by tauto
      zify at this
      rw [this, h]
      qify
      rw [Int.cast_div, Int.cast_div]
      field_simp
      push_cast
      ring
      convert Int.two_dvd_mul_add_one (i + 1 - 1 : )
      ring
      norm_num
      convert Int.two_dvd_mul_add_one (i - 1 : )
      ring
      norm_num

  return out + 2

def sumN (n : Nat) : Id Nat := do
  let mut out := 0
  for i in List.range n do
    out := out + i
  return out + 2


open Std Do

universe u v w u₁ u₂

#check Invariant
#check ForInStep.value

theorem crux {β}
  (n : Nat) (b : β)
  (f : Nat  β  Id (ForInStep β))
  (inv : Nat  β  Prop)
  (base : inv 0 b)
  (ind :  b : β,  i : Nat, inv i b  inv (i + 1) (Id.run (f i b)).value) :
  inv n (Id.run (forIn (List.range n) b f)) := by
  sorry

-- #check forIn
#check Invariant.withEarlyReturn
#check PostCond

open Lean
#check MVarId.extractLets

open Lean Meta Elab Tactic

open Qq Lean
#check Hypothesis

partial def decomposeBinds (mvarId : MVarId) (e : Expr) : MetaM MVarId := do
  if e.isAppOf' ``Bind.bind then
    dbg_trace "bruh"
    let #[α, β, m, f, c, d] := e.getAppArgs' | throwError "unexpected args to Bind.bind"
    dbg_trace "bind found:"
    dbg_trace "  monad value: {c}"
    dbg_trace "  continuation: {d}"

    dbg_trace c
    -- Add `f` as a let in the main goal
    let mvarId  mvarId.define `c ( Meta.inferType c) c
    let (_, mvarId)  mvarId.intros
    decomposeBinds mvarId c
  else
    let mut mvarId := mvarId
    -- recurse into all subterms
    for arg in e.getAppArgs' do
      mvarId  decomposeBinds mvarId arg
    return mvarId

elab "decomposeBinds" : tactic => do
  let g  getMainGoal
  let t  g.getType
  -- dbg_trace "goal type: {t}"
  let g'  decomposeBinds g t
  replaceMainGoal [g']

#check congrArg

theorem sumN_spec (n : Nat) (hn : n > 0 ): Id.run (sumN n) = ((n-1) * (n) / 2 + 2) := by
  have : sumN = sumN_proof := by rfl
  let inv (i r : Nat) := r = (i - 1 : ) * i / 2
  have : (do sumN n) = pure ((n - 1 : ) * n / 2 + 2) := by
    rw [this]
    unfold sumN_proof
    extract_lets
    decomposeBinds
    expose_names
    have : inv n (Id.run c_1) := by
      apply crux _ _ _ inv
      · exact this_1
      · intro b i h
        extract_lets
        tauto
    simp [inv] at this
    have : c_1 = pure (n * (n - 1) / 2) := by
      have q :  x y : Id Nat, Id.run x = Id.run y  x = y := by exact fun x y a => a
      apply q
      simp
      zify
      rw [Nat.cast_sub hn]
      push_cast
      rw [this]
      ring
    dsimp only [c_1] at this
    rw [this]
    have :  x y : Id Int, Id.run x = Id.run y  x = y := by exact fun x y a => a
    apply this
    simp
    rw [mul_comm, Nat.cast_sub hn]
    rfl
  have := congrArg (Id.run) this
  simp at this
  zify
  rw [this, Nat.cast_sub hn]
  ring

Last updated: Dec 20 2025 at 21:32 UTC