Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.16"
version = "1.1.17"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand Down
9 changes: 9 additions & 0 deletions src/data/test/lean4_proj/Lean4Proj/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,13 @@ rw [eq₂]
sorry


theorem complicated_have
(a b c d e f : ℕ)
(h1 : a + b = c)
(h2 : d + e = f) :
a + b + d + e = c + f
∧ a + d + b + e = c + f := by
apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;
exact h3 ; grind

end Lean4Proj2
5 changes: 2 additions & 3 deletions src/itp_interface/tools/dynamic_lean4_proof_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
# Cancel the last tactic
self.cancel_tactic_till_line(start_line_num, no_backtracking=True)
if self._last_tactic_was_modified:
tactics_in_order = self._get_tactics_in_sorted_order()
assert len(tactics_in_order) > 0, "tactics_in_order must not be empty"
self.run_state.tactics_ran[-1] = tactics_in_order[-1][1]
assert self._last_modified_tactic is not None, "last_modified_tactic must not be None if last_tactic_was_modified is True"
self.run_state.tactics_ran[-1] = self._last_modified_tactic
return start_line_num, not tactic_failed

def get_last_tactic(self) -> typing.Optional[str]:
Expand Down
50 changes: 46 additions & 4 deletions src/itp_interface/tools/simple_lean4_sync_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class SimpleLean4SyncExecutor:
theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+"
theorem_match = re.compile(theorem_regex, re.MULTILINE)
have_regex = r"(^\s*have\s+([^\s:]*):[=]*([^:]*))(:=\s*by)([\s|\S]*)"
have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*))(:=\s*by)([\s|\S]*)"
have_match = re.compile(have_regex, re.MULTILINE)
unsolved_message = "unsolved goals"
no_goals = "No goals to be solved"
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(self,
self._run_exactly = False
self._nested_have_counts = 0
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
if self._enable_search:
pass
pass
Expand Down Expand Up @@ -155,6 +156,7 @@ def reset(self,
self._error_messages_since_last_thm = {}
self._nested_have_counts = 0
self._last_tactic_was_modified = False
self._last_modified_tactic : str | None = None
if self._enable_search:
pass
pass
Expand Down Expand Up @@ -239,12 +241,17 @@ def get_current_lemma_name(self) -> Optional[str]:

def _add_last_tactic(self, idx: int, stmt: str):
if idx not in self._last_tactics:
stmt = self._have_preprocessing(stmt)
original_stmt = stmt
stmt = self._tactic_preprocessing(stmt)
indentation = " " * self._nested_have_counts * 2
if self._nested_have_counts > 0:
stmt = stmt.lstrip()
stmt = indentation + stmt
self._last_tactic_was_modified = True
self._last_tactic_was_modified = original_stmt != stmt
if self._last_tactic_was_modified:
self._last_modified_tactic = stmt
else:
self._last_modified_tactic = None
self._last_tactics[idx] = stmt
self._last_tactic_line_idx = idx
# self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}")
Expand Down Expand Up @@ -274,9 +281,44 @@ def _have_preprocessing(self, stmt: str) -> str:
by = by.rstrip()
new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}"
new_stmt = new_stmt.rstrip()
self._last_tactic_was_modified = True
return new_stmt

def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]:
# Split the tactics on multiple goals using `<;>`
initial_space_cnt = len(stmt) - len(stmt.lstrip())
stmt_splits = stmt.split("<;>")
# Initial space cnt
indentation = " " * initial_space_cnt
stmt_splits = [
indentation + s.strip() for s in stmt_splits
]
return stmt_splits

def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]:
# Split the tactics with `;`
initial_space_cnt = len(stmt) - len(stmt.lstrip())
stmt_splits = stmt.split(";")
# Initial space cnt
indentation = " " * initial_space_cnt
stmt_splits = [
indentation + s.strip() for s in stmt_splits
]
return stmt_splits

def _tactic_preprocessing(self, stmt: str) -> str:
tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt)
final_multigoal_tactic : List[str] = []
for tactic in tactics_multi_goal:
new_tactics = self._multiline_tactic_preprocessing(tactic)
final_multiline_tactic : List[str] = []
for new_tactic in new_tactics:
have_stmts = self._have_preprocessing(new_tactic)
final_multiline_tactic.append(have_stmts)
multi_line_stmt = ";\n".join(final_multiline_tactic)
final_multigoal_tactic.append(multi_line_stmt)
final_stmt = "<;>\n".join(final_multigoal_tactic)
return final_stmt

def _get_lean_code_with_tactics(self, idx: int, stmt: str):
assert self._last_theorem is not None, "Last theorem should not be None"
self._add_last_tactic(idx, stmt)
Expand Down
57 changes: 57 additions & 0 deletions src/test/simple_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,67 @@ def test_simple_lean4_with_error(self):
pretty_print(s1, s2, proof_step, done)
assert proof_finished, "Proof was not finished"

def test_simple_lean4_multiline_multigoal(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.simple_proof_env import ProofEnv
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
project_folder = "src/data/test/lean4_proj"
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}'
# theorem complicated_have
# (a b c d e f : ℕ)
# (h1 : a + b = c)
# (h2 : d + e = f) :
# a + b + d + e = c + f
# ∧ a + d + b + e = c + f := by
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
file_path=file_path,
language=language,
always_use_retrieval=False,
keep_local_context=True
)
always_retrieve_thms = False
retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
proof_steps = [
'apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;',
'exact h3 ; grind'
]
with env:
proof_was_finished = False
for proof_step in proof_steps:
state, action, next_state, _, done, info = env.step(ProofAction(
ProofAction.ActionType.RUN_TACTIC,
language,
tactics=[proof_step]))
proof_step = action.kwargs.get('tactics', ['INVALID'])[0]
if info.error_message is not None:
print(f"Error: {info.error_message}")
# This prints StateChanged, StateUnchanged, Failed, or Done
print(info.progress)
print('-'*30)
if done:
pretty_print(next_state, None, proof_step, done)
proof_was_finished = True
else:
s1 : ProofState = state
s2 : ProofState = next_state
pretty_print(s1, s2, proof_step, done)
assert proof_was_finished, "Proof was not finished"

def main():
unittest.main()
# Run only the Lean 4 tests
# t = Lean4Test()
# t.test_simple_lean4_multiline_multigoal()
# t.test_simple_lean4()
# t.test_lean4_backtracking()
# t.test_simple_lean4_done_test()
Expand Down