Zulip Chat Archive
Stream: Machine Learning for Theorem Proving
Topic: Better way for tracing tactic states
Frederick Pu (Jun 26 2025 at 01:09):
I remember earlier last year (2024) having a conversation with @Jason Rute about how to break down induction n with | zero => ... | succ n => ... type blocks when generating pretraining data for LeanCopilote and other tactic based HyperTree search systems. Apparently there had been attempts to replace each subgoal with ?_ but they always turned out a bit hacky. However it seems with the new allTactics flag for LeanRepl it is quite easy to find all of the atomic tactics in a proof. We can then look for the infotree nodes of each atomic tactic and make sure that the syntax position of all of its children is replaced with ?_. Here is the python code for my example:
import subprocess
import json
from typing import List, Tuple
def apply_replacements(src: str,
replacements: List[Tuple[int, int, int, int, str]]) -> str:
"""
replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
Lines are 1-indexed, columns are 0-indexed.
"""
# Step 1: Split into lines with line endings preserved
lines = src.splitlines(keepends=True)
# Step 2: Compute character offsets for line/column positions
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
def to_offset(line: int, col: int) -> int:
return line_offsets[line - 1] + col
# Step 3: Convert replacements to absolute offsets
abs_replacements = [
(to_offset(sl, sc), to_offset(el, ec), rep)
for (sl, sc, el, ec, rep) in replacements
]
# Step 4: Sort by start offset, descending
abs_replacements.sort(reverse=True, key=lambda r: r[0])
# Step 5: Apply replacements
result = src
for start, end, rep in abs_replacements:
result = result[:start] + rep + result[end:]
return result
def position_to_offset(src: str,
row: int, col: int) -> int:
"""
Given a string and a (1-indexed) line number and (0-indexed) column number,
return the absolute character offset in the string.
"""
lines = src.splitlines(keepends=True)
# Precompute character offsets for each line
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
if row < 1 or row > len(lines):
raise ValueError(f"Row {row} out of range (1..{len(lines)})")
return line_offsets[row - 1] + col
def posToTup(pos):
return (pos["line"], pos["column"])
def posLt(pos1, pos2):
if pos1["line"] < pos2["line"]:
return True
return pos1["column"] < pos2["column"]
def getSubNodes(node, startPos, endPos):
out = []
for tree in node["children"]:
if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
if temp not in out:
out.append(temp)
else:
out += getSubNodes(tree, startPos, endPos)
return out
def getTacticNode(infotree, ppName):
for tree in infotree:
try:
stxRange = tree["node"]["stx"]["range"]
if not isinstance(stxRange, dict):
continue
except (KeyError, TypeError):
continue
if tree["node"]["stx"]["pp"] == ppName:
return tree
temp = getTacticNode(tree.get("children", []), ppName)
if temp:
return temp
return None
if __name__ == "__main__":
source_code = (
"theorem womp (n : Nat): 2 + 2 = 5 := by\n"
" rw [Nat.add_comm]; rw [Nat.add_comm]\n"
" rw [Nat.add_comm]\n"
" induction n with\n"
" | zero => sorry\n"
" | succ n => sorry"
)
process = subprocess.Popen(
["lake", "exe", "repl"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8"
)
payload = {
"cmd": source_code,
"infotree": "original",
"allTactics": True
}
stdout, stderr = process.communicate(json.dumps(payload))
feedback = json.loads(stdout)
print("Feedback loaded")
infotree = feedback["infotree"]
tactics_info = feedback.get("tactics", [])
for tactic in tactics_info:
pp_name = tactic["tactic"]
tactic_node = getTacticNode(infotree, pp_name)
if tactic_node is None:
print(f"Could not find node for tactic: {pp_name}")
continue
start = tactic_node["node"]["stx"]["range"]["start"]
finish = tactic_node["node"]["stx"]["range"]["finish"]
sub_ranges = getSubNodes(tactic_node, start, finish)
modified_src = apply_replacements(source_code, sub_ranges)
start_offset = position_to_offset(source_code, start["line"], start["column"])
end_offset = position_to_offset(source_code, finish["line"], finish["column"])
# Extract the final cleaned-up tactic string
cleaned_tactic = modified_src[start_offset:end_offset]
print(cleaned_tactic)
print("--")
Note that this sort of thing should also handle have statements by basically turning them into suffices. If any one has any examples that break this or general feedback that would be great.
Frederick Pu (Jun 26 2025 at 01:11):
Also if anyone has done something similar pls share your code
Frederick Pu (Jun 26 2025 at 01:19):
I found a slight cropping bug which i fixed in the below version:
import subprocess
import json
from typing import List, Tuple
def apply_replacements(src: str,
replacements: List[Tuple[int, int, int, int, str]]) -> str:
"""
replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
Lines are 1-indexed, columns are 0-indexed.
"""
# Step 1: Split into lines with line endings preserved
lines = src.splitlines(keepends=True)
# Step 2: Compute character offsets for line/column positions
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
def to_offset(line: int, col: int) -> int:
return line_offsets[line - 1] + col
# Step 3: Convert replacements to absolute offsets
abs_replacements = [
(to_offset(sl, sc), to_offset(el, ec), rep)
for (sl, sc, el, ec, rep) in replacements
]
# Step 4: Sort by start offset, descending
abs_replacements.sort(reverse=True, key=lambda r: r[0])
# Step 5: Apply replacements
result = src
for start, end, rep in abs_replacements:
result = result[:start] + rep + result[end:]
return result
def position_to_offset(src: str,
row: int, col: int) -> int:
"""
Given a string and a (1-indexed) line number and (0-indexed) column number,
return the absolute character offset in the string.
"""
lines = src.splitlines(keepends=True)
# Precompute character offsets for each line
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
if row < 1 or row > len(lines):
raise ValueError(f"Row {row} out of range (1..{len(lines)})")
return line_offsets[row - 1] + col
def posToTup(pos):
return (pos["line"], pos["column"])
def posLt(pos1, pos2):
if pos1["line"] < pos2["line"]:
return True
return pos1["column"] < pos2["column"]
def getSubNodes(node, startPos, endPos):
out = []
for tree in node["children"]:
if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
if temp not in out:
out.append(temp)
else:
out += getSubNodes(tree, startPos, endPos)
return out
def getTacticNode(infotree, ppName):
for tree in infotree:
try:
stxRange = tree["node"]["stx"]["range"]
if not isinstance(stxRange, dict):
continue
except (KeyError, TypeError):
continue
if tree["node"]["stx"]["pp"] == ppName:
return tree
temp = getTacticNode(tree.get("children", []), ppName)
if temp:
return temp
return None
if __name__ == "__main__":
source_code = (
"theorem womp (n : Nat): 2 + 2 = 5 := by\n"
" rw [Nat.add_comm]; rw [Nat.add_comm]\n"
" rw [Nat.add_comm]\n"
" induction n with\n"
" | zero => sorry\n"
" | succ n => sorry"
)
process = subprocess.Popen(
["lake", "exe", "repl"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8"
)
payload = {
"cmd": source_code,
"infotree": "original",
"allTactics": True
}
stdout, stderr = process.communicate(json.dumps(payload))
feedback = json.loads(stdout)
print("Feedback loaded")
infotree = feedback["infotree"]
tactics_info = feedback.get("tactics", [])
for tactic in tactics_info:
pp_name = tactic["tactic"]
tactic_node = getTacticNode(infotree, pp_name)
if tactic_node is None:
print(f"Could not find node for tactic: {pp_name}")
continue
start = tactic_node["node"]["stx"]["range"]["start"]
finish = tactic_node["node"]["stx"]["range"]["finish"]
sub_ranges = getSubNodes(tactic_node, start, finish)
modified_src = apply_replacements(source_code, sub_ranges)
start_offset = position_to_offset(source_code, start["line"], start["column"])
end_offset = position_to_offset(source_code, finish["line"], finish["column"])
# Extract the final cleaned-up tactic string
cleaned_tactic = modified_src[start_offset:end_offset - (len(modified_src) - len(source_code))]
print(cleaned_tactic)
print("--")
Frederick Pu (Jun 28 2025 at 15:12):
i fixed another bug with the slicing
Frederick Pu (Jun 28 2025 at 15:12):
import subprocess
import json
from typing import List, Tuple
def apply_replacements(src: str,
replacements: List[Tuple[int, int, int, int, str]]) -> str:
"""
replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
Lines are 1-indexed, columns are 0-indexed.
"""
# Step 1: Split into lines with line endings preserved
lines = src.splitlines(keepends=True)
# Step 2: Compute character offsets for line/column positions
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
def to_offset(line: int, col: int) -> int:
return line_offsets[line - 1] + col
# Step 3: Convert replacements to absolute offsets
abs_replacements = [
(to_offset(sl, sc), to_offset(el, ec), rep)
for (sl, sc, el, ec, rep) in replacements
]
# Step 4: Sort by start offset, descending
abs_replacements.sort(reverse=True, key=lambda r: r[0])
# Step 5: Apply replacements
result = src
for start, end, rep in abs_replacements:
result = result[:start] + rep + result[end:]
return result
def position_to_offset(src: str,
row: int, col: int) -> int:
"""
Given a string and a (1-indexed) line number and (0-indexed) column number,
return the absolute character offset in the string.
"""
lines = src.splitlines(keepends=True)
# Precompute character offsets for each line
line_offsets = [0]
for line in lines:
line_offsets.append(line_offsets[-1] + len(line))
if row < 1 or row > len(lines):
raise ValueError(f"Row {row} out of range (1..{len(lines)})")
return line_offsets[row - 1] + col
def posToTup(pos):
return (pos["line"], pos["column"])
def posLt(pos1, pos2):
if pos1["line"] < pos2["line"]:
return True
return pos1["column"] < pos2["column"]
def getSubNodes(node, startPos, endPos):
out = []
for tree in node["children"]:
if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
if temp not in out:
out.append(temp)
else:
out += getSubNodes(tree, startPos, endPos)
return out
def getTacticNode(infotree, ppName):
for tree in infotree:
try:
stxRange = tree["node"]["stx"]["range"]
if not isinstance(stxRange, dict):
continue
except (KeyError, TypeError):
continue
if tree["node"]["stx"]["pp"] == ppName:
return tree
temp = getTacticNode(tree.get("children", []), ppName)
if temp:
return temp
return None
if __name__ == "__main__":
source_code = (
"theorem womp (n : Nat): 2 + 2 = 5 := by\n"
" have womp : 2 + 2 = 4 := by\n rw [Nat.add_comm, Nat.add_comm]\n rfl\n"
" rw [Nat.add_comm]; rw [Nat.add_comm]\n"
" rw [Nat.add_comm]\n"
" induction n with\n"
" | zero => sorry\n"
" | succ n => sorry"
)
process = subprocess.Popen(
["lake", "exe", "repl"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8"
)
payload = {
"cmd": source_code,
"infotree": "original",
"allTactics": True
}
stdout, stderr = process.communicate(json.dumps(payload))
feedback = json.loads(stdout)
print("Feedback loaded")
infotree = feedback["infotree"]
tactics_info = feedback.get("tactics", [])
print("---")
for tactic in tactics_info:
pp_name = tactic["tactic"]
tactic_node = getTacticNode(infotree, pp_name)
if tactic_node is None:
print(f"Could not find node for tactic: {pp_name}")
continue
# print(tactic_node["node"]["stx"]["pp"])
start = tactic_node["node"]["stx"]["range"]["start"]
finish = tactic_node["node"]["stx"]["range"]["finish"]
sub_ranges = getSubNodes(tactic_node, start, finish)
modified_src = apply_replacements(source_code, sub_ranges)
start_offset = position_to_offset(source_code, start["line"], start["column"])
end_offset = position_to_offset(source_code, finish["line"], finish["column"])
# Extract the final cleaned-up tactic string
cleaned_tactic = modified_src[start_offset:end_offset - (len(source_code) - len(modified_src))]
print(cleaned_tactic)
print("--")
Frederick Pu (Jun 28 2025 at 15:13):
here is sample output
Feedback loaded
---
have womp : 2 + 2 = 4 := by
?_
?_
--
rw [Nat.add_comm, Nat.add_comm]
--
rfl
--
rw [Nat.add_comm]
--
rw [Nat.add_comm]
--
rw [Nat.add_comm]
--
induction n with
| zero => ?_
| succ n => ?_
--
induction n with
| zero => ?_
| succ n => ?_
--
sorry
--
sorry
--
Frederick Pu (Jun 28 2025 at 15:15):
seems like have statements are still handled incorrectly
Frederick Pu (Jun 29 2025 at 22:34):
import REPL.Frontend
import REPL.Lean.InfoTree
import REPL.Lean.InfoTree.ToJson
import Std
open Lean Elab
open Std
def applyReplacements (src : String)
(replacements : List (String.Pos × String.Pos)) (replacementStr : String) : String :=
-- Step 1: Sort by start position descending so earlier edits don't shift later ones
let sorted := replacements.toArray.qsort (fun a b => a.fst > b.fst) |>.toList
-- Step 2: Apply the replacements
sorted.foldl (fun acc (startPos, endPos) =>
acc.extract 0 startPos ++ replacementStr ++ acc.extract endPos acc.endPos
) src
#check Json.getObjValD
partial def Lean.Elab.InfoTree.findAllInfoTree (t : InfoTree) (ctx? : Option ContextInfo) (p : Info → Bool) :
List (Info × Option ContextInfo × InfoTree) :=
match t with
| .context ctx t => t.findAllInfoTree (ctx.mergeIntoOuter? ctx?) p
| .node i ts =>
let info := if p i then [(i, ctx?, t)] else []
let rest := ts.toList.flatMap (fun t => t.findAllInfoTree ctx? p)
info ++ rest
| _ => []
def Lean.Elab.InfoTree.findTacticInfoTrees (t : InfoTree) : IO (List (TacticInfo × Option ContextInfo × InfoTree)) :=
let infos := t.findAllInfoTree none fun i => match i with
| .ofTacticInfo i' => i.isOriginal && i'.isSubstantive
| _ => false
infos.filterMapM fun p => match p with
| (.ofTacticInfo i, some ctx, t) =>
ctx.runMetaM {} try
let _ ← Lean.PrettyPrinter.ppTactic ⟨i.stx⟩
return (i, some ctx, t)
catch _ =>
pure none
| _ => pure none
partial def getSubNodePositions (infotree : InfoTree) (startPos : String.Pos) (endPos : String.Pos) (ctx? : Option ContextInfo) : IO (List (String.Pos × String.Pos)) := do
match infotree with
| .context ctx t => getSubNodePositions t startPos endPos (ctx.mergeIntoOuter? ctx?)
| .node info children =>
match info.stx.getRange? with
| some r =>
match ctx? with
| some ctx =>
let flag : Bool ← ctx.runMetaM {} try
let q ← Lean.PrettyPrinter.ppTactic ⟨info.stx⟩
match info with
| .ofTacticInfo info => return true
| _ => pure false
return true
catch _ =>
pure false
if (startPos < r.start ∨ r.stop < endPos) ∧ flag then
return [(r.start, r.stop)]
else
return (← (children.mapM (getSubNodePositions · startPos endPos ctx?))).toList.join
| none => pure []
| none => pure []
| .hole mvarId => pure []
/-
add holes of form ?_ to all subnodes of compound tactic combinator `infotree`
-/
partial def InfoTree.add_subnode_holes (src : String) (infotree : InfoTree) (ctx? : Option ContextInfo) : IO String := do
match infotree with
| .context ctx t =>
add_subnode_holes src t (ctx.mergeIntoOuter? ctx?)
| .node info children =>
match ctx? with
| some ctx =>
match info.stx.getRange? with
| some r =>
let subPos ← getSubNodePositions infotree r.start r.stop ctx
let temp := applyReplacements src subPos "?_"
return temp.extract r.start (r.stop - ⟨src.length - temp.length⟩)
| none => return ""
| none => return ""
| .hole mvarId => return ""
def test :=
String.join ["theorem womp (n : Nat): 2 + 2 = 5 := by\n",
" have womp : 2 + 2 = 4 := by\n rw [Nat.add_comm, Nat.add_comm]\n rfl\n",
" rw [Nat.add_comm]; rw [Nat.add_comm]\n",
" rw [Nat.add_comm]\n",
" induction n with\n",
" | zero => sorry\n",
" | succ n => sorry"]
def lol : IO Unit := do
let x := ((← IO.processInput test none).2.2.flatMap InfoTree.retainTacticInfo |>.flatMap InfoTree.retainOriginal)[0]!
let y := ← InfoTree.findTacticInfoTrees x
for q in y do
let (a, b, c) := q
-- IO.println z
IO.println "----"
IO.println (← (InfoTree.add_subnode_holes test c b))
IO.println "----"
#eval lol
-- #eval do return applyReplacements test (← InfoTree.getSubNodePositions (← InfoTree.findTacticInfoTrees ((← IO.processInput test none).2.2.flatMap InfoTree.retainOriginal)[0]!)[6]!) "?_"
Frederick Pu (Jun 29 2025 at 22:40):
my end goal is that given a lean proof you can generate this a HashMap GoalState String that can be used as a tacgen heuristic that will allow aesop to immediately prove the end goal. This basically generates a proof search tree that only contains the input proof. Then, we can randomly remove nodes (or edges) from this search graph to generate synthetic RL curriculumn for problems that are hard for current tactic state based models to prove. Obviously this would need to be along side some sort of RL mechanism within aesop
Last updated: Dec 20 2025 at 21:32 UTC