diff --git a/pyproject.toml b/pyproject.toml index 924e11b..de5e6ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.1.16" +version = "1.1.17" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] diff --git a/src/data/test/lean4_proj/Lean4Proj/Basic.lean b/src/data/test/lean4_proj/Lean4Proj/Basic.lean index a77254e..d0cf293 100644 --- a/src/data/test/lean4_proj/Lean4Proj/Basic.lean +++ b/src/data/test/lean4_proj/Lean4Proj/Basic.lean @@ -74,4 +74,13 @@ rw [eq₂] sorry +theorem complicated_have + (a b c d e f : ℕ) + (h1 : a + b = c) + (h2 : d + e = f) : + a + b + d + e = c + f + ∧ a + d + b + e = c + f := by + apply And.intro <;> have h3 : a + b + d + e = c + f := by grind; + exact h3 ; grind + end Lean4Proj2 diff --git a/src/itp_interface/tools/dynamic_lean4_proof_exec.py b/src/itp_interface/tools/dynamic_lean4_proof_exec.py index 53eb19e..282f8d2 100644 --- a/src/itp_interface/tools/dynamic_lean4_proof_exec.py +++ b/src/itp_interface/tools/dynamic_lean4_proof_exec.py @@ -128,9 +128,8 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]: # Cancel the last tactic self.cancel_tactic_till_line(start_line_num, no_backtracking=True) if self._last_tactic_was_modified: - tactics_in_order = self._get_tactics_in_sorted_order() - assert len(tactics_in_order) > 0, "tactics_in_order must not be empty" - self.run_state.tactics_ran[-1] = tactics_in_order[-1][1] + assert self._last_modified_tactic is not None, "last_modified_tactic must not be None if last_tactic_was_modified is True" + self.run_state.tactics_ran[-1] = self._last_modified_tactic return start_line_num, not tactic_failed def get_last_tactic(self) -> typing.Optional[str]: diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/tools/simple_lean4_sync_executor.py index b254605..e44d471 100644 --- a/src/itp_interface/tools/simple_lean4_sync_executor.py +++ b/src/itp_interface/tools/simple_lean4_sync_executor.py @@ -29,7 +29,7 @@ class SimpleLean4SyncExecutor: theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+" theorem_match = re.compile(theorem_regex, re.MULTILINE) - have_regex = r"(^\s*have\s+([^\s:]*):[=]*([^:]*))(:=\s*by)([\s|\S]*)" + have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*))(:=\s*by)([\s|\S]*)" have_match = re.compile(have_regex, re.MULTILINE) unsolved_message = "unsolved goals" no_goals = "No goals to be solved" @@ -107,6 +107,7 @@ def __init__(self, self._run_exactly = False self._nested_have_counts = 0 self._last_tactic_was_modified = False + self._last_modified_tactic : str | None = None if self._enable_search: pass pass @@ -155,6 +156,7 @@ def reset(self, self._error_messages_since_last_thm = {} self._nested_have_counts = 0 self._last_tactic_was_modified = False + self._last_modified_tactic : str | None = None if self._enable_search: pass pass @@ -239,12 +241,17 @@ def get_current_lemma_name(self) -> Optional[str]: def _add_last_tactic(self, idx: int, stmt: str): if idx not in self._last_tactics: - stmt = self._have_preprocessing(stmt) + original_stmt = stmt + stmt = self._tactic_preprocessing(stmt) indentation = " " * self._nested_have_counts * 2 if self._nested_have_counts > 0: stmt = stmt.lstrip() stmt = indentation + stmt - self._last_tactic_was_modified = True + self._last_tactic_was_modified = original_stmt != stmt + if self._last_tactic_was_modified: + self._last_modified_tactic = stmt + else: + self._last_modified_tactic = None self._last_tactics[idx] = stmt self._last_tactic_line_idx = idx # self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}") @@ -274,9 +281,44 @@ def _have_preprocessing(self, stmt: str) -> str: by = by.rstrip() new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}" new_stmt = new_stmt.rstrip() - self._last_tactic_was_modified = True return new_stmt + def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]: + # Split the tactics on multiple goals using `<;>` + initial_space_cnt = len(stmt) - len(stmt.lstrip()) + stmt_splits = stmt.split("<;>") + # Initial space cnt + indentation = " " * initial_space_cnt + stmt_splits = [ + indentation + s.strip() for s in stmt_splits + ] + return stmt_splits + + def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]: + # Split the tactics with `;` + initial_space_cnt = len(stmt) - len(stmt.lstrip()) + stmt_splits = stmt.split(";") + # Initial space cnt + indentation = " " * initial_space_cnt + stmt_splits = [ + indentation + s.strip() for s in stmt_splits + ] + return stmt_splits + + def _tactic_preprocessing(self, stmt: str) -> str: + tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt) + final_multigoal_tactic : List[str] = [] + for tactic in tactics_multi_goal: + new_tactics = self._multiline_tactic_preprocessing(tactic) + final_multiline_tactic : List[str] = [] + for new_tactic in new_tactics: + have_stmts = self._have_preprocessing(new_tactic) + final_multiline_tactic.append(have_stmts) + multi_line_stmt = ";\n".join(final_multiline_tactic) + final_multigoal_tactic.append(multi_line_stmt) + final_stmt = "<;>\n".join(final_multigoal_tactic) + return final_stmt + def _get_lean_code_with_tactics(self, idx: int, stmt: str): assert self._last_theorem is not None, "Last theorem should not be None" self._add_last_tactic(idx, stmt) diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py index 0622260..93e1129 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_test.py @@ -611,10 +611,67 @@ def test_simple_lean4_with_error(self): pretty_print(s1, s2, proof_step, done) assert proof_finished, "Proof was not finished" + def test_simple_lean4_multiline_multigoal(self): + from itp_interface.rl.proof_state import ProofState + from itp_interface.rl.proof_action import ProofAction + from itp_interface.rl.simple_proof_env import ProofEnv + from itp_interface.tools.proof_exec_callback import ProofExecutorCallback + from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy + project_folder = "src/data/test/lean4_proj" + file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" + # Build the project + # cd src/data/test/lean4_proj && lake build + helper = Helper() + helper.build_lean4_project(project_folder) + language = ProofAction.Language.LEAN4 + theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}' + # theorem complicated_have + # (a b c d e f : ℕ) + # (h1 : a + b = c) + # (h2 : d + e = f) : + # a + b + d + e = c + f + # ∧ a + d + b + e = c + f := by + proof_exec_callback = ProofExecutorCallback( + project_folder=project_folder, + file_path=file_path, + language=language, + always_use_retrieval=False, + keep_local_context=True + ) + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) + proof_steps = [ + 'apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;', + 'exact h3 ; grind' + ] + with env: + proof_was_finished = False + for proof_step in proof_steps: + state, action, next_state, _, done, info = env.step(ProofAction( + ProofAction.ActionType.RUN_TACTIC, + language, + tactics=[proof_step])) + proof_step = action.kwargs.get('tactics', ['INVALID'])[0] + if info.error_message is not None: + print(f"Error: {info.error_message}") + # This prints StateChanged, StateUnchanged, Failed, or Done + print(info.progress) + print('-'*30) + if done: + pretty_print(next_state, None, proof_step, done) + proof_was_finished = True + else: + s1 : ProofState = state + s2 : ProofState = next_state + pretty_print(s1, s2, proof_step, done) + assert proof_was_finished, "Proof was not finished" + def main(): unittest.main() # Run only the Lean 4 tests # t = Lean4Test() + # t.test_simple_lean4_multiline_multigoal() # t.test_simple_lean4() # t.test_lean4_backtracking() # t.test_simple_lean4_done_test()