Skip to content

Commit 7394c79

Browse files
authored
Merge pull request #50 from trishullab/usr/amit9oct/fixed-have-tactics
Fixed have tactics
2 parents c8a0211 + 8648e9c commit 7394c79

File tree

5 files changed

+120
-9
lines changed

5 files changed

+120
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [
55
build-backend = "hatchling.build"
66
[project]
77
name = "itp_interface"
8-
version = "1.1.9"
8+
version = "1.1.10"
99
authors = [
1010
{ name="Amitayush Thakur", email="[email protected]" },
1111
]

src/data/test/lean4_proj/Lean4Proj/Basic.lean

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import Mathlib
12
namespace Lean4Proj1
23

34
def hello := "world"
@@ -51,4 +52,26 @@ theorem test3 (p q : Prop) (hp : p) (hq : q)
5152
exact hq
5253
exact hp
5354

55+
theorem imo_1959_p1
56+
(n : ℕ)
57+
(h₀ : 0 < n) :
58+
Nat.gcd (21*n + 4) (14*n + 3) = 1 := by
59+
rw [Nat.gcd_rec]
60+
rw [Nat.mod_eq_of_lt (by linarith)]
61+
rw [Nat.gcd_rec]
62+
rw [Nat.gcd_rec]
63+
have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by
64+
have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring
65+
rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]
66+
have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith
67+
rw [Nat.mod_eq_of_lt]
68+
rw [Nat.mod_eq_of_lt]
69+
exact h₂
70+
rw [Nat.mod_eq_of_lt]
71+
exact h₂
72+
exact h₂
73+
rw [eq₂]
74+
sorry
75+
76+
5477
end Lean4Proj2

src/itp_interface/tools/lean4_sync_executor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
831831
proof_running = 'sorries' in response or 'proofState' in response
832832
error_messages = response.get('message', None)
833833
goal_text = None
834+
goal_texts = []
834835
if error_messages is None and 'proofState' in response:
835836
error_messages = response.get('messages', None)
836837
elif error_messages is None:
@@ -840,6 +841,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
840841
text_msg = msg.get('data', None)
841842
if text_msg is not None and text_msg.startswith(Lean4SyncExecutor.unsolved_message):
842843
goal_text = text_msg[len(Lean4SyncExecutor.unsolved_message):]
844+
goal_texts.append(goal_text)
843845
else:
844846
error_messages.append(msg)
845847
if len(error_messages) == 0:
@@ -865,11 +867,11 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
865867
if self._proof_running:
866868
proof_state_idx = None
867869
proof_goals = []
868-
if goal_text is not None:
869-
if len(goal_text) == 0:
870-
proof_goals = []
871-
else:
872-
proof_goals = [goal_text]
870+
if len(goal_texts) == 0:
871+
proof_goals = []
872+
elif len(goal_texts) > 0:
873+
proof_goals = [g_text for g_text in goal_texts
874+
if g_text is not None and len(g_text) > 0]
873875
elif 'sorries' in response:
874876
sorries = response['sorries']
875877
# TODO: Go over all the sorries and find the one which matches the line number with idx + 1

src/test/simple_data_gen_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@ def test_proof_step_data_gen(self):
3939
dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean"))
4040
print(dirs)
4141
last_dir = dirs[-1]
42-
train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train")
43-
data_gen_file = os.path.join(train_data, "local_data_0000000025.json")
42+
train_data = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir, "train")
43+
list_files = os.listdir(train_data)
44+
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
45+
assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}"
46+
print(data_files[0])
47+
data_gen_file = os.path.join(train_data, data_files[0])
4448
print("Data Gen File:", data_gen_file)
4549
with open(data_gen_file, "r") as f:
4650
print(f.read())

src/test/simple_env_test.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self):
77
def build_lean4_project(self, project_folder):
88
import os
99
# Build the project
10-
with os.popen(f"cd {project_folder} && lake build") as proc:
10+
with os.popen(f"cd {project_folder} && lake exe cache get && lake build") as proc:
1111
print("Building Lean4 project...")
1212
print('-'*15 + 'Build Logs' + '-'*15)
1313
print(proc.read())
@@ -488,6 +488,88 @@ def test_simple_lean4_done_test(self):
488488
print(goal.goal)
489489
print(f"="*30)
490490

491+
def test_simple_lean4_have_test(self):
492+
from itp_interface.rl.proof_state import ProofState
493+
from itp_interface.rl.proof_action import ProofAction
494+
from itp_interface.rl.simple_proof_env import ProofEnv
495+
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
496+
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
497+
project_folder = "src/data/test/lean4_proj"
498+
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
499+
# Build the project
500+
# cd src/data/test/lean4_proj && lake build
501+
helper = Helper()
502+
helper.build_lean4_project(project_folder)
503+
language = ProofAction.Language.LEAN4
504+
theorem_name = "imo_1959_p1"
505+
# theorem test3 (p q : Prop) (hp : p) (hq : q)
506+
# : p ∧ q ∧ p :=
507+
proof_exec_callback = ProofExecutorCallback(
508+
project_folder=project_folder,
509+
file_path=file_path,
510+
language=language,
511+
always_use_retrieval=False,
512+
keep_local_context=True,
513+
enforce_qed=True
514+
)
515+
always_retrieve_thms = False
516+
retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
517+
env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
518+
proof_steps = [
519+
'rw [Nat.gcd_rec]',
520+
'rw [Nat.mod_eq_of_lt (by linarith)]',
521+
'rw [Nat.gcd_rec]',
522+
'rw [Nat.gcd_rec]',
523+
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
524+
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
525+
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
526+
' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith',
527+
' rw [Nat.mod_eq_of_lt]',
528+
' rw [Nat.mod_eq_of_lt]',
529+
' exact h₂',
530+
' rw [Nat.mod_eq_of_lt]',
531+
' exact h₂',
532+
' exact h₂',
533+
'rw [eq₂]'
534+
]
535+
with env:
536+
for proof_step in proof_steps:
537+
state, _, next_state, _, done, info = env.step(ProofAction(
538+
ProofAction.ActionType.RUN_TACTIC,
539+
language,
540+
tactics=[proof_step]))
541+
if info.error_message is not None:
542+
print(f"Error: {info.error_message}")
543+
# This prints StateChanged, StateUnchanged, Failed, or Done
544+
print(info.progress)
545+
print('-'*30)
546+
if done:
547+
raise Exception("Proof should not have finished")
548+
else:
549+
s1 : ProofState = state
550+
s2 : ProofState = next_state
551+
print(f"Current Goal:")
552+
print('-'*30)
553+
for goal in s1.training_data_format.start_goals:
554+
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
555+
print(hyps)
556+
print('|- ', end='')
557+
print(goal.goal)
558+
print(f'*'*30)
559+
print(f"="*30)
560+
print(f"Action: {proof_step}")
561+
print(f"="*30)
562+
print(f"Next Goal:")
563+
print('-'*30)
564+
for goal in s2.training_data_format.start_goals:
565+
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
566+
print(hyps)
567+
print('|- ', end='')
568+
print(goal.goal)
569+
print(f'*'*30)
570+
print(f"="*30)
571+
print(f"DONE: {done}")
572+
print('-'*30)
491573

492574
def main():
493575
unittest.main()

0 commit comments

Comments
 (0)