Skip to content

Commit 54889a5

Browse files
authored
Merge pull request #63 from trishullab/bug/complicated-have
Bug/complicated have
2 parents effc925 + fb78f63 commit 54889a5

File tree

5 files changed

+115
-8
lines changed

5 files changed

+115
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = [
55
build-backend = "hatchling.build"
66
[project]
77
name = "itp_interface"
8-
version = "1.1.16"
8+
version = "1.1.17"
99
authors = [
1010
{ name="Amitayush Thakur", email="[email protected]" },
1111
]

src/data/test/lean4_proj/Lean4Proj/Basic.lean

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,13 @@ rw [eq₂]
7474
sorry
7575

7676

77+
theorem complicated_have
78+
(a b c d e f : ℕ)
79+
(h1 : a + b = c)
80+
(h2 : d + e = f) :
81+
a + b + d + e = c + f
82+
∧ a + d + b + e = c + f := by
83+
apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;
84+
exact h3 ; grind
85+
7786
end Lean4Proj2

src/itp_interface/tools/dynamic_lean4_proof_exec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,8 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
128128
# Cancel the last tactic
129129
self.cancel_tactic_till_line(start_line_num, no_backtracking=True)
130130
if self._last_tactic_was_modified:
131-
tactics_in_order = self._get_tactics_in_sorted_order()
132-
assert len(tactics_in_order) > 0, "tactics_in_order must not be empty"
133-
self.run_state.tactics_ran[-1] = tactics_in_order[-1][1]
131+
assert self._last_modified_tactic is not None, "last_modified_tactic must not be None if last_tactic_was_modified is True"
132+
self.run_state.tactics_ran[-1] = self._last_modified_tactic
134133
return start_line_num, not tactic_failed
135134

136135
def get_last_tactic(self) -> typing.Optional[str]:

src/itp_interface/tools/simple_lean4_sync_executor.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
class SimpleLean4SyncExecutor:
3030
theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+"
3131
theorem_match = re.compile(theorem_regex, re.MULTILINE)
32-
have_regex = r"(^\s*have\s+([^\s:]*):[=]*([^:]*))(:=\s*by)([\s|\S]*)"
32+
have_regex = r"(^\s*have\s+([^:]*):([\s|\S]*))(:=\s*by)([\s|\S]*)"
3333
have_match = re.compile(have_regex, re.MULTILINE)
3434
unsolved_message = "unsolved goals"
3535
no_goals = "No goals to be solved"
@@ -107,6 +107,7 @@ def __init__(self,
107107
self._run_exactly = False
108108
self._nested_have_counts = 0
109109
self._last_tactic_was_modified = False
110+
self._last_modified_tactic : str | None = None
110111
if self._enable_search:
111112
pass
112113
pass
@@ -155,6 +156,7 @@ def reset(self,
155156
self._error_messages_since_last_thm = {}
156157
self._nested_have_counts = 0
157158
self._last_tactic_was_modified = False
159+
self._last_modified_tactic : str | None = None
158160
if self._enable_search:
159161
pass
160162
pass
@@ -239,12 +241,17 @@ def get_current_lemma_name(self) -> Optional[str]:
239241

240242
def _add_last_tactic(self, idx: int, stmt: str):
241243
if idx not in self._last_tactics:
242-
stmt = self._have_preprocessing(stmt)
244+
original_stmt = stmt
245+
stmt = self._tactic_preprocessing(stmt)
243246
indentation = " " * self._nested_have_counts * 2
244247
if self._nested_have_counts > 0:
245248
stmt = stmt.lstrip()
246249
stmt = indentation + stmt
247-
self._last_tactic_was_modified = True
250+
self._last_tactic_was_modified = original_stmt != stmt
251+
if self._last_tactic_was_modified:
252+
self._last_modified_tactic = stmt
253+
else:
254+
self._last_modified_tactic = None
248255
self._last_tactics[idx] = stmt
249256
self._last_tactic_line_idx = idx
250257
# self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}")
@@ -274,9 +281,44 @@ def _have_preprocessing(self, stmt: str) -> str:
274281
by = by.rstrip()
275282
new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}"
276283
new_stmt = new_stmt.rstrip()
277-
self._last_tactic_was_modified = True
278284
return new_stmt
279285

286+
def _multiple_goals_tactic_preprocessing(self, stmt: str) -> List[str]:
287+
# Split the tactics on multiple goals using `<;>`
288+
initial_space_cnt = len(stmt) - len(stmt.lstrip())
289+
stmt_splits = stmt.split("<;>")
290+
# Initial space cnt
291+
indentation = " " * initial_space_cnt
292+
stmt_splits = [
293+
indentation + s.strip() for s in stmt_splits
294+
]
295+
return stmt_splits
296+
297+
def _multiline_tactic_preprocessing(self, stmt: str) -> List[str]:
298+
# Split the tactics with `;`
299+
initial_space_cnt = len(stmt) - len(stmt.lstrip())
300+
stmt_splits = stmt.split(";")
301+
# Initial space cnt
302+
indentation = " " * initial_space_cnt
303+
stmt_splits = [
304+
indentation + s.strip() for s in stmt_splits
305+
]
306+
return stmt_splits
307+
308+
def _tactic_preprocessing(self, stmt: str) -> str:
309+
tactics_multi_goal = self._multiple_goals_tactic_preprocessing(stmt)
310+
final_multigoal_tactic : List[str] = []
311+
for tactic in tactics_multi_goal:
312+
new_tactics = self._multiline_tactic_preprocessing(tactic)
313+
final_multiline_tactic : List[str] = []
314+
for new_tactic in new_tactics:
315+
have_stmts = self._have_preprocessing(new_tactic)
316+
final_multiline_tactic.append(have_stmts)
317+
multi_line_stmt = ";\n".join(final_multiline_tactic)
318+
final_multigoal_tactic.append(multi_line_stmt)
319+
final_stmt = "<;>\n".join(final_multigoal_tactic)
320+
return final_stmt
321+
280322
def _get_lean_code_with_tactics(self, idx: int, stmt: str):
281323
assert self._last_theorem is not None, "Last theorem should not be None"
282324
self._add_last_tactic(idx, stmt)

src/test/simple_env_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,67 @@ def test_simple_lean4_with_error(self):
611611
pretty_print(s1, s2, proof_step, done)
612612
assert proof_finished, "Proof was not finished"
613613

614+
def test_simple_lean4_multiline_multigoal(self):
615+
from itp_interface.rl.proof_state import ProofState
616+
from itp_interface.rl.proof_action import ProofAction
617+
from itp_interface.rl.simple_proof_env import ProofEnv
618+
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
619+
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
620+
project_folder = "src/data/test/lean4_proj"
621+
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
622+
# Build the project
623+
# cd src/data/test/lean4_proj && lake build
624+
helper = Helper()
625+
helper.build_lean4_project(project_folder)
626+
language = ProofAction.Language.LEAN4
627+
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}'
628+
# theorem complicated_have
629+
# (a b c d e f : ℕ)
630+
# (h1 : a + b = c)
631+
# (h2 : d + e = f) :
632+
# a + b + d + e = c + f
633+
# ∧ a + d + b + e = c + f := by
634+
proof_exec_callback = ProofExecutorCallback(
635+
project_folder=project_folder,
636+
file_path=file_path,
637+
language=language,
638+
always_use_retrieval=False,
639+
keep_local_context=True
640+
)
641+
always_retrieve_thms = False
642+
retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
643+
env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
644+
proof_steps = [
645+
'apply And.intro <;> have h3 : a + b + d + e = c + f := by grind;',
646+
'exact h3 ; grind'
647+
]
648+
with env:
649+
proof_was_finished = False
650+
for proof_step in proof_steps:
651+
state, action, next_state, _, done, info = env.step(ProofAction(
652+
ProofAction.ActionType.RUN_TACTIC,
653+
language,
654+
tactics=[proof_step]))
655+
proof_step = action.kwargs.get('tactics', ['INVALID'])[0]
656+
if info.error_message is not None:
657+
print(f"Error: {info.error_message}")
658+
# This prints StateChanged, StateUnchanged, Failed, or Done
659+
print(info.progress)
660+
print('-'*30)
661+
if done:
662+
pretty_print(next_state, None, proof_step, done)
663+
proof_was_finished = True
664+
else:
665+
s1 : ProofState = state
666+
s2 : ProofState = next_state
667+
pretty_print(s1, s2, proof_step, done)
668+
assert proof_was_finished, "Proof was not finished"
669+
614670
def main():
615671
unittest.main()
616672
# Run only the Lean 4 tests
617673
# t = Lean4Test()
674+
# t.test_simple_lean4_multiline_multigoal()
618675
# t.test_simple_lean4()
619676
# t.test_lean4_backtracking()
620677
# t.test_simple_lean4_done_test()

0 commit comments

Comments
 (0)