@@ -7,7 +7,7 @@ def __init__(self):
77 def build_lean4_project (self , project_folder ):
88 import os
99 # Build the project
10- with os .popen (f"cd { project_folder } && lake build" ) as proc :
10+ with os .popen (f"cd { project_folder } && lake exe cache get && lake build" ) as proc :
1111 print ("Building Lean4 project..." )
1212 print ('-' * 15 + 'Build Logs' + '-' * 15 )
1313 print (proc .read ())
@@ -488,6 +488,88 @@ def test_simple_lean4_done_test(self):
488488 print (goal .goal )
489489 print (f"=" * 30 )
490490
491+ def test_simple_lean4_have_test (self ):
492+ from itp_interface .rl .proof_state import ProofState
493+ from itp_interface .rl .proof_action import ProofAction
494+ from itp_interface .rl .simple_proof_env import ProofEnv
495+ from itp_interface .tools .proof_exec_callback import ProofExecutorCallback
496+ from itp_interface .rl .simple_proof_env import ProofEnvReRankStrategy
497+ project_folder = "src/data/test/lean4_proj"
498+ file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
499+ # Build the project
500+ # cd src/data/test/lean4_proj && lake build
501+ helper = Helper ()
502+ helper .build_lean4_project (project_folder )
503+ language = ProofAction .Language .LEAN4
504+ theorem_name = "imo_1959_p1"
505+ # theorem test3 (p q : Prop) (hp : p) (hq : q)
506+ # : p ∧ q ∧ p :=
507+ proof_exec_callback = ProofExecutorCallback (
508+ project_folder = project_folder ,
509+ file_path = file_path ,
510+ language = language ,
511+ always_use_retrieval = False ,
512+ keep_local_context = True ,
513+ enforce_qed = True
514+ )
515+ always_retrieve_thms = False
516+ retrieval_strategy = ProofEnvReRankStrategy .NO_RE_RANK
517+ env = ProofEnv ("test_lean4" , proof_exec_callback , theorem_name , retrieval_strategy = retrieval_strategy , max_proof_depth = 10 , always_retrieve_thms = always_retrieve_thms )
518+ proof_steps = [
519+ 'rw [Nat.gcd_rec]' ,
520+ 'rw [Nat.mod_eq_of_lt (by linarith)]' ,
521+ 'rw [Nat.gcd_rec]' ,
522+ 'rw [Nat.gcd_rec]' ,
523+ 'have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by' ,
524+ ' have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring' ,
525+ ' rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]' ,
526+ ' have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith' ,
527+ ' rw [Nat.mod_eq_of_lt]' ,
528+ ' rw [Nat.mod_eq_of_lt]' ,
529+ ' exact h₂' ,
530+ ' rw [Nat.mod_eq_of_lt]' ,
531+ ' exact h₂' ,
532+ ' exact h₂' ,
533+ 'rw [eq₂]'
534+ ]
535+ with env :
536+ for proof_step in proof_steps :
537+ state , _ , next_state , _ , done , info = env .step (ProofAction (
538+ ProofAction .ActionType .RUN_TACTIC ,
539+ language ,
540+ tactics = [proof_step ]))
541+ if info .error_message is not None :
542+ print (f"Error: { info .error_message } " )
543+ # This prints StateChanged, StateUnchanged, Failed, or Done
544+ print (info .progress )
545+ print ('-' * 30 )
546+ if done :
547+ raise Exception ("Proof should not have finished" )
548+ else :
549+ s1 : ProofState = state
550+ s2 : ProofState = next_state
551+ print (f"Current Goal:" )
552+ print ('-' * 30 )
553+ for goal in s1 .training_data_format .start_goals :
554+ hyps = '\n ' .join ([hyp for hyp in goal .hypotheses ])
555+ print (hyps )
556+ print ('|- ' , end = '' )
557+ print (goal .goal )
558+ print (f'*' * 30 )
559+ print (f"=" * 30 )
560+ print (f"Action: { proof_step } " )
561+ print (f"=" * 30 )
562+ print (f"Next Goal:" )
563+ print ('-' * 30 )
564+ for goal in s2 .training_data_format .start_goals :
565+ hyps = '\n ' .join ([hyp for hyp in goal .hypotheses ])
566+ print (hyps )
567+ print ('|- ' , end = '' )
568+ print (goal .goal )
569+ print (f'*' * 30 )
570+ print (f"=" * 30 )
571+ print (f"DONE: { done } " )
572+ print ('-' * 30 )
491573
492574def main ():
493575 unittest .main ()
0 commit comments