Zulip Chat Archive

Stream: Machine Learning for Theorem Proving

Topic: a question about premise retrieval using lean dojo


Dhyan Aranha (Apr 12 2025 at 13:31):

Hi, I'm trying to use lean dojo's premise retriever model. In the sample code they give:

import torch
from typing import Union, List
from transformers import AutoTokenizer, AutoModelForTextEncoding

tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
model = AutoModelForTextEncoding.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")

state = "n : ℕ\n⊢ gcd n n = n"
premises = [
  "<a>vsub_eq_zero_iff_eq</a> @[simp] lemma vsub_eq_zero_iff_eq {p1 p2 : P} : p1 -ᵥ p2 = (0 : G) ↔ p1 = p2",
  "<a>is_scalar_tower.coe_to_alg_hom'</a> @[simp] lemma coe_to_alg_hom' : (to_alg_hom R S A : S → A) = algebra_map S A",
  "<a>polynomial.X_sub_C_ne_zero</a> theorem X_sub_C_ne_zero (r : R) : X - C r ≠ 0",
  "<a>forall_true_iff</a> theorem forall_true_iff : (α → true) ↔ true",
  "def <a>Nat.gcd</a> : Nat → Nat → Nat\n| 0        y := y\n| (succ x) y := have y % succ x < succ x, from mod_lt _ $ succ_pos _,\n                gcd (y % succ x) (succ x)",
  "@[simp] theorem <a>Nat.gcd_zero_left</a> (x : Nat) : gcd 0 x = x",
  "@[simp] theorem <a>Nat.gcd_succ</a> (x y : Nat) : gcd (succ x) y = gcd (y % succ x) (succ x)",
  "@[simp] theorem <a>Nat.mod_self</a> (n : Nat) : n % n = 0",
]  # A corpus of premises to retrieve from.

@torch.no_grad()
def encode(s: Union[str, List[str]]) -> torch.Tensor:
    """Encode texts into feature vectors."""
    if isinstance(s, str):
        s = [s]
        should_squeeze = True
    else:
        should_squeeze = False
    tokenized_s = tokenizer(s, return_tensors="pt", padding=True)
    hidden_state = model(tokenized_s.input_ids).last_hidden_state
    lens = tokenized_s.attention_mask.sum(dim=1)
    features = (hidden_state * tokenized_s.attention_mask.unsqueeze(2)).sum(dim=1) / lens.unsqueeze(1)
    if should_squeeze:
      features = features.squeeze()
    return features

@torch.no_grad()
def retrieve(state: str, premises: List[str], k: int) -> List[str]:
    """Retrieve the top-k premises given a state."""
    state_emb = encode(state)
    premise_embs = encode(premises)
    scores = (state_emb @ premise_embs.T)
    topk = scores.topk(k).indices.tolist()
    return [premises[i] for i in topk]

for p in retrieve(state, premises, k=4):
    print(p, end="\n\n")

It seems like if I want to apply this to different statements then I need to supply "a corpus of premises to retrieve from". But In general I have no clue how to do this for say a different problem. Is there some way to get lean dojo to actually generate a lists of possible premises to retrieve from?

Justin Asher (Apr 12 2025 at 15:14):

After reading through the docs and the code, I believe generating the premise corpus for the LeanDojo retriever is essentially a two-step process: first, you must run lean_dojo.trace on your Lean repository to analyze the code and produce detailed structured data files like *.ast.json (containing the abstract syntax tree), as detailed in the Getting Started guide. Once you have this raw data, the second step involves loading it using LeanDojo's Python classes (e.g., TracedRepo, TracedFile found in the lean_dojo.data_extraction.traced_data module) and then parsing it; specifically, you can iterate through the TracedFile objects and use the TracedFile.get_premise_definitions() method, documented here in the source, which navigates the AST to identify premise declarations (theorems, definitions, etc.) and extracts their clean code strings (just the statement for theorems/lemmas). Aggregating these extracted 'code' strings from all relevant files will give you the final list needed for the retriever.

For instance, the direct output of get_premise_definitions() for a single .lean file yields a list of dictionaries, structured like this:

[
  {'full_name': 'MyNamespace.my_definition', 'code': 'def my_definition (x : Nat) := x + 1', 'kind': 'definition', ...},
  {'full_name': 'MyNamespace.my_theorem', 'code': 'theorem my_theorem (x : Nat) : my_definition x > x', 'kind': 'theorem', ...},
  # ... potentially more dictionaries for other premises in the file
]

By collecting just the value associated with the 'code' key from each dictionary in these lists (across all traced files), you can aggregate the desired flat list of premise strings, like the premises variable in the initial retriever example code you showed.

Edit: The <a>...</a> tags in the example also appear in the get_annotated_tactic method from the aforementioned file. I am uncertain whether you need to manually add these tags to the generated dataset or not.

Dhyan Aranha (Apr 12 2025 at 15:45):

thanks! I'll play around with this tonight and see where I get stuck :)

Dhyan Aranha (Apr 15 2025 at 19:37):

I made a test repo to trace : here to try out the lean dojo tracing however when I run the script

import lean_dojo
from lean_dojo import LeanGitRepo, trace
def main ():
    repo = LeanGitRepo("https://github.com/dhyan-aranha/tracingtest", "69180b0")
    trace(repo, dst_dir="traced_lean4-example")
if __name__ == "__main__":
    main()
    print("Trace completed.")

I get the following error message

2025-04-15 21:28:08.245 | INFO     | lean_dojo.utils:execute:115 - ExtractData.lean:339:24: error: invalid field 'find?', the environment does not contain 'Std.HashMap.find?'
  env.const2ModIdx
has type
  Std.HashMap Name ModuleIdx
ExtractData.lean:340:27: error: invalid field notation, type is not of the form (C ...) where C is a constant
  modIdx
has type
  ?m.26927

2025-04-15 21:28:08.245 | ERROR    | lean_dojo.utils:execute:116 -
Traceback (most recent call last):
  File "/Users/dhyanaranha/lean_dojo_1/src/lean_dojo_1/premise_selection.py", line 10, in <module>
    main()
  File "/Users/dhyanaranha/lean_dojo_1/src/lean_dojo_1/premise_selection.py", line 7, in main
    trace(repo, dst_dir="traced_lean4-example")
  File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 247, in trace
    cached_path = get_traced_repo_path(repo, build_deps)
  File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 213, in get_traced_repo_path
    _trace(repo, build_deps)
  File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 161, in _trace
    execute(cmd, capture_output=True)
  File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/utils.py", line 117, in execute
    raise ex
  File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/utils.py", line 112, in execute
    res = subprocess.run(cmd, shell=True, capture_output=capture_output, check=True)
  File "/Users/dhyanaranha/.pyenv/versions/3.10.9/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command 'lake env lean --threads 8 --run ExtractData.lean' returned non-zero exit status 1.

I'd be really happy to know what these errors mean and how they might be fixed?!

@Justin Asher it looks like the script you ran seemed to work, maybe you could share it here? :)

Auguste Poiroux (Apr 15 2025 at 19:42):

You are probably using a Lean version for which ExtractData.lean is incompatible. For instance, the LeanDojo REPL script is incompatible with Lean >= 4.12.0 (issue). Maybe ExtractData.lean is also affected.

Justin Asher (Apr 15 2025 at 19:47):

@Dhyan Aranha In lean-toolchain of your repository, it says you are using leanprover/lean4:v4.19.0-rc3, so what @Auguste Poiroux mentioned is indeed likely the cause of the error. In either case, you will need to downgrade the version of lean you are using for your project to utilize LeanDojo. Thanks for pointing that out, Auguste.

Dhyan Aranha (Apr 16 2025 at 07:01):

@Auguste Poiroux @Justin Asher , Thanks! The reason I'm interested in LeanDojo is that I'd like to use it to test out the upcoming SorryDB project. But maybe LeanDojo is not the way to go here then because in practice it should be compatible with the most up-to-date lean version...Does anyone know of something similar that is kept more up-to-date?

Auguste Poiroux (Apr 16 2025 at 09:03):

Can you remind me of the SorryDB project? Is it the project trying to collect sorries in the wild?
If so, maybe you can use Lean REPL. When running it on a file, it will collect all sorries and return the associated goals.

Dhyan Aranha (Apr 16 2025 at 11:11):

yeah its the project trying to collect all the wild sorries.

That sounds reasonable! I'm just not familiar enough with LeanDojo's code yet to know if it does something more nuanced to get its corpus of premise from a given lean file.


Last updated: May 02 2025 at 03:31 UTC