Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.9"
version = "1.1.10"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand Down
23 changes: 23 additions & 0 deletions src/data/test/lean4_proj/Lean4Proj/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Mathlib
namespace Lean4Proj1

def hello := "world"
Expand Down Expand Up @@ -51,4 +52,26 @@ theorem test3 (p q : Prop) (hp : p) (hq : q)
exact hq
exact hp

theorem imo_1959_p1
(n : ℕ)
(h₀ : 0 < n) :
Nat.gcd (21*n + 4) (14*n + 3) = 1 := by
rw [Nat.gcd_rec]
rw [Nat.mod_eq_of_lt (by linarith)]
rw [Nat.gcd_rec]
rw [Nat.gcd_rec]
have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by
have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring
rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]
have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith
rw [Nat.mod_eq_of_lt]
rw [Nat.mod_eq_of_lt]
exact h₂
rw [Nat.mod_eq_of_lt]
exact h₂
exact h₂
rw [eq₂]
sorry


end Lean4Proj2
12 changes: 7 additions & 5 deletions src/itp_interface/tools/lean4_sync_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
proof_running = 'sorries' in response or 'proofState' in response
error_messages = response.get('message', None)
goal_text = None
goal_texts = []
if error_messages is None and 'proofState' in response:
error_messages = response.get('messages', None)
elif error_messages is None:
Expand All @@ -840,6 +841,7 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
text_msg = msg.get('data', None)
if text_msg is not None and text_msg.startswith(Lean4SyncExecutor.unsolved_message):
goal_text = text_msg[len(Lean4SyncExecutor.unsolved_message):]
goal_texts.append(goal_text)
else:
error_messages.append(msg)
if len(error_messages) == 0:
Expand All @@ -865,11 +867,11 @@ def _update_proof_context(self, idx, response, relevant_messages, only_env_updat
if self._proof_running:
proof_state_idx = None
proof_goals = []
if goal_text is not None:
if len(goal_text) == 0:
proof_goals = []
else:
proof_goals = [goal_text]
if len(goal_texts) == 0:
proof_goals = []
elif len(goal_texts) > 0:
proof_goals = [g_text for g_text in goal_texts
if g_text is not None and len(g_text) > 0]
elif 'sorries' in response:
sorries = response['sorries']
# TODO: Go over all the sorries and find the one which matches the line number with idx + 1
Expand Down
84 changes: 83 additions & 1 deletion src/test/simple_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(self):
def build_lean4_project(self, project_folder):
import os
# Build the project
with os.popen(f"cd {project_folder} && lake build") as proc:
with os.popen(f"cd {project_folder} && lake exe cache get && lake build") as proc:
print("Building Lean4 project...")
print('-'*15 + 'Build Logs' + '-'*15)
print(proc.read())
Expand Down Expand Up @@ -488,6 +488,88 @@ def test_simple_lean4_done_test(self):
print(goal.goal)
print(f"="*30)

def test_simple_lean4_have_test(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.simple_proof_env import ProofEnv
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
project_folder = "src/data/test/lean4_proj"
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = "imo_1959_p1"
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
file_path=file_path,
language=language,
always_use_retrieval=False,
keep_local_context=True,
enforce_qed=True
)
always_retrieve_thms = False
retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
proof_steps = [
'rw [Nat.gcd_rec]',
'rw [Nat.mod_eq_of_lt (by linarith)]',
'rw [Nat.gcd_rec]',
'rw [Nat.gcd_rec]',
'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by',
' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring',
' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]',
' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith',
' rw [Nat.mod_eq_of_lt]',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' rw [Nat.mod_eq_of_lt]',
' exact h₂',
' exact h₂',
'rw [eq₂]'
]
with env:
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
ProofAction.ActionType.RUN_TACTIC,
language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
# This prints StateChanged, StateUnchanged, Failed, or Done
print(info.progress)
print('-'*30)
if done:
raise Exception("Proof should not have finished")
else:
s1 : ProofState = state
s2 : ProofState = next_state
print(f"Current Goal:")
print('-'*30)
for goal in s1.training_data_format.start_goals:
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
print(hyps)
print('|- ', end='')
print(goal.goal)
print(f'*'*30)
print(f"="*30)
print(f"Action: {proof_step}")
print(f"="*30)
print(f"Next Goal:")
print('-'*30)
for goal in s2.training_data_format.start_goals:
hyps = '\n'.join([hyp for hyp in goal.hypotheses])
print(hyps)
print('|- ', end='')
print(goal.goal)
print(f'*'*30)
print(f"="*30)
print(f"DONE: {done}")
print('-'*30)

def main():
unittest.main()
Expand Down
Loading