Zulip Chat Archive
Stream: Machine Learning for Theorem Proving
Topic: a question about premise retrieval using lean dojo
Dhyan Aranha (Apr 12 2025 at 13:31):
Hi, I'm trying to use lean dojo's premise retriever model. In the sample code they give:
import torch
from typing import Union, List
from transformers import AutoTokenizer, AutoModelForTextEncoding
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
model = AutoModelForTextEncoding.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
state = "n : ℕ\n⊢ gcd n n = n"
premises = [
"<a>vsub_eq_zero_iff_eq</a> @[simp] lemma vsub_eq_zero_iff_eq {p1 p2 : P} : p1 -ᵥ p2 = (0 : G) ↔ p1 = p2",
"<a>is_scalar_tower.coe_to_alg_hom'</a> @[simp] lemma coe_to_alg_hom' : (to_alg_hom R S A : S → A) = algebra_map S A",
"<a>polynomial.X_sub_C_ne_zero</a> theorem X_sub_C_ne_zero (r : R) : X - C r ≠ 0",
"<a>forall_true_iff</a> theorem forall_true_iff : (α → true) ↔ true",
"def <a>Nat.gcd</a> : Nat → Nat → Nat\n| 0 y := y\n| (succ x) y := have y % succ x < succ x, from mod_lt _ $ succ_pos _,\n gcd (y % succ x) (succ x)",
"@[simp] theorem <a>Nat.gcd_zero_left</a> (x : Nat) : gcd 0 x = x",
"@[simp] theorem <a>Nat.gcd_succ</a> (x y : Nat) : gcd (succ x) y = gcd (y % succ x) (succ x)",
"@[simp] theorem <a>Nat.mod_self</a> (n : Nat) : n % n = 0",
] # A corpus of premises to retrieve from.
@torch.no_grad()
def encode(s: Union[str, List[str]]) -> torch.Tensor:
"""Encode texts into feature vectors."""
if isinstance(s, str):
s = [s]
should_squeeze = True
else:
should_squeeze = False
tokenized_s = tokenizer(s, return_tensors="pt", padding=True)
hidden_state = model(tokenized_s.input_ids).last_hidden_state
lens = tokenized_s.attention_mask.sum(dim=1)
features = (hidden_state * tokenized_s.attention_mask.unsqueeze(2)).sum(dim=1) / lens.unsqueeze(1)
if should_squeeze:
features = features.squeeze()
return features
@torch.no_grad()
def retrieve(state: str, premises: List[str], k: int) -> List[str]:
"""Retrieve the top-k premises given a state."""
state_emb = encode(state)
premise_embs = encode(premises)
scores = (state_emb @ premise_embs.T)
topk = scores.topk(k).indices.tolist()
return [premises[i] for i in topk]
for p in retrieve(state, premises, k=4):
print(p, end="\n\n")
It seems like if I want to apply this to different statements then I need to supply "a corpus of premises to retrieve from". But In general I have no clue how to do this for say a different problem. Is there some way to get lean dojo to actually generate a lists of possible premises to retrieve from?
Justin Asher (Apr 12 2025 at 15:14):
After reading through the docs and the code, I believe generating the premise corpus for the LeanDojo retriever is essentially a two-step process: first, you must run lean_dojo.trace
on your Lean repository to analyze the code and produce detailed structured data files like *.ast.json
(containing the abstract syntax tree), as detailed in the Getting Started guide. Once you have this raw data, the second step involves loading it using LeanDojo's Python classes (e.g., TracedRepo
, TracedFile
found in the lean_dojo.data_extraction.traced_data
module) and then parsing it; specifically, you can iterate through the TracedFile
objects and use the TracedFile.get_premise_definitions()
method, documented here in the source, which navigates the AST to identify premise declarations (theorems, definitions, etc.) and extracts their clean code strings (just the statement for theorems/lemmas). Aggregating these extracted 'code'
strings from all relevant files will give you the final list needed for the retriever.
For instance, the direct output of get_premise_definitions()
for a single .lean
file yields a list of dictionaries, structured like this:
[
{'full_name': 'MyNamespace.my_definition', 'code': 'def my_definition (x : Nat) := x + 1', 'kind': 'definition', ...},
{'full_name': 'MyNamespace.my_theorem', 'code': 'theorem my_theorem (x : Nat) : my_definition x > x', 'kind': 'theorem', ...},
# ... potentially more dictionaries for other premises in the file
]
By collecting just the value associated with the 'code'
key from each dictionary in these lists (across all traced files), you can aggregate the desired flat list of premise strings, like the premises
variable in the initial retriever example code you showed.
Edit: The <a>...</a>
tags in the example also appear in the get_annotated_tactic
method from the aforementioned file. I am uncertain whether you need to manually add these tags to the generated dataset or not.
Dhyan Aranha (Apr 12 2025 at 15:45):
thanks! I'll play around with this tonight and see where I get stuck :)
Dhyan Aranha (Apr 15 2025 at 19:37):
I made a test repo to trace : here to try out the lean dojo tracing however when I run the script
import lean_dojo
from lean_dojo import LeanGitRepo, trace
def main ():
repo = LeanGitRepo("https://github.com/dhyan-aranha/tracingtest", "69180b0")
trace(repo, dst_dir="traced_lean4-example")
if __name__ == "__main__":
main()
print("Trace completed.")
I get the following error message
2025-04-15 21:28:08.245 | INFO | lean_dojo.utils:execute:115 - ExtractData.lean:339:24: error: invalid field 'find?', the environment does not contain 'Std.HashMap.find?'
env.const2ModIdx
has type
Std.HashMap Name ModuleIdx
ExtractData.lean:340:27: error: invalid field notation, type is not of the form (C ...) where C is a constant
modIdx
has type
?m.26927
2025-04-15 21:28:08.245 | ERROR | lean_dojo.utils:execute:116 -
Traceback (most recent call last):
File "/Users/dhyanaranha/lean_dojo_1/src/lean_dojo_1/premise_selection.py", line 10, in <module>
main()
File "/Users/dhyanaranha/lean_dojo_1/src/lean_dojo_1/premise_selection.py", line 7, in main
trace(repo, dst_dir="traced_lean4-example")
File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 247, in trace
cached_path = get_traced_repo_path(repo, build_deps)
File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 213, in get_traced_repo_path
_trace(repo, build_deps)
File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/data_extraction/trace.py", line 161, in _trace
execute(cmd, capture_output=True)
File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/utils.py", line 117, in execute
raise ex
File "/Users/dhyanaranha/lean_dojo_1/.venv/lib/python3.10/site-packages/lean_dojo/utils.py", line 112, in execute
res = subprocess.run(cmd, shell=True, capture_output=capture_output, check=True)
File "/Users/dhyanaranha/.pyenv/versions/3.10.9/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command 'lake env lean --threads 8 --run ExtractData.lean' returned non-zero exit status 1.
I'd be really happy to know what these errors mean and how they might be fixed?!
@Justin Asher it looks like the script you ran seemed to work, maybe you could share it here? :)
Auguste Poiroux (Apr 15 2025 at 19:42):
You are probably using a Lean version for which ExtractData.lean is incompatible. For instance, the LeanDojo REPL script is incompatible with Lean >= 4.12.0 (issue). Maybe ExtractData.lean is also affected.
Justin Asher (Apr 15 2025 at 19:47):
@Dhyan Aranha In lean-toolchain
of your repository, it says you are using leanprover/lean4:v4.19.0-rc3
, so what @Auguste Poiroux mentioned is indeed likely the cause of the error. In either case, you will need to downgrade the version of lean you are using for your project to utilize LeanDojo. Thanks for pointing that out, Auguste.
Dhyan Aranha (Apr 16 2025 at 07:01):
@Auguste Poiroux @Justin Asher , Thanks! The reason I'm interested in LeanDojo is that I'd like to use it to test out the upcoming SorryDB project. But maybe LeanDojo is not the way to go here then because in practice it should be compatible with the most up-to-date lean version...Does anyone know of something similar that is kept more up-to-date?
Auguste Poiroux (Apr 16 2025 at 09:03):
Can you remind me of the SorryDB project? Is it the project trying to collect sorries in the wild?
If so, maybe you can use Lean REPL. When running it on a file, it will collect all sorries and return the associated goals.
Dhyan Aranha (Apr 16 2025 at 11:11):
yeah its the project trying to collect all the wild sorries.
That sounds reasonable! I'm just not familiar enough with LeanDojo's code yet to know if it does something more nuanced to get its corpus of premise from a given lean file.
Last updated: May 02 2025 at 03:31 UTC