Zulip Chat Archive

Stream: Machine Learning for Theorem Proving

Topic: Better way for tracing tactic states


Frederick Pu (Jun 26 2025 at 01:09):

I remember earlier last year (2024) having a conversation with @Jason Rute about how to break down induction n with | zero => ... | succ n => ... type blocks when generating pretraining data for LeanCopilote and other tactic based HyperTree search systems. Apparently there had been attempts to replace each subgoal with ?_ but they always turned out a bit hacky. However it seems with the new allTactics flag for LeanRepl it is quite easy to find all of the atomic tactics in a proof. We can then look for the infotree nodes of each atomic tactic and make sure that the syntax position of all of its children is replaced with ?_. Here is the python code for my example:

import subprocess
import json

from typing import List, Tuple

def apply_replacements(src: str,
                       replacements: List[Tuple[int, int, int, int, str]]) -> str:
    """
    replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
    Lines are 1-indexed, columns are 0-indexed.
    """
    # Step 1: Split into lines with line endings preserved
    lines = src.splitlines(keepends=True)

    # Step 2: Compute character offsets for line/column positions
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    def to_offset(line: int, col: int) -> int:
        return line_offsets[line - 1] + col

    # Step 3: Convert replacements to absolute offsets
    abs_replacements = [
        (to_offset(sl, sc), to_offset(el, ec), rep)
        for (sl, sc, el, ec, rep) in replacements
    ]

    # Step 4: Sort by start offset, descending
    abs_replacements.sort(reverse=True, key=lambda r: r[0])

    # Step 5: Apply replacements
    result = src
    for start, end, rep in abs_replacements:
        result = result[:start] + rep + result[end:]

    return result

def position_to_offset(src: str,
                       row: int, col: int) -> int:
    """
    Given a string and a (1-indexed) line number and (0-indexed) column number,
    return the absolute character offset in the string.
    """
    lines = src.splitlines(keepends=True)

    # Precompute character offsets for each line
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    if row < 1 or row > len(lines):
        raise ValueError(f"Row {row} out of range (1..{len(lines)})")

    return line_offsets[row - 1] + col


def posToTup(pos):
  return (pos["line"], pos["column"])

def posLt(pos1, pos2):
  if pos1["line"] < pos2["line"]:
    return True
  return pos1["column"] < pos2["column"]

def getSubNodes(node, startPos, endPos):
    out = []
    for tree in node["children"]:
        if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
          temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
          if temp not in out:
            out.append(temp)
        else:
          out += getSubNodes(tree, startPos, endPos)

    return out

def getTacticNode(infotree, ppName):
    for tree in infotree:
        try:
            stxRange = tree["node"]["stx"]["range"]
            if not isinstance(stxRange, dict):
                continue
        except (KeyError, TypeError):
            continue

        if tree["node"]["stx"]["pp"] == ppName:
            return tree

        temp = getTacticNode(tree.get("children", []), ppName)
        if temp:
            return temp

    return None

if __name__ == "__main__":
    source_code = (
        "theorem womp (n : Nat): 2 + 2 = 5 := by\n"
        " rw [Nat.add_comm]; rw [Nat.add_comm]\n"
        " rw [Nat.add_comm]\n"
        " induction n with\n"
        " | zero => sorry\n"
        " | succ n => sorry"
    )

    process = subprocess.Popen(
        ["lake", "exe", "repl"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        encoding="utf-8"
    )

    payload = {
        "cmd": source_code,
        "infotree": "original",
        "allTactics": True
    }

    stdout, stderr = process.communicate(json.dumps(payload))
    feedback = json.loads(stdout)

    print("Feedback loaded")

    infotree = feedback["infotree"]
    tactics_info = feedback.get("tactics", [])

    for tactic in tactics_info:
        pp_name = tactic["tactic"]
        tactic_node = getTacticNode(infotree, pp_name)

        if tactic_node is None:
            print(f"Could not find node for tactic: {pp_name}")
            continue

        start = tactic_node["node"]["stx"]["range"]["start"]
        finish = tactic_node["node"]["stx"]["range"]["finish"]

        sub_ranges = getSubNodes(tactic_node, start, finish)
        modified_src = apply_replacements(source_code, sub_ranges)

        start_offset = position_to_offset(source_code, start["line"], start["column"])
        end_offset = position_to_offset(source_code, finish["line"], finish["column"])

        # Extract the final cleaned-up tactic string
        cleaned_tactic = modified_src[start_offset:end_offset]
        print(cleaned_tactic)
        print("--")

Note that this sort of thing should also handle have statements by basically turning them into suffices. If any one has any examples that break this or general feedback that would be great.

Frederick Pu (Jun 26 2025 at 01:11):

Also if anyone has done something similar pls share your code

Frederick Pu (Jun 26 2025 at 01:19):

I found a slight cropping bug which i fixed in the below version:

import subprocess
import json

from typing import List, Tuple

def apply_replacements(src: str,
                       replacements: List[Tuple[int, int, int, int, str]]) -> str:
    """
    replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
    Lines are 1-indexed, columns are 0-indexed.
    """
    # Step 1: Split into lines with line endings preserved
    lines = src.splitlines(keepends=True)

    # Step 2: Compute character offsets for line/column positions
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    def to_offset(line: int, col: int) -> int:
        return line_offsets[line - 1] + col

    # Step 3: Convert replacements to absolute offsets
    abs_replacements = [
        (to_offset(sl, sc), to_offset(el, ec), rep)
        for (sl, sc, el, ec, rep) in replacements
    ]

    # Step 4: Sort by start offset, descending
    abs_replacements.sort(reverse=True, key=lambda r: r[0])

    # Step 5: Apply replacements
    result = src
    for start, end, rep in abs_replacements:
        result = result[:start] + rep + result[end:]

    return result

def position_to_offset(src: str,
                       row: int, col: int) -> int:
    """
    Given a string and a (1-indexed) line number and (0-indexed) column number,
    return the absolute character offset in the string.
    """
    lines = src.splitlines(keepends=True)

    # Precompute character offsets for each line
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    if row < 1 or row > len(lines):
        raise ValueError(f"Row {row} out of range (1..{len(lines)})")

    return line_offsets[row - 1] + col


def posToTup(pos):
  return (pos["line"], pos["column"])

def posLt(pos1, pos2):
  if pos1["line"] < pos2["line"]:
    return True
  return pos1["column"] < pos2["column"]

def getSubNodes(node, startPos, endPos):
    out = []
    for tree in node["children"]:
        if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
          temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
          if temp not in out:
            out.append(temp)
        else:
          out += getSubNodes(tree, startPos, endPos)

    return out

def getTacticNode(infotree, ppName):
    for tree in infotree:
        try:
            stxRange = tree["node"]["stx"]["range"]
            if not isinstance(stxRange, dict):
                continue
        except (KeyError, TypeError):
            continue

        if tree["node"]["stx"]["pp"] == ppName:
            return tree

        temp = getTacticNode(tree.get("children", []), ppName)
        if temp:
            return temp

    return None

if __name__ == "__main__":
    source_code = (
        "theorem womp (n : Nat): 2 + 2 = 5 := by\n"
        " rw [Nat.add_comm]; rw [Nat.add_comm]\n"
        " rw [Nat.add_comm]\n"
        " induction n with\n"
        " | zero => sorry\n"
        " | succ n => sorry"
    )

    process = subprocess.Popen(
        ["lake", "exe", "repl"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        encoding="utf-8"
    )

    payload = {
        "cmd": source_code,
        "infotree": "original",
        "allTactics": True
    }

    stdout, stderr = process.communicate(json.dumps(payload))
    feedback = json.loads(stdout)

    print("Feedback loaded")

    infotree = feedback["infotree"]
    tactics_info = feedback.get("tactics", [])

    for tactic in tactics_info:
        pp_name = tactic["tactic"]
        tactic_node = getTacticNode(infotree, pp_name)

        if tactic_node is None:
            print(f"Could not find node for tactic: {pp_name}")
            continue

        start = tactic_node["node"]["stx"]["range"]["start"]
        finish = tactic_node["node"]["stx"]["range"]["finish"]

        sub_ranges = getSubNodes(tactic_node, start, finish)
        modified_src = apply_replacements(source_code, sub_ranges)

        start_offset = position_to_offset(source_code, start["line"], start["column"])
        end_offset = position_to_offset(source_code, finish["line"], finish["column"])

        # Extract the final cleaned-up tactic string
        cleaned_tactic = modified_src[start_offset:end_offset - (len(modified_src) - len(source_code))]
        print(cleaned_tactic)
        print("--")

Frederick Pu (Jun 28 2025 at 15:12):

i fixed another bug with the slicing

Frederick Pu (Jun 28 2025 at 15:12):

import subprocess
import json

from typing import List, Tuple

def apply_replacements(src: str,
                       replacements: List[Tuple[int, int, int, int, str]]) -> str:
    """
    replacements: List of (start_line, start_col, end_line, end_col, replacement_string)
    Lines are 1-indexed, columns are 0-indexed.
    """
    # Step 1: Split into lines with line endings preserved
    lines = src.splitlines(keepends=True)

    # Step 2: Compute character offsets for line/column positions
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    def to_offset(line: int, col: int) -> int:
        return line_offsets[line - 1] + col

    # Step 3: Convert replacements to absolute offsets
    abs_replacements = [
        (to_offset(sl, sc), to_offset(el, ec), rep)
        for (sl, sc, el, ec, rep) in replacements
    ]

    # Step 4: Sort by start offset, descending
    abs_replacements.sort(reverse=True, key=lambda r: r[0])

    # Step 5: Apply replacements
    result = src
    for start, end, rep in abs_replacements:
        result = result[:start] + rep + result[end:]

    return result

def position_to_offset(src: str,
                       row: int, col: int) -> int:
    """
    Given a string and a (1-indexed) line number and (0-indexed) column number,
    return the absolute character offset in the string.
    """
    lines = src.splitlines(keepends=True)

    # Precompute character offsets for each line
    line_offsets = [0]
    for line in lines:
        line_offsets.append(line_offsets[-1] + len(line))

    if row < 1 or row > len(lines):
        raise ValueError(f"Row {row} out of range (1..{len(lines)})")

    return line_offsets[row - 1] + col


def posToTup(pos):
  return (pos["line"], pos["column"])

def posLt(pos1, pos2):
  if pos1["line"] < pos2["line"]:
    return True
  return pos1["column"] < pos2["column"]

def getSubNodes(node, startPos, endPos):
    out = []
    for tree in node["children"]:
        if (posLt(startPos, tree["node"]["stx"]["range"]["start"]) or posLt(tree["node"]["stx"]["range"]["finish"], endPos)) and tree["node"]["stx"]["pp"] != "<failed to pretty print>":
          temp = posToTup(tree["node"]["stx"]["range"]["start"]) + posToTup(tree["node"]["stx"]["range"]["finish"]) + ("?_",) #(tree["node"]["stx"]["pp"],)
          if temp not in out:
            out.append(temp)
        else:
          out += getSubNodes(tree, startPos, endPos)

    return out

def getTacticNode(infotree, ppName):
    for tree in infotree:
        try:
            stxRange = tree["node"]["stx"]["range"]
            if not isinstance(stxRange, dict):
                continue
        except (KeyError, TypeError):
            continue

        if tree["node"]["stx"]["pp"] == ppName:
            return tree

        temp = getTacticNode(tree.get("children", []), ppName)
        if temp:
            return temp

    return None

if __name__ == "__main__":
    source_code = (
        "theorem womp (n : Nat): 2 + 2 = 5 := by\n"
        " have womp : 2 + 2 = 4 := by\n  rw [Nat.add_comm, Nat.add_comm]\n  rfl\n"
        " rw [Nat.add_comm]; rw [Nat.add_comm]\n"
        " rw [Nat.add_comm]\n"
        " induction n with\n"
        " | zero => sorry\n"
        " | succ n => sorry"
    )

    process = subprocess.Popen(
        ["lake", "exe", "repl"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        encoding="utf-8"
    )

    payload = {
        "cmd": source_code,
        "infotree": "original",
        "allTactics": True
    }

    stdout, stderr = process.communicate(json.dumps(payload))
    feedback = json.loads(stdout)

    print("Feedback loaded")

    infotree = feedback["infotree"]
    tactics_info = feedback.get("tactics", [])
    print("---")

    for tactic in tactics_info:
        pp_name = tactic["tactic"]
        tactic_node = getTacticNode(infotree, pp_name)

        if tactic_node is None:
            print(f"Could not find node for tactic: {pp_name}")
            continue

        # print(tactic_node["node"]["stx"]["pp"])

        start = tactic_node["node"]["stx"]["range"]["start"]
        finish = tactic_node["node"]["stx"]["range"]["finish"]

        sub_ranges = getSubNodes(tactic_node, start, finish)
        modified_src = apply_replacements(source_code, sub_ranges)

        start_offset = position_to_offset(source_code, start["line"], start["column"])
        end_offset = position_to_offset(source_code, finish["line"], finish["column"])

        # Extract the final cleaned-up tactic string
        cleaned_tactic = modified_src[start_offset:end_offset - (len(source_code) - len(modified_src))]
        print(cleaned_tactic)
        print("--")

Frederick Pu (Jun 28 2025 at 15:13):

here is sample output

Feedback loaded
---
have womp : 2 + 2 = 4 := by
  ?_
  ?_
--
rw [Nat.add_comm, Nat.add_comm]
--
rfl
--
rw [Nat.add_comm]
--
rw [Nat.add_comm]
--
rw [Nat.add_comm]
--
induction n with
 | zero => ?_
 | succ n => ?_
--
induction n with
 | zero => ?_
 | succ n => ?_
--
sorry
--
sorry
--

Frederick Pu (Jun 28 2025 at 15:15):

seems like have statements are still handled incorrectly

Frederick Pu (Jun 29 2025 at 22:34):

import REPL.Frontend
import REPL.Lean.InfoTree
import REPL.Lean.InfoTree.ToJson
import Std

open Lean Elab

open Std

def applyReplacements (src : String)
    (replacements : List (String.Pos × String.Pos)) (replacementStr : String) : String :=

  -- Step 1: Sort by start position descending so earlier edits don't shift later ones
  let sorted := replacements.toArray.qsort (fun a b => a.fst > b.fst) |>.toList

  -- Step 2: Apply the replacements
  sorted.foldl (fun acc (startPos, endPos) =>
    acc.extract 0 startPos ++ replacementStr ++ acc.extract endPos acc.endPos
  ) src


#check Json.getObjValD

partial def Lean.Elab.InfoTree.findAllInfoTree (t : InfoTree) (ctx? : Option ContextInfo) (p : Info  Bool) :
    List (Info × Option ContextInfo × InfoTree) :=
  match t with
  | .context ctx t => t.findAllInfoTree (ctx.mergeIntoOuter? ctx?) p
  | .node i ts  =>
    let info := if p i then [(i, ctx?, t)] else []
    let rest := ts.toList.flatMap (fun t => t.findAllInfoTree ctx? p)
    info ++ rest
  | _ => []

def Lean.Elab.InfoTree.findTacticInfoTrees (t : InfoTree) : IO (List (TacticInfo × Option ContextInfo × InfoTree)) :=
  let infos := t.findAllInfoTree none fun i => match i with
  | .ofTacticInfo i' => i.isOriginal && i'.isSubstantive
  | _ => false
  infos.filterMapM fun p => match p with
  | (.ofTacticInfo i, some ctx, t) =>
    ctx.runMetaM {} try
      let _  Lean.PrettyPrinter.ppTactic i.stx
      return (i, some ctx, t)
    catch _ =>
      pure none
  | _ => pure none


partial def getSubNodePositions (infotree : InfoTree) (startPos : String.Pos) (endPos : String.Pos) (ctx? : Option ContextInfo) : IO (List (String.Pos × String.Pos)) := do
  match infotree with
  | .context ctx t => getSubNodePositions t startPos endPos (ctx.mergeIntoOuter? ctx?)
  | .node info children =>
    match info.stx.getRange? with
    | some r =>
      match ctx? with
      | some ctx =>
        let flag : Bool  ctx.runMetaM {} try
          let q  Lean.PrettyPrinter.ppTactic info.stx
          match info with
          | .ofTacticInfo  info => return true
          | _                   => pure false
          return true
        catch _ =>
          pure false
        if (startPos < r.start  r.stop < endPos)  flag then
          return [(r.start, r.stop)]
        else
          return ( (children.mapM (getSubNodePositions · startPos endPos ctx?))).toList.join
      | none => pure []
    | none => pure []
  | .hole mvarId => pure []

/-
  add holes of form ?_ to all subnodes of compound tactic combinator `infotree`
-/
partial def InfoTree.add_subnode_holes (src : String) (infotree : InfoTree) (ctx? : Option ContextInfo) : IO String := do
  match infotree with
  | .context ctx t =>
    add_subnode_holes src t (ctx.mergeIntoOuter? ctx?)
  | .node info children =>
    match ctx? with
    | some ctx =>
      match info.stx.getRange? with
      | some r =>
        let subPos  getSubNodePositions infotree r.start r.stop ctx
        let temp := applyReplacements src subPos "?_"
        return temp.extract r.start (r.stop - src.length - temp.length)
      | none => return ""
    | none => return ""
  | .hole mvarId => return ""

def test :=
  String.join ["theorem womp (n : Nat): 2 + 2 = 5 := by\n",
    " have womp : 2 + 2 = 4 := by\n  rw [Nat.add_comm, Nat.add_comm]\n  rfl\n",
    " rw [Nat.add_comm]; rw [Nat.add_comm]\n",
    " rw [Nat.add_comm]\n",
    " induction n with\n",
    " | zero => sorry\n",
    " | succ n => sorry"]

def lol : IO Unit := do
  let x := (( IO.processInput test none).2.2.flatMap InfoTree.retainTacticInfo |>.flatMap InfoTree.retainOriginal)[0]!
  let y :=  InfoTree.findTacticInfoTrees x
  for q in y do
    let (a, b, c) := q
    -- IO.println z
    IO.println "----"
    IO.println ( (InfoTree.add_subnode_holes test c b))
    IO.println "----"

#eval lol

-- #eval do return applyReplacements test (← InfoTree.getSubNodePositions (← InfoTree.findTacticInfoTrees ((← IO.processInput test none).2.2.flatMap InfoTree.retainOriginal)[0]!)[6]!) "?_"

Frederick Pu (Jun 29 2025 at 22:40):

my end goal is that given a lean proof you can generate this a HashMap GoalState String that can be used as a tacgen heuristic that will allow aesop to immediately prove the end goal. This basically generates a proof search tree that only contains the input proof. Then, we can randomly remove nodes (or edges) from this search graph to generate synthetic RL curriculumn for problems that are hard for current tactic state based models to prove. Obviously this would need to be along side some sort of RL mechanism within aesop


Last updated: Dec 20 2025 at 21:32 UTC