diff --git a/.github/workflows/github-build-actions-python314t.yaml b/.github/workflows/github-build-actions-python314t.yaml
index 6a7ba1a..010a72b 100644
--- a/.github/workflows/github-build-actions-python314t.yaml
+++ b/.github/workflows/github-build-actions-python314t.yaml
@@ -117,3 +117,12 @@ jobs:
eval $(opam env)
source $HOME/.elan/env
python src/test/simple_data_gen_test.py
+
+ - name: Run Data Extraction Test
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_data_extract_test.py
diff --git a/.github/workflows/github-build-actions.yaml b/.github/workflows/github-build-actions.yaml
index 302425c..7814b40 100644
--- a/.github/workflows/github-build-actions.yaml
+++ b/.github/workflows/github-build-actions.yaml
@@ -75,10 +75,32 @@ jobs:
eval $(opam env)
source $HOME/.elan/env
python src/test/simple_env_test.py
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
- name: Run Data Gen Test
shell: bash
run: |
eval $(opam env)
source $HOME/.elan/env
- python src/test/simple_data_gen_test.py
\ No newline at end of file
+ python src/test/simple_data_gen_test.py
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
+
+ - name: Run Data Extraction Test
+ shell: bash
+ run: |
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_data_extract_test.py
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index f2668ee..e018019 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
-version = "1.1.14"
+version = "1.1.15"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
@@ -45,6 +45,12 @@ dependencies = [
"grpcio>=1.51.3; python_version<'3.14'"
]
+[project.optional-dependencies]
+app = [
+ "flask>=2.3.0",
+ "flask-cors>=4.0.0"
+]
+
[project.urls]
Homepage = "https://github.com/trishullab/itp-interface"
Issues = "https://github.com/trishullab/itp-interface/issues"
diff --git a/src/app/itp-gui/app.py b/src/app/itp-gui/app.py
new file mode 100644
index 0000000..c42eb9c
--- /dev/null
+++ b/src/app/itp-gui/app.py
@@ -0,0 +1,447 @@
+#!/usr/bin/env python3
+"""
+ITP GUI - Interactive Theorem Proving Graphical User Interface
+A web-based interface for interacting with Lean 4 proofs using Lean4SyncExecutor
+"""
+
+import sys
+import os
+
+# Add the parent directory to the path
+root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
+if root_dir not in sys.path:
+ sys.path.insert(0, root_dir)
+
+from flask import Flask, render_template, request, jsonify, send_from_directory
+from flask_cors import CORS
+import logging
+import traceback
+from typing import Optional, Dict, Any, List
+import json
+
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean_server.lean_context import ProofContext
+
+app = Flask(__name__, static_folder='static', template_folder='templates')
+CORS(app)
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+# Global state to store the executor
+executor_state: Dict[str, Any] = {
+ 'executor': None,
+ 'context_manager': None,
+ 'history': [], # List of tactics executed
+ 'project_root': None,
+ 'file_path': None,
+ 'lemma_name': None
+}
+
+
+def get_debug_info() -> Dict[str, Any]:
+ """Extract debug information from the Lean4SyncExecutor private variables"""
+ if not executor_state['executor']:
+ return {}
+
+ executor = executor_state['executor']
+ assert isinstance(executor, SimpleLean4SyncExecutor)
+
+ # Format proof_context
+ proof_context_info = None
+ if executor.proof_context and executor.proof_context != ProofContext.empty():
+ proof_context_info = {
+ 'num_goals': len(executor.proof_context.all_goals),
+ 'goals': [
+ {
+ 'hypotheses': [str(hyp) for hyp in goal.hypotheses],
+ 'goal': str(goal.goal)
+ }
+ for goal in executor.proof_context.all_goals
+ ]
+ }
+
+ debug_info = {
+ 'line_num': executor.line_num,
+ 'current_stmt': executor.current_stmt,
+ 'execution_complete': executor.execution_complete,
+ 'curr_lemma_name': executor.curr_lemma_name,
+ 'curr_lemma': executor.curr_lemma,
+ '_last_tactics': executor._last_tactics,
+ "_nested_have_counts": executor._nested_have_counts,
+ '_last_tactic_was_modified': executor._last_tactic_was_modified,
+
+ # Requested private variables
+ 'proof_context': proof_context_info,
+ '_proof_running': executor._proof_running,
+ 'lean_error_messages': executor.lean_error_messages,
+ '_error_messages_since_last_thm': dict(executor._error_messages_since_last_thm),
+ '_error_messages_so_far': list(executor._error_messages_so_far),
+
+ # Other useful info
+ 'proof_start_idx': executor._proof_start_idx,
+ 'import_end_idx': executor._import_end_idx,
+ 'theorem_started': executor._theorem_started,
+ 'enforce_qed': executor._enforce_qed,
+ 'anon_theorem_count': executor._anon_theorem_count,
+
+ # Debug traces and proof tactics
+ 'debug_enabled': executor.debug_enabled,
+ 'debug_traces': list(executor._debug_traces),
+ 'possible_proof_tactics': executor.possible_proof_tactics
+ }
+
+ return debug_info
+
+
+def get_proof_state() -> Dict[str, Any]:
+ """Get current proof state information"""
+ if not executor_state['executor']:
+ return {
+ 'initialized': False,
+ 'error': 'Executor not initialized'
+ }
+
+ executor = executor_state['executor']
+ proof_context = executor.proof_context
+
+ result = {
+ 'initialized': True,
+ 'lemma_name': executor.curr_lemma_name,
+ 'lemma_stmt': executor.curr_lemma,
+ 'is_in_proof_mode': executor.is_in_proof_mode(),
+ 'execution_complete': executor.execution_complete,
+ 'error_messages': executor.lean_error_messages,
+ 'goals': [],
+ 'history': executor_state['history']
+ }
+
+ if proof_context and proof_context != ProofContext.empty():
+ for goal in proof_context.all_goals:
+ goal_dict = {
+ 'hypotheses': [str(hyp) for hyp in goal.hypotheses],
+ 'goal': str(goal.goal)
+ }
+ result['goals'].append(goal_dict)
+
+ return result
+
+
+@app.route('/')
+def index():
+ """Serve the main HTML page"""
+ return send_from_directory('static', 'index.html')
+
+
+@app.route('/api/initialize', methods=['POST'])
+def initialize():
+ """Initialize the Lean4SyncExecutor with given parameters"""
+ try:
+ data = request.json
+ project_root = data.get('project_root')
+ file_path = data.get('file_path')
+ lemma_name = data.get('lemma_name')
+
+ if not all([project_root, file_path, lemma_name]):
+ return jsonify({
+ 'success': False,
+ 'error': 'Missing required parameters: project_root, file_path, lemma_name'
+ }), 400
+
+ # Validate paths
+ if not os.path.exists(project_root):
+ return jsonify({
+ 'success': False,
+ 'error': f'Project root does not exist: {project_root}'
+ }), 400
+
+ if not os.path.exists(file_path):
+ return jsonify({
+ 'success': False,
+ 'error': f'File path does not exist: {file_path}'
+ }), 400
+
+ # Close existing executor if any
+ if executor_state['context_manager']:
+ try:
+ executor_state['context_manager'].__exit__(None, None, None)
+ except:
+ pass
+
+ # Initialize new executor
+ executor = SimpleLean4SyncExecutor(
+ project_root=project_root,
+ main_file=file_path,
+ timeout_in_sec=60,
+ use_human_readable_proof_context=True,
+ suppress_error_log=False,
+ logger=logger
+ )
+
+ # Enter context manager
+ executor.__enter__()
+
+ # Enable debug mode
+ executor.debug_enabled = True
+
+ # Store in global state
+ executor_state['executor'] = executor
+ executor_state['context_manager'] = executor
+ executor_state['project_root'] = project_root
+ executor_state['file_path'] = file_path
+ executor_state['lemma_name'] = lemma_name
+ executor_state['history'] = []
+
+ # Skip to the specified theorem
+ try:
+ executor._skip_to_theorem(lemma_name)
+ except Exception as e:
+ return jsonify({
+ 'success': False,
+ 'error': f'Failed to find lemma "{lemma_name}": {str(e)}'
+ }), 400
+
+ # Get initial state
+ state = get_proof_state()
+ debug = get_debug_info()
+
+ return jsonify({
+ 'success': True,
+ 'message': f'Initialized with lemma: {lemma_name}',
+ 'state': state,
+ 'debug': debug
+ })
+
+ except Exception as e:
+ logger.error(f"Error initializing: {str(e)}")
+ logger.error(traceback.format_exc())
+ return jsonify({
+ 'success': False,
+ 'error': str(e),
+ 'traceback': traceback.format_exc()
+ }), 500
+
+
+@app.route('/api/run_tactic', methods=['POST'])
+def run_tactic():
+ """Run a tactic on the current proof state"""
+ try:
+ if not executor_state['executor']:
+ return jsonify({
+ 'success': False,
+ 'error': 'Executor not initialized. Please initialize first.'
+ }), 400
+
+ data = request.json
+ tactic = data.get('tactic', '')
+
+ if not tactic or not tactic.strip():
+ return jsonify({
+ 'success': False,
+ 'error': 'Tactic cannot be empty'
+ }), 400
+
+ executor = executor_state['executor']
+
+ # Store current state before running
+ prev_line_num = executor.line_num
+
+ # Manually create a proof step iterator with the tactic
+ # We need to inject this tactic into the executor
+ old_iter = executor.main_file_iter
+
+ # Create a simple iterator that yields the tactic as-is (no splitting)
+ def tactic_iter():
+ yield tactic
+ # Then continue with the rest of the file
+ while True:
+ try:
+ yield next(old_iter)
+ except StopIteration:
+ break
+
+ executor.main_file_iter = tactic_iter()
+
+ # Run the tactic (single execution)
+ success = executor.run_next()
+
+ # Add to history
+ executor_state['history'].append({
+ 'tactic': tactic,
+ 'line_num': prev_line_num,
+ 'success': success,
+ 'errors': list(executor.lean_error_messages) if executor.lean_error_messages else []
+ })
+
+ # Get updated state
+ state = get_proof_state()
+ debug = get_debug_info()
+
+ return jsonify({
+ 'success': True,
+ 'tactic_executed': tactic,
+ 'state': state,
+ 'debug': debug
+ })
+
+ except Exception as e:
+ logger.error(f"Error running tactic: {str(e)}")
+ logger.error(traceback.format_exc())
+ return jsonify({
+ 'success': False,
+ 'error': str(e),
+ 'traceback': traceback.format_exc()
+ }), 500
+
+
+@app.route('/api/state', methods=['GET'])
+def get_state():
+ """Get the current proof state"""
+ try:
+ state = get_proof_state()
+ debug = get_debug_info()
+
+ return jsonify({
+ 'success': True,
+ 'state': state,
+ 'debug': debug
+ })
+
+ except Exception as e:
+ logger.error(f"Error getting state: {str(e)}")
+ return jsonify({
+ 'success': False,
+ 'error': str(e)
+ }), 500
+
+
+@app.route('/api/validate', methods=['POST'])
+def validate_proof():
+ """Validate the current proof using lake lean (independent of REPL)"""
+ try:
+ if not executor_state['executor']:
+ return jsonify({
+ 'success': False,
+ 'error': 'Executor not initialized. Please initialize first.'
+ }), 400
+
+ executor = executor_state['executor']
+
+ # Get optional parameters from request
+ data = request.json or {}
+ timeout_sec = data.get('timeout_sec', 30)
+ keep_temp_file = data.get('keep_temp_file', True) # Default to keeping the file
+
+ # Run validation
+ validation_result = executor.validate_proof(
+ timeout_sec=timeout_sec,
+ keep_temp_file=keep_temp_file
+ )
+
+ return jsonify({
+ 'success': True,
+ 'validation': validation_result
+ })
+
+ except Exception as e:
+ logger.error(f"Error validating proof: {str(e)}")
+ logger.error(traceback.format_exc())
+ return jsonify({
+ 'success': False,
+ 'error': str(e),
+ 'traceback': traceback.format_exc()
+ }), 500
+
+
+@app.route('/api/kill', methods=['POST'])
+def kill_executor():
+ """Kill/exit the executor and clean up resources"""
+ try:
+ if executor_state['context_manager']:
+ logger.info("Killing executor and cleaning up resources")
+ executor_state['context_manager'].__exit__(None, None, None)
+ executor_state['executor'] = None
+ executor_state['context_manager'] = None
+ executor_state['history'] = []
+ executor_state['project_root'] = None
+ executor_state['file_path'] = None
+ executor_state['lemma_name'] = None
+
+ return jsonify({
+ 'success': True,
+ 'message': 'Executor killed and resources cleaned up'
+ })
+ else:
+ return jsonify({
+ 'success': True,
+ 'message': 'No active executor to kill'
+ })
+
+ except Exception as e:
+ logger.error(f"Error killing executor: {str(e)}")
+ logger.error(traceback.format_exc())
+ return jsonify({
+ 'success': False,
+ 'error': str(e),
+ 'traceback': traceback.format_exc()
+ }), 500
+
+
+@app.route('/api/reset', methods=['POST'])
+def reset():
+ """Reset the executor to initial state"""
+ try:
+ if executor_state['context_manager']:
+ executor_state['context_manager'].__exit__(None, None, None)
+
+ executor_state['executor'] = None
+ executor_state['context_manager'] = None
+ executor_state['history'] = []
+
+ return jsonify({
+ 'success': True,
+ 'message': 'Executor reset successfully'
+ })
+
+ except Exception as e:
+ logger.error(f"Error resetting: {str(e)}")
+ return jsonify({
+ 'success': False,
+ 'error': str(e)
+ }), 500
+
+
+@app.route('/api/health', methods=['GET'])
+def health():
+ """Health check endpoint"""
+ return jsonify({
+ 'status': 'healthy',
+ 'initialized': executor_state['executor'] is not None
+ })
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='ITP GUI Server')
+ parser.add_argument('--host', default='127.0.0.1', help='Host to bind to')
+ parser.add_argument('--port', type=int, default=5000, help='Port to bind to')
+ parser.add_argument('--debug', action='store_true', help='Enable debug mode')
+
+ args = parser.parse_args()
+
+ logger.info(f"Starting ITP GUI on {args.host}:{args.port}")
+
+ try:
+ app.run(host=args.host, port=args.port, debug=args.debug)
+ finally:
+ # Cleanup on exit
+ if executor_state['context_manager']:
+ try:
+ executor_state['context_manager'].__exit__(None, None, None)
+ except:
+ pass
diff --git a/src/app/itp-gui/static/app.js b/src/app/itp-gui/static/app.js
new file mode 100644
index 0000000..01e8df5
--- /dev/null
+++ b/src/app/itp-gui/static/app.js
@@ -0,0 +1,538 @@
+// ITP GUI - Frontend JavaScript
+
+const API_BASE = '';
+
+// DOM Elements
+let initButton, resetButton, runTacticButton, toggleDebugButton, validateButton, killButton, killButtonProof;
+let projectRootInput, filePathInput, lemmaNameInput, tacticInput;
+let initPanel, proofInterface;
+let initStatus, tacticStatus, validationStatus;
+let currentLemmaName, currentLemmaStmt, goalsContainer, errorsContainer;
+let tacticHistory, debugContent, debugInfo, errorMessages;
+let validationResults, validationLeanCode, validationOutput, validationErrors, validationDebugTraces;
+
+// Application State
+let state = {
+ initialized: false,
+ projectRoot: '',
+ filePath: '',
+ lemmaName: ''
+};
+
+// Initialize DOM references
+document.addEventListener('DOMContentLoaded', () => {
+ // Buttons
+ initButton = document.getElementById('initButton');
+ resetButton = document.getElementById('resetButton');
+ runTacticButton = document.getElementById('runTacticButton');
+ toggleDebugButton = document.getElementById('toggleDebug');
+ validateButton = document.getElementById('validateButton');
+ killButton = document.getElementById('killButton');
+ killButtonProof = document.getElementById('killButtonProof');
+
+ // Inputs
+ projectRootInput = document.getElementById('projectRoot');
+ filePathInput = document.getElementById('filePath');
+ lemmaNameInput = document.getElementById('lemmaName');
+ tacticInput = document.getElementById('tacticInput');
+
+ // Panels
+ initPanel = document.getElementById('initPanel');
+ proofInterface = document.getElementById('proofInterface');
+
+ // Status
+ initStatus = document.getElementById('initStatus');
+ tacticStatus = document.getElementById('tacticStatus');
+ validationStatus = document.getElementById('validationStatus');
+
+ // Proof State
+ currentLemmaName = document.getElementById('currentLemmaName');
+ currentLemmaStmt = document.getElementById('currentLemmaStmt');
+ goalsContainer = document.getElementById('goalsContainer');
+ errorsContainer = document.getElementById('errorsContainer');
+ errorMessages = document.getElementById('errorMessages');
+
+ // History and Debug
+ tacticHistory = document.getElementById('tacticHistory');
+ debugContent = document.getElementById('debugContent');
+ debugInfo = document.getElementById('debugInfo');
+
+ // Validation Results
+ validationResults = document.getElementById('validationResults');
+ validationLeanCode = document.getElementById('validationLeanCode');
+ validationOutput = document.getElementById('validationOutput');
+ validationErrors = document.getElementById('validationErrors');
+ validationDebugTraces = document.getElementById('validationDebugTraces');
+
+ // Event Listeners
+ initButton.addEventListener('click', initializeSession);
+ resetButton.addEventListener('click', resetSession);
+ runTacticButton.addEventListener('click', runTactic);
+ toggleDebugButton.addEventListener('click', toggleDebug);
+ validateButton.addEventListener('click', validateProof);
+ killButton.addEventListener('click', killExecutor);
+ killButtonProof.addEventListener('click', killExecutor);
+
+ // Enter key in tactic input
+ tacticInput.addEventListener('keypress', (e) => {
+ if (e.key === 'Enter' && !e.shiftKey) {
+ e.preventDefault();
+ runTactic();
+ }
+ });
+
+ // Load saved values from localStorage
+ loadSavedInputs();
+
+ // Tab switching
+ document.addEventListener('click', (e) => {
+ if (e.target.classList.contains('tab-button')) {
+ const tabName = e.target.getAttribute('data-tab');
+ switchTab(tabName, e.target);
+ }
+ });
+});
+
+// Switch tabs
+function switchTab(tabName, button) {
+ // Remove active class from all buttons and panes
+ const tabButtons = document.querySelectorAll('.tab-button');
+ const tabPanes = document.querySelectorAll('.tab-pane');
+
+ tabButtons.forEach(btn => btn.classList.remove('active'));
+ tabPanes.forEach(pane => pane.classList.remove('active'));
+
+ // Add active class to selected button and pane
+ button.classList.add('active');
+ document.getElementById(tabName).classList.add('active');
+}
+
+// Save input values to localStorage
+function saveInputs() {
+ localStorage.setItem('itp_project_root', projectRootInput.value);
+ localStorage.setItem('itp_file_path', filePathInput.value);
+ localStorage.setItem('itp_lemma_name', lemmaNameInput.value);
+}
+
+// Load saved input values
+function loadSavedInputs() {
+ const savedProjectRoot = localStorage.getItem('itp_project_root');
+ const savedFilePath = localStorage.getItem('itp_file_path');
+ const savedLemmaName = localStorage.getItem('itp_lemma_name');
+
+ if (savedProjectRoot) projectRootInput.value = savedProjectRoot;
+ if (savedFilePath) filePathInput.value = savedFilePath;
+ if (savedLemmaName) lemmaNameInput.value = savedLemmaName;
+}
+
+// Show status message
+function showStatus(element, message, type) {
+ element.textContent = message;
+ element.className = 'status-message ' + type;
+ element.style.display = 'block';
+}
+
+// Hide status message
+function hideStatus(element) {
+ element.style.display = 'none';
+}
+
+// Initialize Session
+async function initializeSession() {
+ const projectRoot = projectRootInput.value.trim();
+ const filePath = filePathInput.value.trim();
+ const lemmaName = lemmaNameInput.value.trim();
+
+ if (!projectRoot || !filePath || !lemmaName) {
+ showStatus(initStatus, 'Please fill in all fields', 'error');
+ return;
+ }
+
+ showStatus(initStatus, 'Initializing...', 'info');
+ initButton.disabled = true;
+
+ try {
+ const response = await fetch(`${API_BASE}/api/initialize`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify({
+ project_root: projectRoot,
+ file_path: filePath,
+ lemma_name: lemmaName
+ })
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ state.initialized = true;
+ state.projectRoot = projectRoot;
+ state.filePath = filePath;
+ state.lemmaName = lemmaName;
+
+ saveInputs();
+
+ showStatus(initStatus, 'Session initialized successfully!', 'success');
+
+ // Switch to proof interface
+ setTimeout(() => {
+ initPanel.style.display = 'none';
+ proofInterface.style.display = 'block';
+ updateProofState(data.state);
+ updateDebugInfo(data.debug);
+ }, 500);
+ } else {
+ showStatus(initStatus, 'Error: ' + data.error, 'error');
+ }
+ } catch (error) {
+ showStatus(initStatus, 'Network error: ' + error.message, 'error');
+ } finally {
+ initButton.disabled = false;
+ }
+}
+
+// Kill Executor
+async function killExecutor() {
+ if (!confirm('Are you sure you want to kill the executor? This will terminate the REPL process and clean up all resources.')) {
+ return;
+ }
+
+ try {
+ const response = await fetch(`${API_BASE}/api/kill`, {
+ method: 'POST'
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ state.initialized = false;
+ proofInterface.style.display = 'none';
+ initPanel.style.display = 'block';
+ tacticHistory.innerHTML = '
No tactics executed yet.
';
+ showStatus(initStatus, data.message, 'success');
+ setTimeout(() => hideStatus(initStatus), 3000);
+ } else {
+ showStatus(initStatus, 'Error killing executor: ' + data.error, 'error');
+ }
+ } catch (error) {
+ showStatus(initStatus, 'Error killing executor: ' + error.message, 'error');
+ }
+}
+
+// Reset Session
+async function resetSession() {
+ if (!confirm('Are you sure you want to reset the session?')) {
+ return;
+ }
+
+ try {
+ const response = await fetch(`${API_BASE}/api/reset`, {
+ method: 'POST'
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ state.initialized = false;
+ proofInterface.style.display = 'none';
+ initPanel.style.display = 'block';
+ tacticHistory.innerHTML = 'No tactics executed yet.
';
+ hideStatus(initStatus);
+ }
+ } catch (error) {
+ showStatus(initStatus, 'Error resetting: ' + error.message, 'error');
+ }
+}
+
+// Run Tactic
+async function runTactic() {
+ // Don't trim - preserve exact whitespace including tabs/spaces for Lean formatting
+ const tactic = tacticInput.value;
+
+ // Only check if it's completely empty (no trimming)
+ if (!tactic || tactic.trim().length === 0) {
+ showStatus(tacticStatus, 'Please enter a tactic', 'error');
+ return;
+ }
+
+ showStatus(tacticStatus, 'Running tactic...', 'info');
+ runTacticButton.disabled = true;
+
+ try {
+ const response = await fetch(`${API_BASE}/api/run_tactic`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify({
+ tactic: tactic
+ })
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ showStatus(tacticStatus, 'Tactic executed!', 'success');
+ tacticInput.value = '';
+ updateProofState(data.state);
+ updateDebugInfo(data.debug);
+
+ setTimeout(() => hideStatus(tacticStatus), 2000);
+ } else {
+ showStatus(tacticStatus, 'Error: ' + data.error, 'error');
+ }
+ } catch (error) {
+ showStatus(tacticStatus, 'Network error: ' + error.message, 'error');
+ } finally {
+ runTacticButton.disabled = false;
+ tacticInput.focus();
+ }
+}
+
+// Update Proof State
+function updateProofState(proofState) {
+ // Update lemma info
+ currentLemmaName.textContent = proofState.lemma_name || 'Unknown';
+
+ // Update lemma statement with syntax highlighting
+ const lemmaStmt = proofState.lemma_stmt || 'No statement available';
+ currentLemmaStmt.innerHTML = '';
+ const codeElement = document.createElement('code');
+ codeElement.className = 'language-haskell';
+ codeElement.textContent = lemmaStmt;
+ currentLemmaStmt.appendChild(codeElement);
+
+ // Apply syntax highlighting
+ if (typeof Prism !== 'undefined') {
+ Prism.highlightElement(codeElement);
+ }
+
+ // Update goals
+ goalsContainer.innerHTML = '';
+ if (proofState.goals && proofState.goals.length > 0) {
+ proofState.goals.forEach((goal, index) => {
+ const goalDiv = document.createElement('div');
+ goalDiv.className = 'goal-item';
+
+ const goalTitle = document.createElement('h4');
+ goalTitle.textContent = `Goal ${index + 1}:`;
+ goalDiv.appendChild(goalTitle);
+
+ if (goal.hypotheses && goal.hypotheses.length > 0) {
+ const hypothesesDiv = document.createElement('div');
+ hypothesesDiv.className = 'hypotheses';
+
+ const hypTitle = document.createElement('strong');
+ hypTitle.textContent = 'Hypotheses:';
+ hypothesesDiv.appendChild(hypTitle);
+
+ goal.hypotheses.forEach(hyp => {
+ const hypDiv = document.createElement('div');
+ hypDiv.className = 'hypothesis';
+ hypDiv.textContent = hyp;
+ hypothesesDiv.appendChild(hypDiv);
+ });
+
+ goalDiv.appendChild(hypothesesDiv);
+ }
+
+ const conclusionDiv = document.createElement('div');
+ conclusionDiv.className = 'goal-conclusion';
+ const conclusionTitle = document.createElement('strong');
+ conclusionTitle.textContent = '⊢ Goal:';
+ conclusionDiv.appendChild(conclusionTitle);
+
+ const goalPre = document.createElement('pre');
+ goalPre.textContent = goal.goal || 'No goal';
+ conclusionDiv.appendChild(goalPre);
+
+ goalDiv.appendChild(conclusionDiv);
+ goalsContainer.appendChild(goalDiv);
+ });
+ } else {
+ const noGoals = document.createElement('p');
+ noGoals.className = 'empty-message';
+ noGoals.textContent = proofState.is_in_proof_mode ? 'No goals remaining! Proof complete!' : 'Not in proof mode';
+ goalsContainer.appendChild(noGoals);
+ }
+
+ // Update error messages
+ if (proofState.error_messages && proofState.error_messages.length > 0) {
+ errorMessages.style.display = 'block';
+ errorsContainer.innerHTML = '';
+ proofState.error_messages.forEach(error => {
+ const errorDiv = document.createElement('div');
+ errorDiv.className = 'error-item';
+ errorDiv.textContent = error;
+ errorsContainer.appendChild(errorDiv);
+ });
+ } else {
+ errorMessages.style.display = 'none';
+ }
+
+ // Update history
+ updateHistory(proofState.history);
+}
+
+// Update Tactic History
+function updateHistory(history) {
+ if (!history || history.length === 0) {
+ tacticHistory.innerHTML = 'No tactics executed yet.
';
+ return;
+ }
+
+ tacticHistory.innerHTML = '';
+ history.forEach((item, index) => {
+ const historyDiv = document.createElement('div');
+ historyDiv.className = 'tactic-item';
+ if (item.errors && item.errors.length > 0) {
+ historyDiv.classList.add('error');
+ }
+
+ const tacticCode = document.createElement('div');
+ tacticCode.className = 'tactic-code';
+ tacticCode.textContent = `${index + 1}. ${item.tactic}`;
+ historyDiv.appendChild(tacticCode);
+
+ const tacticMeta = document.createElement('div');
+ tacticMeta.className = 'tactic-meta';
+ tacticMeta.textContent = `Line: ${item.line_num} | Success: ${item.success}`;
+ if (item.errors && item.errors.length > 0) {
+ tacticMeta.textContent += ' | Errors: ' + item.errors.join(', ');
+ }
+ historyDiv.appendChild(tacticMeta);
+
+ tacticHistory.appendChild(historyDiv);
+ });
+
+ // Scroll to bottom
+ tacticHistory.scrollTop = tacticHistory.scrollHeight;
+}
+
+// Update Debug Info
+function updateDebugInfo(debug) {
+ if (!debug) return;
+
+ const formatted = JSON.stringify(debug, null, 2);
+ debugContent.textContent = formatted;
+}
+
+// Toggle Debug Panel
+function toggleDebug() {
+ if (debugInfo.style.display === 'none') {
+ debugInfo.style.display = 'block';
+ } else {
+ debugInfo.style.display = 'none';
+ }
+}
+
+// Validate Proof
+async function validateProof() {
+ showStatus(validationStatus, 'Validating proof with lake lean...', 'info');
+ validateButton.disabled = true;
+
+ try {
+ const response = await fetch(`${API_BASE}/api/validate`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json'
+ },
+ body: JSON.stringify({
+ timeout_sec: 30
+ })
+ });
+
+ const data = await response.json();
+
+ if (data.success) {
+ const validation = data.validation;
+
+ if (validation.success) {
+ showStatus(validationStatus, '✓ Proof is complete and valid!', 'success');
+ } else if (!validation.compilation_ok) {
+ showStatus(validationStatus, '✗ Compilation failed: ' + validation.error_message, 'error');
+ } else if (validation.has_sorries) {
+ showStatus(validationStatus, '✗ Proof has unsolved goals (sorries)', 'error');
+ } else {
+ showStatus(validationStatus, '✗ ' + validation.error_message, 'error');
+ }
+
+ // Display validation results
+ displayValidationResults(validation);
+
+ // Auto-hide success message after 3 seconds
+ if (validation.success) {
+ setTimeout(() => hideStatus(validationStatus), 3000);
+ }
+ } else {
+ showStatus(validationStatus, 'Error: ' + data.error, 'error');
+ }
+ } catch (error) {
+ showStatus(validationStatus, 'Network error: ' + error.message, 'error');
+ } finally {
+ validateButton.disabled = false;
+ }
+}
+
+// Display Validation Results in Tabs
+function displayValidationResults(validation) {
+ // Show the validation results panel
+ validationResults.style.display = 'block';
+
+ // Populate Lean Code tab with syntax highlighting
+ const leanCode = validation.lean_code || 'No code available';
+ validationLeanCode.innerHTML = '';
+ const codeElement = document.createElement('code');
+ codeElement.className = 'language-haskell'; // Using Haskell as closest syntax
+ codeElement.textContent = leanCode;
+ const preElement = document.createElement('pre');
+ preElement.className = 'line-numbers';
+ preElement.appendChild(codeElement);
+ validationLeanCode.appendChild(preElement);
+
+ // Apply syntax highlighting
+ if (typeof Prism !== 'undefined') {
+ Prism.highlightElement(codeElement);
+ }
+
+ // Populate Output tab
+ const outputText = `File: ${validation.temp_filename || 'N/A'}
+Full Path: ${validation.temp_file_path || 'N/A'}
+File Kept: ${validation.temp_file_kept ? 'Yes' : 'No'}
+Return Code: ${validation.return_code}
+Success: ${validation.success}
+Compilation OK: ${validation.compilation_ok}
+Has Sorries: ${validation.has_sorries}
+Error Message: ${validation.error_message}
+
+${validation.full_output || (validation.stdout + '\n' + validation.stderr)}`;
+ validationOutput.textContent = outputText;
+
+ // Populate Errors tab
+ if (validation.errors && validation.errors.length > 0) {
+ const errorsText = validation.errors.map((err, idx) =>
+ `[${idx + 1}] ${err.file}:${err.line}:${err.column} (${err.severity})\n${err.message}`
+ ).join('\n\n');
+ validationErrors.textContent = errorsText;
+ } else {
+ validationErrors.textContent = 'No errors found';
+ }
+
+ // Populate Debug Traces tab
+ if (validation.debug_traces && validation.debug_traces.length > 0) {
+ const tracesText = validation.debug_traces.map((trace, idx) =>
+ `[${idx + 1}] ${trace}`
+ ).join('\n');
+ validationDebugTraces.textContent = tracesText;
+ } else {
+ validationDebugTraces.textContent = 'No debug traces available';
+ }
+
+ // Add proof tactics info to output if available
+ if (validation.possible_proof_tactics) {
+ const currentOutput = validationOutput.textContent;
+ validationOutput.textContent = currentOutput + '\n\n=== PROOF TACTICS (from possible_proof_tactics) ===\n' + validation.possible_proof_tactics;
+ }
+}
diff --git a/src/app/itp-gui/static/index.html b/src/app/itp-gui/static/index.html
new file mode 100644
index 0000000..9a98279
--- /dev/null
+++ b/src/app/itp-gui/static/index.html
@@ -0,0 +1,135 @@
+
+
+
+
+
+ ITP GUI - Interactive Theorem Proving
+
+
+
+
+
+
+
+
+
+
+
+
+
Initialize Proof Session
+
+
+
+ Absolute path to the Lean 4 project root directory
+
+
+
+
+ Absolute path to the .lean file containing the lemma
+
+
+
+
+ Name of the theorem/lemma to prove
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Current Proof State
+
+
+
+
+
+
+
+
+
+
+
+
+
Validation Results
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Tactics History
+
+
No tactics executed yet.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/app/itp-gui/static/style.css b/src/app/itp-gui/static/style.css
new file mode 100644
index 0000000..681abe6
--- /dev/null
+++ b/src/app/itp-gui/static/style.css
@@ -0,0 +1,460 @@
+/* Reset and base styles */
+* {
+ margin: 0;
+ padding: 0;
+ box-sizing: border-box;
+}
+
+body {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
+ color: #333;
+ line-height: 1.6;
+ min-height: 100vh;
+ padding: 20px;
+}
+
+.container {
+ max-width: 1600px;
+ margin: 0 auto;
+}
+
+/* Header */
+header {
+ text-align: center;
+ color: white;
+ margin-bottom: 30px;
+}
+
+header h1 {
+ font-size: 2.5em;
+ margin-bottom: 10px;
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
+}
+
+.subtitle {
+ font-size: 1.1em;
+ opacity: 0.9;
+}
+
+/* Main Content */
+.main-content {
+ background: white;
+ border-radius: 12px;
+ padding: 30px;
+ box-shadow: 0 10px 30px rgba(0,0,0,0.3);
+}
+
+/* Panels */
+.panel {
+ background: #f8f9fa;
+ border-radius: 8px;
+ padding: 20px;
+ margin-bottom: 20px;
+ border: 1px solid #dee2e6;
+}
+
+.panel h2 {
+ color: #495057;
+ margin-bottom: 15px;
+ font-size: 1.5em;
+ border-bottom: 2px solid #667eea;
+ padding-bottom: 10px;
+}
+
+.panel h3 {
+ color: #6c757d;
+ margin-top: 15px;
+ margin-bottom: 10px;
+ font-size: 1.2em;
+}
+
+/* Form Groups */
+.form-group {
+ margin-bottom: 20px;
+}
+
+.form-group label {
+ display: block;
+ margin-bottom: 5px;
+ font-weight: 600;
+ color: #495057;
+}
+
+.form-group input,
+.form-group textarea {
+ width: 100%;
+ padding: 10px;
+ border: 2px solid #dee2e6;
+ border-radius: 6px;
+ font-size: 14px;
+ font-family: 'Monaco', 'Courier New', monospace;
+ transition: border-color 0.3s;
+}
+
+.form-group input:focus,
+.form-group textarea:focus {
+ outline: none;
+ border-color: #667eea;
+}
+
+.form-group small {
+ display: block;
+ margin-top: 5px;
+ color: #6c757d;
+ font-size: 0.85em;
+}
+
+/* Buttons */
+.btn {
+ padding: 12px 24px;
+ border: none;
+ border-radius: 6px;
+ font-size: 16px;
+ font-weight: 600;
+ cursor: pointer;
+ transition: all 0.3s;
+ margin-right: 10px;
+ margin-bottom: 10px;
+}
+
+.button-group {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 10px;
+ margin-bottom: 10px;
+}
+
+.button-group .btn {
+ margin-right: 0;
+ margin-bottom: 0;
+}
+
+.btn-primary {
+ background: #667eea;
+ color: white;
+}
+
+.btn-primary:hover {
+ background: #5568d3;
+ transform: translateY(-2px);
+ box-shadow: 0 4px 8px rgba(102, 126, 234, 0.4);
+}
+
+.btn-secondary {
+ background: #6c757d;
+ color: white;
+}
+
+.btn-secondary:hover {
+ background: #5a6268;
+}
+
+.btn-danger {
+ background: #dc3545;
+ color: white;
+}
+
+.btn-danger:hover {
+ background: #c82333;
+ transform: translateY(-2px);
+ box-shadow: 0 4px 8px rgba(220, 53, 69, 0.4);
+}
+
+.btn-small {
+ padding: 5px 10px;
+ font-size: 12px;
+ background: #667eea;
+ color: white;
+ border: none;
+ border-radius: 4px;
+ cursor: pointer;
+ float: right;
+}
+
+.btn-small:hover {
+ background: #5568d3;
+}
+
+/* Status Messages */
+.status-message {
+ margin-top: 15px;
+ padding: 12px;
+ border-radius: 6px;
+ display: none;
+}
+
+.status-message.success {
+ background: #d4edda;
+ color: #155724;
+ border: 1px solid #c3e6cb;
+ display: block;
+}
+
+.status-message.error {
+ background: #f8d7da;
+ color: #721c24;
+ border: 1px solid #f5c6cb;
+ display: block;
+}
+
+.status-message.info {
+ background: #d1ecf1;
+ color: #0c5460;
+ border: 1px solid #bee5eb;
+ display: block;
+}
+
+/* Two Column Layout */
+.two-column {
+ display: grid;
+ grid-template-columns: 1fr 1fr;
+ gap: 20px;
+}
+
+.column {
+ min-width: 0;
+}
+
+/* Code Blocks */
+.code-block {
+ padding: 0;
+ border-radius: 6px;
+ overflow-x: auto;
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+ line-height: 1.5;
+}
+
+/* Prism.js code blocks */
+.code-block pre {
+ margin: 0;
+ border-radius: 6px;
+}
+
+.code-block pre code {
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+}
+
+/* Line numbers styling */
+.line-numbers .line-numbers-rows {
+ border-right: 1px solid #555;
+}
+
+/* Lemma Info */
+.lemma-info {
+ margin-bottom: 20px;
+}
+
+.lemma-info h3 {
+ color: #667eea;
+ font-size: 1.3em;
+}
+
+.lemma-info pre code {
+ background: transparent;
+}
+
+/* Proof Goals */
+.proof-goals {
+ margin-top: 20px;
+}
+
+.goal-item {
+ background: white;
+ border-left: 4px solid #667eea;
+ padding: 15px;
+ margin-bottom: 15px;
+ border-radius: 4px;
+}
+
+.goal-item h4 {
+ color: #495057;
+ margin-bottom: 10px;
+}
+
+.hypotheses {
+ margin-bottom: 10px;
+}
+
+.hypothesis {
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+ color: #495057;
+ padding: 3px 0;
+}
+
+.goal-conclusion {
+ border-top: 2px solid #dee2e6;
+ padding-top: 10px;
+ margin-top: 10px;
+}
+
+.goal-conclusion pre {
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+ color: #28a745;
+ font-weight: 600;
+}
+
+/* Error Messages */
+.error-messages {
+ margin-top: 20px;
+ background: #f8d7da;
+ border: 1px solid #f5c6cb;
+ border-radius: 6px;
+ padding: 15px;
+}
+
+.error-item {
+ color: #721c24;
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+ margin-bottom: 5px;
+}
+
+/* Tactic History */
+.tactic-history {
+ max-height: 400px;
+ overflow-y: auto;
+ background: white;
+ padding: 10px;
+ border-radius: 6px;
+}
+
+.tactic-item {
+ background: #e9ecef;
+ padding: 10px;
+ margin-bottom: 8px;
+ border-radius: 4px;
+ border-left: 4px solid #28a745;
+}
+
+.tactic-item.error {
+ border-left-color: #dc3545;
+}
+
+.tactic-item .tactic-code {
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 13px;
+ color: #495057;
+ font-weight: 600;
+}
+
+.tactic-item .tactic-meta {
+ font-size: 11px;
+ color: #6c757d;
+ margin-top: 5px;
+}
+
+.empty-message {
+ text-align: center;
+ color: #6c757d;
+ padding: 20px;
+ font-style: italic;
+}
+
+/* Debug Panel */
+.debug-info {
+ background: #282c34;
+ border-radius: 6px;
+ padding: 15px;
+ max-height: 500px;
+ overflow-y: auto;
+}
+
+.debug-info pre {
+ color: #abb2bf;
+ font-family: 'Monaco', 'Courier New', monospace;
+ font-size: 12px;
+ line-height: 1.5;
+ white-space: pre-wrap;
+ word-wrap: break-word;
+}
+
+/* Scrollbar Styling */
+::-webkit-scrollbar {
+ width: 10px;
+ height: 10px;
+}
+
+::-webkit-scrollbar-track {
+ background: #f1f1f1;
+ border-radius: 5px;
+}
+
+::-webkit-scrollbar-thumb {
+ background: #888;
+ border-radius: 5px;
+}
+
+::-webkit-scrollbar-thumb:hover {
+ background: #555;
+}
+
+/* Tabs */
+.tabs {
+ display: flex;
+ border-bottom: 2px solid #dee2e6;
+ margin-bottom: 15px;
+}
+
+.tab-button {
+ background: none;
+ border: none;
+ padding: 10px 20px;
+ font-size: 14px;
+ font-weight: 600;
+ color: #6c757d;
+ cursor: pointer;
+ border-bottom: 3px solid transparent;
+ transition: all 0.3s;
+}
+
+.tab-button:hover {
+ color: #495057;
+ background: #f8f9fa;
+}
+
+.tab-button.active {
+ color: #667eea;
+ border-bottom-color: #667eea;
+}
+
+.tab-content {
+ position: relative;
+}
+
+.tab-pane {
+ display: none;
+}
+
+.tab-pane.active {
+ display: block;
+}
+
+.validation-results-panel {
+ margin-top: 20px;
+}
+
+/* Responsive Design */
+@media (max-width: 1200px) {
+ .two-column {
+ grid-template-columns: 1fr;
+ }
+}
+
+@media (max-width: 768px) {
+ header h1 {
+ font-size: 2em;
+ }
+
+ .main-content {
+ padding: 20px;
+ }
+
+ .panel {
+ padding: 15px;
+ }
+}
diff --git a/src/data/test/Mathlib/lake-manifest.json b/src/data/test/Mathlib/lake-manifest.json
index 2a8d9e3..eabbf16 100644
--- a/src/data/test/Mathlib/lake-manifest.json
+++ b/src/data/test/Mathlib/lake-manifest.json
@@ -1,68 +1,95 @@
-{"version": 7,
+{"version": "1.1.0",
"packagesDir": ".lake/packages",
"packages":
- [{"url": "https://github.com/leanprover/std4",
+ [{"url": "https://github.com/leanprover-community/mathlib4",
"type": "git",
"subDir": null,
- "rev": "e5306c3b0edefe722370b7387ee9bcd4631d6c17",
- "name": "std",
+ "scope": "",
+ "rev": "f897ebcf72cd16f89ab4577d0c826cd14afaafc7",
+ "name": "mathlib",
+ "manifestFile": "lake-manifest.json",
+ "inputRev": "v4.24.0",
+ "inherited": false,
+ "configFile": "lakefile.lean"},
+ {"url": "https://github.com/leanprover-community/plausible",
+ "type": "git",
+ "subDir": null,
+ "scope": "leanprover-community",
+ "rev": "dfd06ebfe8d0e8fa7faba9cb5e5a2e74e7bd2805",
+ "name": "plausible",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/quote4",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/LeanSearchClient",
"type": "git",
"subDir": null,
- "rev": "fd760831487e6835944e7eeed505522c9dd47563",
- "name": "Qq",
+ "scope": "leanprover-community",
+ "rev": "99657ad92e23804e279f77ea6dbdeebaa1317b98",
+ "name": "LeanSearchClient",
"manifestFile": "lake-manifest.json",
- "inputRev": "master",
+ "inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/aesop",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/import-graph",
"type": "git",
"subDir": null,
- "rev": "8be30c25e3caa06937feeb62f7ca898370f80ee9",
- "name": "aesop",
+ "scope": "leanprover-community",
+ "rev": "d768126816be17600904726ca7976b185786e6b9",
+ "name": "importGraph",
"manifestFile": "lake-manifest.json",
- "inputRev": "master",
+ "inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
+ "configFile": "lakefile.toml"},
{"url": "https://github.com/leanprover-community/ProofWidgets4",
"type": "git",
"subDir": null,
- "rev": "fb65c476595a453a9b8ffc4a1cea2db3a89b9cd8",
+ "scope": "leanprover-community",
+ "rev": "556caed0eadb7901e068131d1be208dd907d07a2",
"name": "proofwidgets",
"manifestFile": "lake-manifest.json",
- "inputRev": "v0.0.30",
+ "inputRev": "v0.0.74",
"inherited": true,
"configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover/lean4-cli",
+ {"url": "https://github.com/leanprover-community/aesop",
"type": "git",
"subDir": null,
- "rev": "be8fa79a28b8b6897dce0713ef50e89c4a0f6ef5",
- "name": "Cli",
+ "scope": "leanprover-community",
+ "rev": "725ac8cd67acd70a7beaf47c3725e23484c1ef50",
+ "name": "aesop",
"manifestFile": "lake-manifest.json",
- "inputRev": "main",
+ "inputRev": "master",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/import-graph.git",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/quote4",
"type": "git",
"subDir": null,
- "rev": "61a79185b6582573d23bf7e17f2137cd49e7e662",
- "name": "importGraph",
+ "scope": "leanprover-community",
+ "rev": "dea6a3361fa36d5a13f87333dc506ada582e025c",
+ "name": "Qq",
+ "manifestFile": "lake-manifest.json",
+ "inputRev": "master",
+ "inherited": true,
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover-community/batteries",
+ "type": "git",
+ "subDir": null,
+ "scope": "leanprover-community",
+ "rev": "8da40b72fece29b7d3fe3d768bac4c8910ce9bee",
+ "name": "batteries",
"manifestFile": "lake-manifest.json",
"inputRev": "main",
"inherited": true,
- "configFile": "lakefile.lean"},
- {"url": "https://github.com/leanprover-community/mathlib4",
+ "configFile": "lakefile.toml"},
+ {"url": "https://github.com/leanprover/lean4-cli",
"type": "git",
"subDir": null,
- "rev": "fe4454af900584467d21f4fd4fe951d29d9332a7",
- "name": "mathlib",
+ "scope": "leanprover",
+ "rev": "91c18fa62838ad0ab7384c03c9684d99d306e1da",
+ "name": "Cli",
"manifestFile": "lake-manifest.json",
- "inputRev": "v4.7.0-rc2",
- "inherited": false,
- "configFile": "lakefile.lean"}],
+ "inputRev": "main",
+ "inherited": true,
+ "configFile": "lakefile.toml"}],
"name": "«repl-mathlib-tests»",
"lakeDir": ".lake"}
diff --git a/src/data/test/Mathlib/lakefile.lean b/src/data/test/Mathlib/lakefile.lean
index 34c561d..b1f61b0 100644
--- a/src/data/test/Mathlib/lakefile.lean
+++ b/src/data/test/Mathlib/lakefile.lean
@@ -3,7 +3,7 @@ open Lake DSL
package «repl-mathlib-tests» where
-- add package configuration options here
- require mathlib from git "https://github.com/leanprover-community/mathlib4" @ "v4.7.0-rc2"
+ require mathlib from git "https://github.com/leanprover-community/mathlib4" @ "v4.24.0"
@[default_target]
lean_lib «ReplMathlibTests» where
diff --git a/src/data/test/Mathlib/lean-toolchain b/src/data/test/Mathlib/lean-toolchain
index e35881c..c00a535 100644
--- a/src/data/test/Mathlib/lean-toolchain
+++ b/src/data/test/Mathlib/lean-toolchain
@@ -1 +1 @@
-leanprover/lean4:v4.7.0-rc2
+leanprover/lean4:v4.24.0
diff --git a/src/itp_interface/lean_server/lean4_utils.py b/src/itp_interface/lean_server/lean4_utils.py
index b29e7d0..607dc49 100644
--- a/src/itp_interface/lean_server/lean4_utils.py
+++ b/src/itp_interface/lean_server/lean4_utils.py
@@ -142,6 +142,7 @@ def parse_proof_context_human_readable(proof_context_str: str) -> ProofContext:
# raise
return ProofContext(goals, [], [], [])
+ @staticmethod
def parse_proof_context_human_readable_as_goals(proof_context_str: str) -> typing.List[Obligation]:
if len(proof_context_str) == 0 and Lean4Utils.proof_context_separator not in proof_context_str:
return None
diff --git a/src/itp_interface/main/config.py b/src/itp_interface/main/config.py
index 979b396..5d0c042 100644
--- a/src/itp_interface/main/config.py
+++ b/src/itp_interface/main/config.py
@@ -1,10 +1,5 @@
#!/usr/bin/env python3
-import sys
-
-root_dir = f"{__file__.split('itp_interface')[0]}"
-if root_dir not in sys.path:
- sys.path.append(root_dir)
import typing
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
@@ -73,11 +68,17 @@ class EvalFile(object):
path: str
theorems: typing.Union[str, typing.List[str]]
+@dataclass_json
+@dataclass
+class ExtractFile(object):
+ path: str
+ declarations: typing.Union[str, typing.List[str]]
+
@dataclass_json
@dataclass
class EvalDataset(object):
project: str
- files: typing.List[EvalFile]
+ files: typing.Union[typing.List[EvalFile], typing.List[ExtractFile]]
@dataclass_json
@dataclass
@@ -90,6 +91,7 @@ class EvalBenchmark(object):
few_shot_metadata_filename_for_retrieval: str = None
dfs_data_path_for_retrieval: str = None
dfs_metadata_filename_for_retrieval: str = None
+ is_extraction_request: bool = False
setup_cmds: typing.List[str] = field(default_factory=list)
@dataclass_json
@@ -133,6 +135,7 @@ def add_theorem_to_maps(self, path: str, theorem: str, proof_result: ProofSearch
def parse_config(cfg):
+ is_extraction_request = False
env_settings_cfg = cfg["env_settings"]
env_settings = EnvSettings(
name=env_settings_cfg["name"],
@@ -167,14 +170,28 @@ def parse_config(cfg):
files_cfg = list(dataset_cfg["files"])
eval_files = []
for file_cfg in files_cfg:
- theorems = None
- if type(file_cfg["theorems"]) == str:
- theorems = file_cfg["theorems"]
+ if "theorems" in file_cfg:
+ theorems = None
+ if type(file_cfg["theorems"]) == str:
+ theorems = file_cfg["theorems"]
+ else:
+ theorems = list(file_cfg["theorems"])
+ eval_files.append(EvalFile(
+ path=file_cfg["path"],
+ theorems=theorems))
+ is_extraction_request = False
+ elif "declarations" in file_cfg:
+ declarations = None
+ if type(file_cfg["declarations"]) == str:
+ declarations = file_cfg["declarations"]
+ else:
+ declarations = list(file_cfg["declarations"])
+ eval_files.append(ExtractFile(
+ path=file_cfg["path"],
+ declarations=declarations))
+ is_extraction_request = True
else:
- theorems = list(file_cfg["theorems"])
- eval_files.append(EvalFile(
- path=file_cfg["path"],
- theorems=theorems))
+ raise ValueError(f"File config must have either 'theorems' or 'declarations': {file_cfg}")
eval_datasets.append(EvalDataset(
project=dataset_cfg["project"],
files=eval_files))
@@ -188,5 +205,6 @@ def parse_config(cfg):
few_shot_metadata_filename_for_retrieval=benchmark_cfg["few_shot_metadata_filename_for_retrieval"],
dfs_data_path_for_retrieval=benchmark_cfg["dfs_data_path_for_retrieval"],
dfs_metadata_filename_for_retrieval=benchmark_cfg["dfs_metadata_filename_for_retrieval"],
+ is_extraction_request=is_extraction_request and benchmark_cfg.get("is_extraction_request", False),
setup_cmds=benchmark_cfg["setup_cmds"] if "setup_cmds" in benchmark_cfg else [])
return Experiments(env_settings=env_settings, run_settings=eval_settings, benchmark=benchmark)
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml
new file mode 100644
index 0000000..2ed8c1d
--- /dev/null
+++ b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml
@@ -0,0 +1,13 @@
+name: simple_benchmark_lean_ext
+num_files: 1
+language: LEAN4
+few_shot_data_path_for_retrieval:
+few_shot_metadata_filename_for_retrieval:
+dfs_data_path_for_retrieval:
+dfs_metadata_filename_for_retrieval:
+is_extraction_request: true
+datasets:
+ - project: src/data/test/lean4_proj
+ files:
+ - path: Lean4Proj/Basic.lean
+ declarations: "*"
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/simple_lean_data_extract.yaml b/src/itp_interface/main/configs/simple_lean_data_extract.yaml
new file mode 100644
index 0000000..4fed34c
--- /dev/null
+++ b/src/itp_interface/main/configs/simple_lean_data_extract.yaml
@@ -0,0 +1,13 @@
+defaults:
+ # - benchmark: simple_benchmark_lean_training_data
+ # - run_settings: default_lean_data_generation_transforms
+ # - benchmark: simple_benchmark_1
+ # - run_settings: default_lean4_data_generation_transforms
+ - benchmark: simple_benchmark_lean_ext
+ - run_settings: default_lean4_data_generation_transforms
+ - env_settings: no_retrieval
+ - override hydra/job_logging: 'disabled'
+
+run_settings:
+ output_dir: .log/data_generation/benchmark/simple_benchmark_lean_ext
+ pool_size: 2
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/simple_lean_data_gen.yaml b/src/itp_interface/main/configs/simple_lean_data_gen.yaml
index 135f353..d93dd0f 100644
--- a/src/itp_interface/main/configs/simple_lean_data_gen.yaml
+++ b/src/itp_interface/main/configs/simple_lean_data_gen.yaml
@@ -9,4 +9,5 @@ defaults:
- override hydra/job_logging: 'disabled'
run_settings:
- output_dir: .log/data_generation/benchmark/simple_benchmark_lean
\ No newline at end of file
+ output_dir: .log/data_generation/benchmark/simple_benchmark_lean
+ pool_size: 2
\ No newline at end of file
diff --git a/src/itp_interface/main/filter_dataset.py b/src/itp_interface/main/filter_dataset.py
index 788ab13..3c376f9 100644
--- a/src/itp_interface/main/filter_dataset.py
+++ b/src/itp_interface/main/filter_dataset.py
@@ -11,9 +11,9 @@
import typing
import copy
from itp_interface.tools.log_utils import setup_logger
-from itp_interface.tools.training_data import TrainingData, TrainingDataFormat
+from itp_interface.tools.training_data import TrainingData, TheoremProvingTrainingDataFormat
-def filter_keyword(tdf: TrainingDataFormat, keywords: typing.List[str]) -> bool:
+def filter_keyword(tdf: TheoremProvingTrainingDataFormat, keywords: typing.List[str]) -> bool:
"""
Filter out the training data format if it contains any of the keywords
"""
@@ -40,7 +40,7 @@ def filter_dataset(dataset, metafilename, output, keywords, max_parallelism=8, l
metadata = copy.deepcopy(training_data.meta)
logger.info(f"Start filtering datasets")
cloned_metadata = copy.deepcopy(metadata)
- cloned_metadata.total_proof_step_cnt = 0
+ cloned_metadata.total_data_count = 0
cloned_metadata.last_training_data = 0
logger.info(f"Cloned metadata:\n {cloned_metadata}")
filtered_td = TrainingData(
@@ -67,7 +67,7 @@ def filter_dataset(dataset, metafilename, output, keywords, max_parallelism=8, l
total_data_points += 1
logger.info(f"In total, skipped {skipped_cnt} data points.")
logger.info(f"Total data points: {total_data_points}")
- logger.info(f"Total proof steps: {metadata.total_proof_step_cnt}")
+ logger.info(f"Total data points so far: {metadata.total_data_count}")
logger.info(f"Lenght of filtered dataset: {len(filtered_td)}")
logger.info(f"Length of training data: {len(training_data)}")
logger.info("Finished filtering dataset.")
@@ -83,7 +83,7 @@ def filter_dataset(dataset, metafilename, output, keywords, max_parallelism=8, l
logger=logger
)
new_merged_td.load()
- assert len(new_merged_td) == metadata.total_proof_step_cnt - skipped_cnt, "Filtered dataset is not correct"
+ assert len(new_merged_td) == metadata.total_data_count - skipped_cnt, "Filtered dataset is not correct"
logger.info("Verified the filtered dataset.")
pass
diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py
index d993148..c6e8222 100644
--- a/src/itp_interface/main/install.py
+++ b/src/itp_interface/main/install.py
@@ -1,6 +1,7 @@
import os
import random
import string
+from itp_interface.tools.tactic_parser import build_tactic_parser_if_needed
file_path = os.path.abspath(__file__)
@@ -15,60 +16,23 @@ def install_itp_interface():
print("Installing itp_interface")
itp_dir = os.path.dirname(os.path.dirname(file_path))
tools_dir = os.path.join(itp_dir, "tools")
- repl_dir = os.path.join(tools_dir, "repl")
- assert os.path.exists(repl_dir), f"repl_dir: {repl_dir} does not exist"
- assert os.path.exists(os.path.join(repl_dir, "lean-toolchain")
- ), f"lean-toolchain does not exist in {repl_dir}, build has failed"
- print("repl_dir: ", repl_dir)
+ tactic_parser_dir = os.path.join(tools_dir, "tactic_parser")
+ assert os.path.exists(tactic_parser_dir), f"tactic_parser_dir: {tactic_parser_dir} does not exist"
+ assert os.path.exists(os.path.join(tactic_parser_dir, "lean-toolchain")), f"lean-toolchain does not exist in {tactic_parser_dir}, build has failed"
+ print("tactic_parser_dir: ", tactic_parser_dir)
+ with open(os.path.join(tactic_parser_dir, "lean-toolchain"), "r") as f:
+ lean_toolchain_content = f.read().strip()
+ print("Lean toolchain version for tactic_parser: ", lean_toolchain_content)
+ print(f"LEAN_VERSION set: {os.environ.get('LEAN_VERSION', 'Not Set')}")
print("Building itp_interface")
- os.system(f"cd {repl_dir} && lake build repl")
-
+ build_tactic_parser_if_needed()
def install_lean_repl():
print("Updating Lean")
- itp_dir = os.path.dirname(os.path.dirname(file_path))
- tools_dir = os.path.join(itp_dir, "tools")
- repl_dir = os.path.join(tools_dir, "repl")
- assert os.path.exists(repl_dir), f"repl_dir: {repl_dir} does not exist"
- assert os.path.exists(os.path.join(repl_dir, "lean-toolchain")
- ), f"lean-toolchain does not exist in {repl_dir}, build has failed"
- print("repl_dir: ", repl_dir)
- assert os.system("git --version") == 0, "git is not installed"
- print("[OK] git is installed")
print("Checking if Lean version is set in environment variables as LEAN_VERSION")
print("If not, defaulting to 4.24.0")
lean_version = os.environ.get("LEAN_VERSION", "4.24.0")
- github_repo = "https://github.com/amit9oct/repl.git"
- if lean_version.strip() == "4.24.0":
- print("Lean version is set to 4.24.0, not cloning the REPL")
- else:
- # Clone the repl fresh
- print("Cloning the REPL fresh")
- os.system(f"rm -rf {repl_dir}")
- os.system(f"git clone {github_repo} {repl_dir}")
- # escape the version number
- lean_version_esc = lean_version.replace(".", "\.")
- print("Switching to the right REPL version", lean_version_esc)
- cmd_to_run = f"cd {repl_dir} && git log --grep \"v{lean_version_esc}\" --pretty=\"%h %s\""
- print("Running: ", cmd_to_run)
- output = os.popen(cmd_to_run).read()
- print("Output: ", output)
- if output == "":
- print(
- f"Could not find a commit with message containing {lean_version}")
- print("Probably this version does not exist in the git history of the REPL")
- lean_version = "4.24.0"
- print("Switching to v4.24.0 (latest default)")
- os.system(f"cd {repl_dir} && git checkout main")
- else:
- # Split on first space
- for line in output.split("\n"):
- if line:
- commit, message = line.split(" ", 1)
- if lean_version in message:
- print(f"Switching to commit {commit}")
- os.system(f"cd {repl_dir} && git checkout {commit}")
- break
+
# Make sure that .elan is installed
print("Checking if .elan is installed")
if os.system("elan --version") == 0:
diff --git a/src/itp_interface/main/merge_dataset.py b/src/itp_interface/main/merge_dataset.py
index 754f35c..d50f12c 100644
--- a/src/itp_interface/main/merge_dataset.py
+++ b/src/itp_interface/main/merge_dataset.py
@@ -11,7 +11,7 @@
import typing
import copy
from itp_interface.tools.log_utils import setup_logger
-from itp_interface.tools.training_data import TrainingData, TrainingDataFormat
+from itp_interface.tools.training_data import TrainingData, TheoremProvingTrainingDataFormat
def filter_training_data(training_data: TrainingData, max_distance_to_good: int = 10):
@@ -42,7 +42,7 @@ def _reconstruct_proof_tree(prev_state_id_map: typing.Dict[int, typing.Set[int]]
if end_state not in good_state_ids:
bad_state_ids.add(end_state)
- def _reconstruct_prev_state_id_map(training_datas: typing.List[TrainingDataFormat]) -> typing.Tuple[typing.Dict[int, int], int]:
+ def _reconstruct_prev_state_id_map(training_datas: typing.List[TheoremProvingTrainingDataFormat]) -> typing.Tuple[typing.Dict[int, int], int]:
prev_state_id_map = {}
done_state = None
for training_data in training_datas:
@@ -57,11 +57,11 @@ def _reconstruct_prev_state_id_map(training_datas: typing.List[TrainingDataForma
done_state = end_state
return prev_state_id_map, done_state
- filtered_training_data : typing.List[TrainingDataFormat] = []
- proof_id_maps : typing.Dict[str, typing.List[TrainingDataFormat]] = {}
+ filtered_training_data : typing.List[TheoremProvingTrainingDataFormat] = []
+ proof_id_maps : typing.Dict[str, typing.List[TheoremProvingTrainingDataFormat]] = {}
for idx in range(len(training_data)):
example = training_data[idx]
- training_datas : typing.List[TrainingDataFormat] = proof_id_maps.get(example.proof_id, [])
+ training_datas : typing.List[TheoremProvingTrainingDataFormat] = proof_id_maps.get(example.proof_id, [])
training_datas.append(example)
proof_id_maps[example.proof_id] = training_datas
for proof_id, training_datas in proof_id_maps.items():
@@ -118,12 +118,12 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No
logger.info(f"Inited training data for {dataset}")
metadata = copy.deepcopy(tds[-1].meta)
for td in tds[:-1]:
- metadata.total_proof_step_cnt += td.meta.total_proof_step_cnt
+ metadata.total_data_count += td.meta.total_data_count
metadata.last_training_data += td.meta.last_training_data
logger.info(f"Merged metadata:\n {metadata}")
logger.info(f"Start merging datasets")
cloned_metadata = copy.deepcopy(metadata)
- cloned_metadata.total_proof_step_cnt = 0
+ cloned_metadata.total_data_count = 0
cloned_metadata.last_training_data = 0
merged_td = TrainingData(
folder=output,
@@ -136,7 +136,7 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No
logger.info(f"Start loading {td.folder} ...")
td.load()
- filtered_training_data_points : typing.List[TrainingDataFormat] = None
+ filtered_training_data_points : typing.List[TheoremProvingTrainingDataFormat] = None
if should_filter_data and td.folder != datasets[0]:
# TODO: move to the right location
filtered_training_data_points = filter_training_data(td)
@@ -166,7 +166,7 @@ def merge_datasets(datasets, metafilenames, output, max_parallelism=8, logger=No
)
if not should_filter_data:
new_merged_td.load()
- assert len(new_merged_td) == metadata.total_proof_step_cnt, "Merged dataset is not correct"
+ assert len(new_merged_td) == metadata.total_data_count, "Merged dataset is not correct"
assert new_merged_td.meta.last_training_data == metadata.last_training_data, "Merged dataset is not correct"
assert new_merged_td.meta.last_proof_id == metadata.last_proof_id, "Merged dataset is not correct"
logger.info("Merged dataset is correct.")
diff --git a/src/itp_interface/main/run_tool.py b/src/itp_interface/main/run_tool.py
index 7307d27..7fbd7da 100644
--- a/src/itp_interface/main/run_tool.py
+++ b/src/itp_interface/main/run_tool.py
@@ -13,7 +13,6 @@
import numpy as np
import yaml
import uuid
-import threading
from concurrent.futures import ThreadPoolExecutor
# Conditional Ray import
@@ -34,15 +33,14 @@
from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform
from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform
from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform
+from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform
from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform
from itp_interface.tools.run_data_generation_transforms import RunDataGenerationTransforms
from itp_interface.tools.log_utils import setup_logger
-from itp_interface.main.config import Experiments, EvalRunCheckpointInfo, TransformType, parse_config
-from itp_interface.tools.isabelle_executor import IsabelleExecutor
+from itp_interface.main.config import EvalFile, ExtractFile, Experiments, EvalRunCheckpointInfo, EvalBenchmark, TransformType, parse_config
from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor
from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor
from itp_interface.tools.dynamic_lean4_proof_exec import DynamicProofExecutor as DynamicLean4ProofExecutor
-from itp_interface.tools.dynamic_isabelle_proof_exec import DynamicProofExecutor as DynamicIsabelleProofExecutor
from itp_interface.tools.coq_executor import get_all_lemmas_in_file as get_all_lemmas_coq
from itp_interface.tools.lean4_sync_executor import get_all_theorems_in_file as get_all_lemmas_lean4, get_fully_qualified_theorem_name as get_fully_qualified_theorem_name_lean4, get_theorem_name_resembling as get_theorem_name_resembling_lean4
from itp_interface.tools.isabelle_executor import get_all_lemmas_in_file as get_all_lemmas_isabelle
@@ -122,6 +120,16 @@ def _get_all_lemmas_impl(
logger.info(f"Discovered {len(lemmas_to_prove)} lemmas")
return lemmas_to_prove
+
+def _get_all_lean_files_in_folder_recursively(
+ project_folder: str) -> typing.List[str]:
+ lean_files = []
+ for root, dirs, files in os.walk(project_folder):
+ for file in files:
+ if file.endswith(".lean"):
+ lean_files.append(os.path.join(root, file))
+ return lean_files
+
# Create Ray remote version if Ray is available
if HAS_RAY:
get_all_lemmas = ray.remote(num_cpus=0.5)(_get_all_lemmas_impl)
@@ -139,6 +147,7 @@ def partition_data(project_to_theorems: typing.Dict[str, typing.Dict[str, typing
# Go over each project and classify into three categories
proj_file_to_theorems_named_tuple = typing.NamedTuple("proj_file_to_theorems", [("project", str), ("file", str), ("theorems", typing.List[str])])
proj_file_thms = []
+ logger.info(f"Partitioning: {project_to_theorems}")
for project, file_to_theorems in project_to_theorems.items():
for file, theorems in file_to_theorems.items():
# Generate a random number between 0 and 1
@@ -211,7 +220,8 @@ def partition_data(project_to_theorems: typing.Dict[str, typing.Dict[str, typing
logger.info(f"Actual division Train: {train_cnt}, Eval: {eval_cnt}, Test: {test_cnt}")
return train_project_to_theorems, eval_project_to_theorems, test_project_to_theorems
-def create_yaml(project_to_theorems, name, language, output_file):
+def create_yaml(project_to_theorems, name, eval_benchmark: EvalBenchmark, output_file):
+ language = eval_benchmark.language
data = {
"name": name,
"num_files": 0,
@@ -219,6 +229,7 @@ def create_yaml(project_to_theorems, name, language, output_file):
"few_shot_data_path_for_retrieval": None,
"few_shot_metadata_filename_for_retrieval": None,
"dfs_data_path_for_retrieval": None,
+ "is_extraction_request": eval_benchmark.is_extraction_request,
"dfs_metadata_filename_for_retrieval": "local.meta.json",
"theorem_cnt": 0,
"datasets": []
@@ -228,12 +239,184 @@ def create_yaml(project_to_theorems, name, language, output_file):
for file_path, theorems in file_dict.items():
data["num_files"] += 1
data["theorem_cnt"] += len(theorems)
- dataset["files"].append({"path": file_path, "theorems": theorems})
+ if eval_benchmark.is_extraction_request:
+ dataset["files"].append({"path": file_path, "declarations": theorems})
+ else:
+ dataset["files"].append({"path": file_path, "theorems": theorems})
data["datasets"].append(dataset)
with open(output_file, 'w') as yaml_file:
yaml.dump(data, yaml_file, sort_keys=False)
+def add_transform(experiment: Experiments, clone_dir: str, resources: list, transforms: list, logger: logging.Logger = None):
+ global ray_resource_pool
+ if experiment.run_settings.transform_type == TransformType.LOCAL:
+ if experiment.benchmark.language == ProofAction.Language.LEAN:
+ transform = LeanLocalDataGenerationTransform(
+ experiment.run_settings.dep_depth,
+ max_search_results=experiment.run_settings.max_search_results,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger)
+ os.makedirs(clone_dir, exist_ok=True)
+ elif experiment.benchmark.language == ProofAction.Language.LEAN4:
+ if experiment.benchmark.is_extraction_request:
+ transform = Local4DataExtractionTransform(
+ experiment.run_settings.dep_depth,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger)
+ else:
+ transform = Local4DataGenerationTransform(
+ experiment.run_settings.dep_depth,
+ max_search_results=experiment.run_settings.max_search_results,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger)
+ clone_dir = None
+ elif experiment.benchmark.language == ProofAction.Language.COQ:
+ only_proof_state = experiment.env_settings.retrieval_strategy == ProofEnvReRankStrategy.NO_RE_RANK
+ transform = CoqLocalDataGenerationTransform(
+ experiment.run_settings.dep_depth,
+ max_search_results=experiment.run_settings.max_search_results,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger,
+ no_dfns=only_proof_state,
+ no_thms=only_proof_state)
+ clone_dir = None
+ elif experiment.benchmark.language == ProofAction.Language.ISABELLE:
+ if HAS_RAY:
+ ray_resource_pool = RayResourcePoolActor.remote(resources)
+ else:
+ # Thread-based resource pool (simplified version)
+ from itp_interface.tools.thread_resource_pool import ThreadResourcePool
+ ray_resource_pool = ThreadResourcePool(resources)
+ transform = IsabelleLocalDataGenerationTransform(
+ experiment.run_settings.dep_depth,
+ max_search_results=experiment.run_settings.max_search_results,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger,
+ resource_pool=ray_resource_pool
+ )
+ clone_dir = None
+ # os.makedirs(clone_dir, exist_ok=True)
+ else:
+ raise ValueError(f"Unexpected language: {experiment.benchmark.language}")
+ transforms.append(transform)
+ else:
+ raise ValueError(f"Unexpected transform_type: {experiment.run_settings.transform_type}")
+ return clone_dir
+
+def get_decl_lemmas_to_parse(
+ experiment: Experiments,
+ lemma_discovery_remotes: list,
+ project_to_theorems: typing.Dict[str, typing.Any],
+ other_args: typing.Dict[str, typing.Any],
+ log_dir: str,
+ logger: logging.Logger):
+ for idx, dataset in enumerate(experiment.benchmark.datasets):
+ if dataset.project not in project_to_theorems:
+ project_to_theorems[dataset.project] = {}
+ other_args[dataset.project] = {}
+ file_to_theorems = project_to_theorems[dataset.project]
+ file_args = other_args[dataset.project]
+ if experiment.benchmark.language == ProofAction.Language.LEAN4 \
+ and experiment.benchmark.is_extraction_request:
+ if len(dataset.files) == 0:
+ # List all the files recursively in the project folder
+ files_in_dataset = _get_all_lean_files_in_folder_recursively(dataset.project)
+ for file_path in files_in_dataset:
+ file_to_theorems[file_path] = ["*"]
+ file_args[file_path] = {}
+ else:
+ for file in dataset.files:
+ if file.path not in file_to_theorems:
+ file_to_theorems[file.path] = []
+ file_args[file.path] = {}
+ decls_or_thms = []
+ assert isinstance(file, ExtractFile)
+ assert experiment.benchmark.is_extraction_request, "Extraction request must be true for ExtractFile"
+ decls_or_thms = file.declarations
+ if isinstance(decls_or_thms, list):
+ file_to_theorems[file.path].extend(decls_or_thms)
+ else:
+ assert isinstance(decls_or_thms, str) and decls_or_thms.strip() == "*", "Only '*' is supported for Lean4 extraction request"
+ file_to_theorems[file.path].append("*")
+ else:
+ for file_idx, file in enumerate(dataset.files):
+ if file.path not in file_to_theorems:
+ file_to_theorems[file.path] = []
+ file_args[file.path] = {}
+ decls_or_thms = []
+ if isinstance(file, EvalFile):
+ assert not experiment.benchmark.is_extraction_request, "Extraction request must be false for EvalFile"
+ decls_or_thms = file.theorems
+ else:
+ assert isinstance(file, ExtractFile)
+ assert experiment.benchmark.is_extraction_request, "Extraction request must be true for ExtractFile"
+ decls_or_thms = file.declarations
+ if isinstance(decls_or_thms, list):
+ # if language is Lean4 then change the theorem names to fully qualified names
+ if experiment.benchmark.language == ProofAction.Language.LEAN4:
+ full_file_path = os.path.join(dataset.project, file.path)
+ if experiment.benchmark.is_extraction_request:
+ theorems_in_file = decls_or_thms
+ else:
+ theorems_in_file = [get_theorem_name_resembling_lean4(full_file_path, theorem, use_cache=True) for theorem in decls_or_thms]
+ else:
+ assert not experiment.benchmark.is_extraction_request, "Extraction request with list of declarations is not supported"
+ theorems_in_file = decls_or_thms
+ file_to_theorems[file.path].extend(theorems_in_file)
+ else:
+ if not experiment.benchmark.is_extraction_request:
+ discover_log_file = os.path.join(log_dir, f"discover{idx}_{file_idx}.log")
+ if HAS_RAY:
+ timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict(
+ project_folder=dataset.project,
+ file_path=os.path.join(dataset.project, file.path),
+ language=experiment.benchmark.language,
+ use_hammer=False,
+ timeout_in_secs=experiment.run_settings.timeout_in_secs,
+ use_human_readable_proof_context=experiment.run_settings.use_human_readable,
+ suppress_error_log=True,
+ always_use_retrieval=False,
+ setup_cmds=experiment.benchmark.setup_cmds,
+ log_file=discover_log_file))
+ timeout_in_secs = experiment.run_settings.timeout_in_secs * 100
+ timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs)
+ lemma_discovery_remotes.append(timed_exec_remote)
+ else:
+ # Thread-based execution
+ lemma_discovery_remotes.append((dataset.project, file.path, discover_log_file))
+ pass
+ if len(lemma_discovery_remotes) > 0:
+ assert not experiment.benchmark.is_extraction_request, "Lemma discovery is not needed for extraction request"
+ if HAS_RAY:
+ lemmas = ray.get(lemma_discovery_remotes)
+ else:
+ # Thread-based lemma discovery
+ with ThreadPoolExecutor(max_workers=experiment.run_settings.pool_size) as executor:
+ futures = []
+ for proj, fpath, log_file in lemma_discovery_remotes:
+ future = executor.submit(_get_all_lemmas_impl,
+ project_folder=proj,
+ file_path=os.path.join(proj, fpath),
+ language=experiment.benchmark.language,
+ use_hammer=False,
+ timeout_in_secs=experiment.run_settings.timeout_in_secs,
+ use_human_readable_proof_context=experiment.run_settings.use_human_readable,
+ suppress_error_log=True,
+ always_use_retrieval=False,
+ setup_cmds=experiment.benchmark.setup_cmds,
+ log_file=log_file)
+ futures.append(future)
+ lemmas = [f.result() for f in futures]
+ _idx = 0
+ for idx, dataset in enumerate(experiment.benchmark.datasets):
+ for file_idx, file in enumerate(dataset.files):
+ if lemmas[_idx] is not None:
+ project_to_theorems[dataset.project][file.path].extend(lemmas[_idx])
+ else:
+ logger.error(f"Discovering lemmas failed because of timeout for {dataset.project}/{file.path}")
+ _idx += 1
+
def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoint_info: EvalRunCheckpointInfo, logger: logging.Logger = None):
global ray_resource_pool
pisa_servers = []
@@ -276,124 +459,19 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
transforms = []
str_time = time.strftime("%Y%m%d-%H%M%S")
clone_dir = os.path.join(experiment.run_settings.output_dir, "clone{}".format(str_time))
- if experiment.run_settings.transform_type == TransformType.LOCAL:
- if experiment.benchmark.language == ProofAction.Language.LEAN:
- transform = LeanLocalDataGenerationTransform(
- experiment.run_settings.dep_depth,
- max_search_results=experiment.run_settings.max_search_results,
- buffer_size=experiment.run_settings.buffer_size,
- logger=logger)
- os.makedirs(clone_dir, exist_ok=True)
- elif experiment.benchmark.language == ProofAction.Language.LEAN4:
- transform = Local4DataGenerationTransform(
- experiment.run_settings.dep_depth,
- max_search_results=experiment.run_settings.max_search_results,
- buffer_size=experiment.run_settings.buffer_size,
- logger=logger)
- clone_dir = None
- elif experiment.benchmark.language == ProofAction.Language.COQ:
- only_proof_state = experiment.env_settings.retrieval_strategy == ProofEnvReRankStrategy.NO_RE_RANK
- transform = CoqLocalDataGenerationTransform(
- experiment.run_settings.dep_depth,
- max_search_results=experiment.run_settings.max_search_results,
- buffer_size=experiment.run_settings.buffer_size,
- logger=logger,
- no_dfns=only_proof_state,
- no_thms=only_proof_state)
- clone_dir = None
- elif experiment.benchmark.language == ProofAction.Language.ISABELLE:
- if HAS_RAY:
- ray_resource_pool = RayResourcePoolActor.remote(resources)
- else:
- # Thread-based resource pool (simplified version)
- from itp_interface.tools.thread_resource_pool import ThreadResourcePool
- ray_resource_pool = ThreadResourcePool(resources)
- transform = IsabelleLocalDataGenerationTransform(
- experiment.run_settings.dep_depth,
- max_search_results=experiment.run_settings.max_search_results,
- buffer_size=experiment.run_settings.buffer_size,
- logger=logger,
- resource_pool=ray_resource_pool
- )
- clone_dir = None
- # os.makedirs(clone_dir, exist_ok=True)
- else:
- raise ValueError(f"Unexpected language: {experiment.benchmark.language}")
- transforms.append(transform)
- else:
- raise ValueError(f"Unexpected transform_type: {experiment.run_settings.transform_type}")
+ clone_dir = add_transform(experiment, clone_dir, resources, transforms, logger)
# Find all the lemmas to prove
project_to_theorems = {}
other_args = {}
lemma_discovery_remotes = []
- for idx, dataset in enumerate(experiment.benchmark.datasets):
- if dataset.project not in project_to_theorems:
- project_to_theorems[dataset.project] = {}
- other_args[dataset.project] = {}
- file_to_theorems = project_to_theorems[dataset.project]
- file_args = other_args[dataset.project]
- for file_idx, file in enumerate(dataset.files):
- if file.path not in file_to_theorems:
- file_to_theorems[file.path] = []
- file_args[file.path] = {}
- if isinstance(file.theorems, list):
- # if language is Lean4 then change the theorem names to fully qualified names
- if experiment.benchmark.language == ProofAction.Language.LEAN4:
- full_file_path = os.path.join(dataset.project, file.path)
- theorems_in_file = [get_theorem_name_resembling_lean4(full_file_path, theorem, use_cache=True) for theorem in file.theorems]
- else:
- theorems_in_file = file.theorems
- file_to_theorems[file.path].extend(theorems_in_file)
- else:
- discover_log_file = os.path.join(log_dir, f"discover{idx}_{file_idx}.log")
- if HAS_RAY:
- timed_exec = TimedRayExec.remote(get_all_lemmas, kwargs=dict(
- project_folder=dataset.project,
- file_path=os.path.join(dataset.project, file.path),
- language=experiment.benchmark.language,
- use_hammer=False,
- timeout_in_secs=experiment.run_settings.timeout_in_secs,
- use_human_readable_proof_context=experiment.run_settings.use_human_readable,
- suppress_error_log=True,
- always_use_retrieval=False,
- setup_cmds=experiment.benchmark.setup_cmds,
- log_file=discover_log_file))
- timeout_in_secs = experiment.run_settings.timeout_in_secs * 100
- timed_exec_remote = timed_exec.execute_with_timeout.remote(timeout=timeout_in_secs)
- lemma_discovery_remotes.append(timed_exec_remote)
- else:
- # Thread-based execution
- lemma_discovery_remotes.append((dataset.project, file.path, discover_log_file))
- pass
- if len(lemma_discovery_remotes) > 0:
- if HAS_RAY:
- lemmas = ray.get(lemma_discovery_remotes)
- else:
- # Thread-based lemma discovery
- with ThreadPoolExecutor(max_workers=experiment.run_settings.pool_size) as executor:
- futures = []
- for proj, fpath, log_file in lemma_discovery_remotes:
- future = executor.submit(_get_all_lemmas_impl,
- project_folder=proj,
- file_path=os.path.join(proj, fpath),
- language=experiment.benchmark.language,
- use_hammer=False,
- timeout_in_secs=experiment.run_settings.timeout_in_secs,
- use_human_readable_proof_context=experiment.run_settings.use_human_readable,
- suppress_error_log=True,
- always_use_retrieval=False,
- setup_cmds=experiment.benchmark.setup_cmds,
- log_file=log_file)
- futures.append(future)
- lemmas = [f.result() for f in futures]
- _idx = 0
- for idx, dataset in enumerate(experiment.benchmark.datasets):
- for file_idx, file in enumerate(dataset.files):
- if lemmas[_idx] is not None:
- project_to_theorems[dataset.project][file.path].extend(lemmas[_idx])
- else:
- logger.error(f"Discovering lemmas failed because of timeout for {dataset.project}/{file.path}")
- _idx += 1
+ get_decl_lemmas_to_parse(
+ experiment,
+ lemma_discovery_remotes,
+ project_to_theorems,
+ other_args,
+ log_dir,
+ logger
+ )
data_transform = RunDataGenerationTransforms(transforms,
log_dir,
save_intermidiat_transforms=len(transforms) > 1 or \
@@ -409,7 +487,7 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
# dump a yaml file with the partition
partition_name = f"{experiment.benchmark.name}_{dataset_partition}"
yaml_file = os.path.join(new_output_dir, f"{partition_name}.yaml")
- create_yaml(partition_project_to_theorems, partition_name, experiment.benchmark.language, yaml_file)
+ create_yaml(partition_project_to_theorems, partition_name, experiment.benchmark, yaml_file)
if len(partition_project_to_theorems) == 0:
logger.info(f"==============================>No projects to process for {dataset_partition}<==============================")
continue
diff --git a/src/itp_interface/retrieval/coq_bm25_reranker.py b/src/itp_interface/retrieval/coq_bm25_reranker.py
index 3a23d13..146b53b 100644
--- a/src/itp_interface/retrieval/coq_bm25_reranker.py
+++ b/src/itp_interface/retrieval/coq_bm25_reranker.py
@@ -13,7 +13,7 @@
from rank_bm25 import BM25Okapi
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.training_data import TrainingData
-from itp_interface.tools.training_data_format import TrainingDataFormat
+from itp_interface.tools.training_data_format import TheoremProvingTrainingDataFormat
from itp_interface.retrieval.abstraction import ReRanker
class CoqBm25ReRanker(ReRanker):
@@ -106,7 +106,7 @@ def load(self) -> None:
self.logger.info(f"BM25 initialized.")
self._loaded = True
- def find_relevant_training_data(self, query: str, num_results: int = 1) -> typing.List[typing.Tuple[float, TrainingDataFormat]]:
+ def find_relevant_training_data(self, query: str, num_results: int = 1) -> typing.List[typing.Tuple[float, TheoremProvingTrainingDataFormat]]:
assert self.is_loaded
query_tokens = list(CoqExecutor.tokenize(query))
scores = self.bm25.get_scores(query_tokens)
diff --git a/src/itp_interface/rl/proof_state.py b/src/itp_interface/rl/proof_state.py
index b0538ed..9c805c7 100644
--- a/src/itp_interface/rl/proof_state.py
+++ b/src/itp_interface/rl/proof_state.py
@@ -12,14 +12,14 @@
from itp_interface.tools.dynamic_isabelle_proof_exec import DynamicProofExecutor as DynamicIsabelleProofExecutor
from itp_interface.rl.abstraction import State
from itp_interface.rl.proof_action import ProofAction
-from itp_interface.tools.training_data_format import TrainingDataFormat
+from itp_interface.tools.training_data_format import TheoremProvingTrainingDataFormat
from dataclasses_json import dataclass_json
from dataclasses import dataclass
@dataclass_json
@dataclass
class ProofState(State):
- training_data_format: TrainingDataFormat
+ training_data_format: TheoremProvingTrainingDataFormat
was_reset: bool = False
language: ProofAction.Language = ProofAction.Language.COQ
theorem_statement_with_name: typing.Optional[str] = None
@@ -76,7 +76,7 @@ def __ge__(self, __o: object) -> bool:
return self.training_data_format == __o.training_data_format
if self == FailedProofState:
return True
- assert isinstance(self.training_data_format, TrainingDataFormat)
+ assert isinstance(self.training_data_format, TheoremProvingTrainingDataFormat)
if self.language == ProofAction.Language.COQ:
desc_cmp = DynamicCoqProofExecutor.goal_description_compare(self.training_data_format.goal_description, __o.training_data_format.goal_description)
elif self.language == ProofAction.Language.LEAN:
@@ -109,7 +109,7 @@ def __le__(self, __o: object) -> bool:
return self.training_data_format == __o.training_data_format
if __o == FailedProofState:
return True
- assert isinstance(self.training_data_format, TrainingDataFormat)
+ assert isinstance(self.training_data_format, TheoremProvingTrainingDataFormat)
if self.language == ProofAction.Language.COQ:
desc_cmp = DynamicCoqProofExecutor.goal_description_compare(self.training_data_format.goal_description, __o.training_data_format.goal_description)
elif self.language == ProofAction.Language.LEAN:
@@ -127,13 +127,13 @@ def __le__(self, __o: object) -> bool:
def __lt__(self, __o: object) -> bool:
assert isinstance(__o, ProofState)
- assert isinstance(self.training_data_format, TrainingDataFormat)
+ assert isinstance(self.training_data_format, TheoremProvingTrainingDataFormat)
assert self.language == __o.language, f"self.language: {self.language}, __o.language: {__o.language}"
return self.training_data_format != __o.training_data_format and self.training_data_format <= __o.training_data_format
def __gt__(self, __o: object) -> bool:
assert isinstance(__o, ProofState)
- assert isinstance(self.training_data_format, TrainingDataFormat)
+ assert isinstance(self.training_data_format, TheoremProvingTrainingDataFormat)
assert self.language == __o.language, f"self.language: {self.language}, __o.language: {__o.language}"
return self.training_data_format != __o.training_data_format and self.training_data_format >= __o.training_data_format
diff --git a/src/itp_interface/rl/proof_tree.py b/src/itp_interface/rl/proof_tree.py
index c1c256e..9e6618e 100644
--- a/src/itp_interface/rl/proof_tree.py
+++ b/src/itp_interface/rl/proof_tree.py
@@ -9,12 +9,12 @@
from dataclasses_json import dataclass_json
from dataclasses import dataclass, field
from itp_interface.rl.proof_action import ProofAction
-from itp_interface.tools.training_data_format import TrainingDataFormat
+from itp_interface.tools.training_data_format import TheoremProvingTrainingDataFormat
@dataclass_json
@dataclass
class ProofTree(object):
- tactics: typing.List[typing.Tuple[int, TrainingDataFormat]] = field(default_factory=list)
+ tactics: typing.List[typing.Tuple[int, TheoremProvingTrainingDataFormat]] = field(default_factory=list)
actions: typing.List[typing.Optional[ProofAction]] = field(default_factory=list)
def __len__(self):
@@ -23,7 +23,7 @@ def __len__(self):
def __getitem__(self, index):
return self.tactics[index]
- def try_add_tactic(self, line_num, tactic: TrainingDataFormat, force_add: bool = False, action: ProofAction = None):
+ def try_add_tactic(self, line_num, tactic: TheoremProvingTrainingDataFormat, force_add: bool = False, action: ProofAction = None):
# Make sure that the tactic is not more hard than any of the previous tactics
if not force_add:
for _, prev_tactic in self.tactics:
@@ -40,7 +40,7 @@ def try_remove_last_tactic(self):
return line_num, tactic
return None, None
- def _convert_to_str(self, tactic: TrainingDataFormat) -> typing.Tuple[str, list, list]:
+ def _convert_to_str(self, tactic: TheoremProvingTrainingDataFormat) -> typing.Tuple[str, list, list]:
# sort the goals
goal_set = set([goal.goal for goal in tactic.start_goals])
hyp_set = set([hyp for goal in tactic.start_goals for hyp in goal.hypotheses])
@@ -55,7 +55,7 @@ class ProofSearchResult(object):
proof_file: typing.Optional[str]
proof_found: bool
lemma_name: str
- proof_steps: typing.List[TrainingDataFormat]
+ proof_steps: typing.List[TheoremProvingTrainingDataFormat]
proof_time_in_secs: float
inferences_taken: int
possible_failed_paths: int
diff --git a/src/itp_interface/rl/simple_proof_env.py b/src/itp_interface/rl/simple_proof_env.py
index 0431ab6..3451c27 100644
--- a/src/itp_interface/rl/simple_proof_env.py
+++ b/src/itp_interface/rl/simple_proof_env.py
@@ -10,7 +10,7 @@
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.abstraction import State, Action, Env
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
-from itp_interface.tools.training_data_format import TrainingDataFormat
+from itp_interface.tools.training_data_format import TheoremProvingTrainingDataFormat
from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor
from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor
from itp_interface.tools.dynamic_lean4_proof_exec import DynamicProofExecutor as DynamicLean4ProofExecutor
@@ -73,7 +73,7 @@ def __init__(self,
assert isinstance(max_proof_depth, int)
assert isinstance(always_retrieve_thms, bool)
self.dynamic_proof_executor_callback = dynamic_proof_executor_callback
- self._dynamic_proof_executor : typing.Union[DynamicCoqProofExecutor, DynamicLeanProofExecutor, DynamicIsabelleProofExecutor] = None
+ self._dynamic_proof_executor : typing.Union[DynamicCoqProofExecutor, DynamicLeanProofExecutor, DynamicIsabelleProofExecutor, DynamicLean4ProofExecutor] = None
self._loaded = False
self._history : typing.List[typing.Tuple[ProofState, ProofAction, ProofState, float, bool, ProofEnvInfo]] = []
self._name = name
@@ -230,6 +230,18 @@ def step(self, action: Action) -> typing.Tuple[State, Action, State, float, bool
self.inferences_used += 1
return self._history[-1][0], self._history[-1][1], self._history[-1][2], self._history[-1][3], self._history[-1][4], self._history[-1][5]
+ def validate_proof_completion(self, timeout_in_secs = 120, keep_validation_file = False) -> dict[str, typing.Any]:
+ assert self._loaded, "Env not loaded, call reset() first"
+ assert self.done, "Proof is not yet complete, cannot validate"
+ if self.language == ProofAction.Language.LEAN4:
+ assert isinstance(self._dynamic_proof_executor, DynamicLean4ProofExecutor), "Dynamic proof executor must be of type DynamicLean4ProofExecutor"
+ return self._dynamic_proof_executor.validate_proof(
+ timeout_sec=timeout_in_secs,
+ keep_temp_file=keep_validation_file
+ )
+ else:
+ return {}
+
def checkpoint(self):
return super().checkpoint()
@@ -278,7 +290,7 @@ def dump_proof(self, dump_file_name: str = None, additional_info: typing.Dict[st
assert self._loaded, "Env not loaded, call reset() first"
self.goal_end_time = time.time()
self.time_taken = self.goal_end_time - self.goal_start_time
- proof_steps = [TrainingDataFormat(proof_steps=tactic.proof_steps) for _, tactic in self._p_tree.tactics]
+ proof_steps = [TheoremProvingTrainingDataFormat(proof_steps=tactic.proof_steps) for _, tactic in self._p_tree.tactics]
additional_info = additional_info if additional_info is not None else {}
self.proof_search_res = ProofSearchResult(
self._dynamic_proof_executor.main_file,
@@ -336,9 +348,22 @@ def _run_tactic(self, history_idx: int = None):
self._history[history_idx] = (state, action, next_state, reward, done, env_info)
+ def _fix_tactics(self, tactics: typing.List[str], action: ProofAction):
+ if self.language == ProofAction.Language.LEAN4 and len(tactics) > 0:
+ # It is possible that Lean4 modifies the last tactic (especially in case of `have` tactics)
+ modified_last_tactic = self._dynamic_proof_executor.get_last_tactic()
+ if modified_last_tactic is None:
+ return
+ tactics[len(tactics) - 1] = modified_last_tactic
+ if "tactics" in action.kwargs:
+ tactics_in_action = action.kwargs["tactics"]
+ tactics_in_action[len(tactics_in_action) - 1] = modified_last_tactic
+ action.kwargs["tactics"] = tactics_in_action
+
def _run_tactics(self, tactics: typing.List[str], state: ProofState, action: ProofAction, env_info: ProofEnvInfo):
env_info = copy.deepcopy(env_info)
tactic_line_num, ran_successfully = self._dynamic_proof_executor.run_tactics(tactics)
+ self._fix_tactics(tactics, action)
proof_progressed = False
if ran_successfully:
previous_proof_state = state
@@ -354,7 +379,9 @@ def _run_tactics(self, tactics: typing.List[str], state: ProofState, action: Pro
self._possible_failure_paths += 1
assert len(self._p_tree) == self.current_proof_depth, "proof_tree must have the same length as current_depth"
# cancel anything which might got executed
- self._dynamic_proof_executor.cancel_tactic_till_line(tactic_line_num)
+ if self.language != ProofAction.Language.LEAN4:
+ # Lean4 automatically cancels the failed tactic
+ self._dynamic_proof_executor.cancel_tactic_till_line(tactic_line_num)
reward = 0.0
depth_ratio = self.current_proof_depth/self.max_proof_depth
if depth_ratio > 1.0:
diff --git a/src/itp_interface/tools/coq_context_helper.py b/src/itp_interface/tools/coq_context_helper.py
index 4ea55ae..18226de 100644
--- a/src/itp_interface/tools/coq_context_helper.py
+++ b/src/itp_interface/tools/coq_context_helper.py
@@ -8,7 +8,7 @@
import logging
import typing
from itp_interface.tools.coq_executor import CoqExecutor
-from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TheoremProvingTrainingDataFormat
from typing import List
class CoqContextHelper(object):
@@ -113,7 +113,7 @@ def _get_variables_in_hyp(self, hyps: List[str], coq_exec: CoqExecutor) -> typin
variables.update(possible_vars)
return variables
- def _get_changed_goal_idx(self, training_data_point: TrainingDataFormat) -> typing.List[int]:
+ def _get_changed_goal_idx(self, training_data_point: TheoremProvingTrainingDataFormat) -> typing.List[int]:
# Figure out the subset of start goals which were changed
start_goals = dict()
for goal in training_data_point.start_goals:
@@ -168,7 +168,7 @@ def get_local_lemmas(self, coq_executor: CoqExecutor, logger: logging.Logger = N
lemmas.append((lemma_name, lemma_val))
return lemmas
- def set_relevant_defns_in_training_data_point(self, training_data_point: TrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, should_print_symbol: bool = False, only_local: bool = False):
+ def set_relevant_defns_in_training_data_point(self, training_data_point: TheoremProvingTrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, should_print_symbol: bool = False, only_local: bool = False):
logger = logger if logger is not None else self.logger
depth = self.depth if depth is None else depth
unique_defns = {defn: idx for idx, defn in enumerate(training_data_point.all_useful_defns_theorems)}
@@ -210,7 +210,7 @@ def set_relevant_defns_in_training_data_point(self, training_data_point: Trainin
useful_defns = [LemmaRefWithScore(unique_defns[defn], score) for defn, _, score in useful_defns]
goal.relevant_defns = useful_defns
- def set_all_type_matched_query_result(self, training_data_point: TrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, should_print_symbol: bool = False, only_local: bool = False):
+ def set_all_type_matched_query_result(self, training_data_point: TheoremProvingTrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, should_print_symbol: bool = False, only_local: bool = False):
# Use the hypothesis to find the definition
# Recursively find the definition of the definition to a fixed depth
# dump useful_hyps and current stmt into a stack
@@ -265,7 +265,7 @@ def set_all_type_matched_query_result(self, training_data_point: TrainingDataFor
goal.possible_useful_theorems_external = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_external_theorems if score <= CoqContextHelper.max_relevance_score]
goal.possible_useful_theorems_local = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_local_theorems if score <= CoqContextHelper.max_relevance_score]
- def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None, should_print_symbol: bool = False, only_local: bool = False):
+ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TheoremProvingTrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None, should_print_symbol: bool = False, only_local: bool = False):
# Use the hypothesis to find the definition
# Recursively find the definition of the definition to a fixed depth
# dump useful_hyps and current stmt into a stack
@@ -339,7 +339,7 @@ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: s
goal.possible_useful_theorems_local = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_local_theorems if score <= CoqContextHelper.max_relevance_score]
goal.possible_useful_theorems_external = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_external_theorems if score <= CoqContextHelper.max_relevance_score]
- def set_local_thms_dfns(self, training_data_point: TrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None):
+ def set_local_thms_dfns(self, training_data_point: TheoremProvingTrainingDataFormat, coq_executor: CoqExecutor, logger: logging.Logger = None):
local_lemmas = self.get_local_lemmas(coq_executor, logger)
unique_thms = {defn.lemma_name: idx for idx, defn in enumerate(training_data_point.all_useful_defns_theorems)}
useful_local_theorems = []
diff --git a/src/itp_interface/tools/coq_local_data_generation_transform.py b/src/itp_interface/tools/coq_local_data_generation_transform.py
index 00936a1..2f90dd3 100644
--- a/src/itp_interface/tools/coq_local_data_generation_transform.py
+++ b/src/itp_interface/tools/coq_local_data_generation_transform.py
@@ -9,7 +9,7 @@
import os
from itp_interface.tools.coq_context_helper import CoqContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TrainingDataCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.training_data import TrainingData
@@ -33,13 +33,13 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
def __call__(self, training_data: TrainingData, project_id : str, coq_executor: CoqExecutor, print_coq_executor_callback: typing.Callable[[], CoqExecutor], theorems: typing.List[str] = None, other_args: dict = {}) -> TrainingData:
print_coq_executor = print_coq_executor_callback()
@@ -66,7 +66,7 @@ def __call__(self, training_data: TrainingData, project_id : str, coq_executor:
prev_goal : typing.List[Goal] = [Goal(goal.hypotheses, goal.goal) for goal in prev_goal]
next_goal : typing.List[Goal] = coq_context_helper.get_focussed_goals(coq_executor)
if len(prev_goal) > 0 and cmd_exec != "Proof.":
- training_data_format = TrainingDataFormat(
+ training_data_format = TheoremProvingTrainingDataFormat(
proof_id=proof_id,
all_useful_defns_theorems=[],
start_goals=prev_goal,
diff --git a/src/itp_interface/tools/coq_theorem_proof_pair_generation_transform.py b/src/itp_interface/tools/coq_theorem_proof_pair_generation_transform.py
index 9154e07..e6e3f0b 100644
--- a/src/itp_interface/tools/coq_theorem_proof_pair_generation_transform.py
+++ b/src/itp_interface/tools/coq_theorem_proof_pair_generation_transform.py
@@ -9,7 +9,7 @@
import typing
from itp_interface.tools.coq_context_helper import CoqContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TrainingDataCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.training_data import TrainingData
@@ -29,13 +29,13 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
def __call__(self, training_data: TrainingData, project_id : str, coq_executor: CoqExecutor, print_coq_executor_callback: typing.Callable[[], CoqExecutor]) -> TrainingData:
print_coq_executor = print_coq_executor_callback()
@@ -60,7 +60,7 @@ def __call__(self, training_data: TrainingData, project_id : str, coq_executor:
prev_goal : typing.List[Goal] = [Goal(goal.hypotheses, goal.goal) for goal in prev_goal]
next_goal : typing.List[Goal] = coq_context_helper.get_focussed_goals(coq_executor)
if len(prev_goal) > 0 and cmd_exec != "Proof.":
- training_data_format = TrainingDataFormat(
+ training_data_format = TheoremProvingTrainingDataFormat(
proof_id=proof_id,
all_useful_defns_theorems=[],
start_goals=prev_goal,
diff --git a/src/itp_interface/tools/coq_training_data_generator.py b/src/itp_interface/tools/coq_training_data_generator.py
index 910cea4..371dd14 100644
--- a/src/itp_interface/tools/coq_training_data_generator.py
+++ b/src/itp_interface/tools/coq_training_data_generator.py
@@ -8,7 +8,7 @@
import typing
import logging
import enum
-from itp_interface.tools.training_data_format import MergableCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import MergableCollection, TheoremProvingTrainingDataFormat
logger = logging.getLogger("CoqTrainingGenerator")
class TrainingDataGenerationType(enum.Enum):
@@ -24,7 +24,7 @@ def __init__(self):
self.logger = None
pass
- def filter_best_context(self, partial_data: TrainingDataFormat) -> TrainingDataFormat:
+ def filter_best_context(self, partial_data: TheoremProvingTrainingDataFormat) -> TheoremProvingTrainingDataFormat:
raise NotImplementedError("retrieve_best_context must be implemented")
def set_logger(self, logger: logging.Logger):
diff --git a/src/itp_interface/tools/dynamic_coq_proof_exec.py b/src/itp_interface/tools/dynamic_coq_proof_exec.py
index 9e0814e..ea0c578 100644
--- a/src/itp_interface/tools/dynamic_coq_proof_exec.py
+++ b/src/itp_interface/tools/dynamic_coq_proof_exec.py
@@ -10,7 +10,7 @@
import os
import enum
import logging
-from itp_interface.tools.training_data_format import Goal, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat
from itp_interface.tools.coq_parse_utils import CoqLineByLineReader
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.coq_context_helper import CoqContextHelper
@@ -120,43 +120,43 @@ def get_unfocussed_goals(self) -> typing.List[Goal]:
return []
return self.coq_context_helper.get_unfocussed_goals(self)
- def get_current_proof_state_as_training_data(self) -> TrainingDataFormat:
+ def get_current_proof_state_as_training_data(self) -> TheoremProvingTrainingDataFormat:
# get the current goal
if self.needs_cut_close():
current_goals = self.get_unfocussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.UnfocussedGoalsDescription
elif not self.is_in_proof_mode():
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.NotInProofModeDescription
elif self.needs_qed():
current_goals = self.get_focussed_goals()
assert len(current_goals) == 0, "There should be no goals when needs_qed is True"
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.ProofFinishedDescription
else:
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = None
return training_data_format
- def get_all_relevant_thms(self, should_print_symbol: bool = False) -> TrainingDataFormat:
+ def get_all_relevant_thms(self, should_print_symbol: bool = False) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.coq_context_helper.set_all_type_matched_query_result(training_data_format, self, self.logger, should_print_symbol=should_print_symbol)
return training_data_format
- def get_all_relevant_thms_within_local_context(self) -> TrainingDataFormat:
+ def get_all_relevant_thms_within_local_context(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.coq_context_helper.set_local_thms_dfns(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns(self) -> TrainingDataFormat:
+ def get_all_relevant_defns(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.coq_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TrainingDataFormat:
+ def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
if not only_proof_state:
self.coq_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger, should_print_symbol=should_print_symbol, only_local=only_local)
@@ -196,6 +196,11 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
break
return start_line_num, not tactic_failed
+ def get_last_tactic(self) -> typing.Optional[str]:
+ if len(self.run_state.tatics_ran) == 0:
+ return None
+ return self.run_state.tatics_ran[-1]
+
def get_last_exception(self) -> typing.Optional[str]:
last_exception = self.run_state.last_exception
self.run_state.last_exception = None
diff --git a/src/itp_interface/tools/dynamic_isabelle_proof_exec.py b/src/itp_interface/tools/dynamic_isabelle_proof_exec.py
index f18e7d9..cee3aa5 100644
--- a/src/itp_interface/tools/dynamic_isabelle_proof_exec.py
+++ b/src/itp_interface/tools/dynamic_isabelle_proof_exec.py
@@ -11,7 +11,7 @@
import enum
import logging
from itp_interface.rl.proof_action import ProofAction
-from itp_interface.tools.training_data_format import Goal, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat
from itp_interface.tools.isabelle_parse_utils import IsabelleLineByLineReader
from itp_interface.tools.isabelle_executor import IsabelleExecutor, HammerMode
from itp_interface.tools.isabelle_context_helper import IsabelleContextHelper
@@ -122,43 +122,43 @@ def get_unfocussed_goals(self) -> typing.List[Goal]:
return []
return self.isabelle_context_helper.get_unfocussed_goals(self)
- def get_current_proof_state_as_training_data(self) -> TrainingDataFormat:
+ def get_current_proof_state_as_training_data(self) -> TheoremProvingTrainingDataFormat:
# get the current goal
if self.needs_cut_close():
current_goals = self.get_unfocussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.UnfocussedGoalsDescription
elif not self.is_in_proof_mode():
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.NotInProofModeDescription
elif self.needs_qed():
current_goals = self.get_focussed_goals()
assert len(current_goals) == 0, "There should be no goals when needs_qed is True"
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.ProofFinishedDescription
else:
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = None
return training_data_format
- def get_all_relevant_thms(self, should_print_symbol: bool = False) -> TrainingDataFormat:
+ def get_all_relevant_thms(self, should_print_symbol: bool = False) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.isabelle_context_helper.set_all_type_matched_query_result(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_thms_within_local_context(self) -> TrainingDataFormat:
+ def get_all_relevant_thms_within_local_context(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.isabelle_context_helper.set_local_thms_dfns(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns(self) -> TrainingDataFormat:
+ def get_all_relevant_defns(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.isabelle_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TrainingDataFormat:
+ def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
# self.isabelle_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
if not only_proof_state:
@@ -201,6 +201,11 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
break
return start_line_num, not tactic_failed
+ def get_last_tactic(self) -> typing.Optional[str]:
+ if len(self.run_state.tatics_ran) == 0:
+ return None
+ return self.run_state.tatics_ran[-1]
+
def get_last_exception(self) -> typing.Optional[str]:
last_exception = self.run_state.last_exception
self.run_state.last_exception = None
diff --git a/src/itp_interface/tools/dynamic_lean4_proof_exec.py b/src/itp_interface/tools/dynamic_lean4_proof_exec.py
index bb365b2..1d51f67 100644
--- a/src/itp_interface/tools/dynamic_lean4_proof_exec.py
+++ b/src/itp_interface/tools/dynamic_lean4_proof_exec.py
@@ -1,61 +1,22 @@
#!/usr/bin/env python3
import sys
-
-root_dir = f"{__file__.split('itp_interface')[0]}"
-if root_dir not in sys.path:
- sys.path.append(root_dir)
import typing
import os
import copy
import enum
import logging
-from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
-from itp_interface.tools.training_data_format import Goal, TrainingDataFormat
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat
from itp_interface.tools.lean_parse_utils import LeanLineByLineReader
from itp_interface.tools.lean_context_helper import Lean3ContextHelper
from itp_interface.tools.misc_defns import HammerMode
+from itp_interface.tools.iter_helpers import IntertwinedIterator
-class IntertwinedIterator(object):
- def __init__(self, iterator: typing.Optional[typing.Iterator[str]] = None):
- self.base_iterator = iterator
- self.next_instruction: typing.Optional[str] = None
- self.base_iterator_stopped = iterator is None # if the base iterator is None, then it is stopped
-
- def set_next_instruction(self, instruction: str):
- assert self.next_instruction is None, "next_instruction must be None"
- assert instruction is not None, "instruction must not be None"
- self.next_instruction = instruction
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.next_instruction is not None:
- # Return the next instruction if it is set
- next_instruction = self.next_instruction
- self.next_instruction = None
- return next_instruction
- # Otherwise, get the next instruction from the base iterator
- if self.base_iterator is not None and not self.base_iterator_stopped:
- try:
- instruction = next(self.base_iterator)
- return instruction
- except StopIteration:
- self.base_iterator_stopped = True
- raise
- else:
- raise StopIteration()
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.base_iterator is not None:
- self.base_iterator.close()
- pass
-
-class DynamicProofExecutor(Lean4SyncExecutor):
+class DynamicProofExecutor(SimpleLean4SyncExecutor):
class RunState(object):
def __init__(self):
- self.tatics_ran = []
+ self.tactics_ran : typing.List[str] = []
self.last_exception : typing.Optional[str] = None
self.line_tactic_map = {}
self.line_proof_context_map = {}
@@ -87,13 +48,12 @@ def goal_description_compare(description1: str, descripton2: str) -> int:
else:
return -1
-
def __init__(self, coq_context_helper: Lean3ContextHelper, project_folder: str = None, proof_file: str = None, instruction_iter: typing.Optional[str] = None, use_hammer: typing.Union[bool, HammerMode] = False, timeout_in_seconds: int = 60, use_human_readable_proof_context: bool = True, suppress_error_log: bool = True, context_type: ContextType = ContextType.NoContext, keep_local_context = False, enforce_qed: bool = False):
assert proof_file is None or os.path.exists(proof_file), f"Proof file {proof_file} does not exist"
assert coq_context_helper is not None, "coq_context_helper must not be None"
self.proof_file = proof_file
self.context_type = context_type
- self.lean_file_iter = LeanLineByLineReader(proof_file, remove_comments=True, no_strip=True).instruction_step_generator() if proof_file is not None else instruction_iter
+ self.lean_file_iter = LeanLineByLineReader(proof_file, remove_comments=False, no_strip=True).instruction_step_generator() if proof_file is not None else instruction_iter
self.tactic_switch_iterator = IntertwinedIterator(self.lean_file_iter)
self.run_state = DynamicProofExecutor.RunState()
self.logger = None
@@ -126,85 +86,58 @@ def get_unfocussed_goals(self) -> typing.List[Goal]:
return []
return self.lean_context_helper.get_unfocussed_goals(self)
- def get_current_proof_state_as_training_data(self) -> TrainingDataFormat:
+ def get_current_proof_state_as_training_data(self) -> TheoremProvingTrainingDataFormat:
# get the current goal
if self.needs_cut_close():
current_goals = self.get_unfocussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.UnfocussedGoalsDescription
elif not self.is_in_proof_mode():
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.NotInProofModeDescription
elif self.needs_qed():
current_goals = self.get_focussed_goals()
assert len(current_goals) == 0, "There should be no goals when needs_qed is True"
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.ProofFinishedDescription
else:
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
if len(self.lean_error_messages) > 0:
training_data_format.goal_description = '\n'.join(self.lean_error_messages)
else:
training_data_format.goal_description = None
return training_data_format
-
- def get_all_relevant_thms(self) -> TrainingDataFormat:
- training_data_format = self.get_current_proof_state_as_training_data()
- # self.lean_context_helper.set_all_type_matched_query_result(training_data_format, self, self.logger)
- return training_data_format
-
- def get_all_relevant_thms_within_local_context(self) -> TrainingDataFormat:
- training_data_format = self.get_current_proof_state_as_training_data()
- self.lean_context_helper.set_local_thms_dfns(training_data_format, self, self.logger)
- return training_data_format
-
- def get_all_relevant_defns(self) -> TrainingDataFormat:
- training_data_format = self.get_current_proof_state_as_training_data()
- self.lean_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
- return training_data_format
-
- def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TrainingDataFormat:
- training_data_format = self.get_current_proof_state_as_training_data()
- # self.lean_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
- # if not only_proof_state:
- # self.lean_context_helper.set_all_type_matched_query_result(training_data_format, self, self.logger)
- return training_data_format
-
- def run_cmds(self, cmds: typing.List[str], raise_exception=False) -> typing.Tuple[int, bool]:
- cmd_failed = False
- start_line_num = self.line_num
- for cmd in cmds:
- self.tactic_switch_iterator.set_next_instruction(cmd)
- try:
- self.run_next()
- except Exception:
- self.line_num -= 1
- cmd_failed = True
- if raise_exception:
- raise
- else:
- break
- return start_line_num, not cmd_failed
def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
tactic_failed = False
start_line_num = self.line_num
- self.run_state.line_tactic_map[self.line_num] = len(self.run_state.tatics_ran)
+ self.run_state.line_tactic_map[self.line_num] = len(self.run_state.tactics_ran)
self.run_state.line_proof_context_map[self.line_num] = copy.deepcopy(self.proof_context)
for tactic in tactics:
self.tactic_switch_iterator.set_next_instruction(tactic)
self.run_next()
- self.run_state.tatics_ran.append(tactic)
+ self.run_state.tactics_ran.append(tactic)
self.run_state.line_proof_context_map[self.line_num] = copy.deepcopy(self.proof_context)
if len(self.lean_error_messages) > 0:
current_thm_name = self.get_lemma_name_if_running()
assert current_thm_name is not None, "current_thm_name must not be None"
tactic_failed = True
self.run_state.last_exception = '\n'.join(self.lean_error_messages)
+ # Cancel the last tactic
+ self.cancel_tactic_till_line(start_line_num, no_backtracking=True)
+ if self._last_tactic_was_modified:
+ tactics_in_order = self._get_tactics_in_sorted_order()
+ assert len(tactics_in_order) > 0, "tactics_in_order must not be empty"
+ self.run_state.tactics_ran[-1] = tactics_in_order[-1][1]
return start_line_num, not tactic_failed
+ def get_last_tactic(self) -> typing.Optional[str]:
+ if len(self.run_state.tactics_ran) == 0:
+ return None
+ return self.run_state.tactics_ran[-1]
+
def get_last_exception(self) -> typing.Optional[str]:
last_exception = self.run_state.last_exception
self.run_state.last_exception = None
@@ -214,15 +147,13 @@ def skip_to_theorem(self, theorem_name: str):
self._skip_to_theorem(theorem_name)
# [TODO] change this for bactracking
- def cancel_tactic_till_line(self, tactic_line_num: int) -> bool:
+ def cancel_tactic_till_line(self, tactic_line_num: int, no_backtracking: bool = False) -> bool:
assert tactic_line_num <= self.line_num, "tactic_line_num must be <= self.line_num"
assert tactic_line_num >= 0, "tactic_line_num must be >= 0"
cancelled_some_tactics = False
if tactic_line_num < self.line_num:
- self._lines_executed = self._lines_executed[:tactic_line_num]
state_num = self.run_state.line_tactic_map[tactic_line_num]
- self.run_state.tatics_ran = self.run_state.tatics_ran[:state_num]
- self.proof_context = self.run_state.line_proof_context_map[tactic_line_num]
+ self.run_state.tactics_ran = self.run_state.tactics_ran[:state_num]
line_tactic_map_keys = list(self.run_state.line_tactic_map.keys())
for line_num in line_tactic_map_keys:
if line_num >= tactic_line_num:
@@ -231,7 +162,9 @@ def cancel_tactic_till_line(self, tactic_line_num: int) -> bool:
for line_num in line_proof_context_map_keys:
if line_num >= tactic_line_num:
del self.run_state.line_proof_context_map[line_num]
- self.line_num = tactic_line_num
- cancelled_some_tactics = self._backtrack_tactic_line(tactic_line_num)
- self._proof_running = self.proof_context is not None
+ if not no_backtracking:
+ self.proof_context = self.run_state.line_proof_context_map[tactic_line_num]
+ self.line_num = tactic_line_num
+ cancelled_some_tactics = self._backtrack_tactic_line(tactic_line_num)
+ self._proof_running = self.proof_context is not None
return cancelled_some_tactics
\ No newline at end of file
diff --git a/src/itp_interface/tools/dynamic_lean_proof_exec.py b/src/itp_interface/tools/dynamic_lean_proof_exec.py
index c27f484..f0bfea3 100644
--- a/src/itp_interface/tools/dynamic_lean_proof_exec.py
+++ b/src/itp_interface/tools/dynamic_lean_proof_exec.py
@@ -11,7 +11,7 @@
import enum
import logging
from itp_interface.tools.lean_cmd_executor import Lean3Executor
-from itp_interface.tools.training_data_format import Goal, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat
from itp_interface.tools.lean_parse_utils import LeanLineByLineReader
from itp_interface.tools.lean_context_helper import Lean3ContextHelper
from itp_interface.tools.misc_defns import HammerMode
@@ -122,43 +122,43 @@ def get_unfocussed_goals(self) -> typing.List[Goal]:
return []
return self.lean_context_helper.get_unfocussed_goals(self)
- def get_current_proof_state_as_training_data(self) -> TrainingDataFormat:
+ def get_current_proof_state_as_training_data(self) -> TheoremProvingTrainingDataFormat:
# get the current goal
if self.needs_cut_close():
current_goals = self.get_unfocussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.UnfocussedGoalsDescription
elif not self.is_in_proof_mode():
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.NotInProofModeDescription
elif self.needs_qed():
current_goals = self.get_focussed_goals()
assert len(current_goals) == 0, "There should be no goals when needs_qed is True"
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = DynamicProofExecutor.ProofFinishedDescription
else:
current_goals = self.get_focussed_goals()
- training_data_format = TrainingDataFormat(start_goals=current_goals)
+ training_data_format = TheoremProvingTrainingDataFormat(start_goals=current_goals)
training_data_format.goal_description = None
return training_data_format
- def get_all_relevant_thms(self) -> TrainingDataFormat:
+ def get_all_relevant_thms(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.lean_context_helper.set_all_type_matched_query_result(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_thms_within_local_context(self) -> TrainingDataFormat:
+ def get_all_relevant_thms_within_local_context(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.lean_context_helper.set_local_thms_dfns(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns(self) -> TrainingDataFormat:
+ def get_all_relevant_defns(self) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
self.lean_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
return training_data_format
- def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TrainingDataFormat:
+ def get_all_relevant_defns_and_thms(self, should_print_symbol: bool = False, only_local: bool = False, only_proof_state: bool = False) -> TheoremProvingTrainingDataFormat:
training_data_format = self.get_current_proof_state_as_training_data()
# self.lean_context_helper.set_relevant_defns_in_training_data_point(training_data_format, self, self.logger)
if not only_proof_state:
@@ -198,6 +198,11 @@ def run_tactics(self, tactics: typing.List[str]) -> typing.Tuple[int, bool]:
self.run_state.last_exception = '\n'.join(self.lean_error_messages)
return start_line_num, not tactic_failed
+ def get_last_tactic(self) -> typing.Optional[str]:
+ if len(self.run_state.tatics_ran) == 0:
+ return None
+ return self.run_state.tatics_ran[-1]
+
def get_last_exception(self) -> typing.Optional[str]:
last_exception = self.run_state.last_exception
self.run_state.last_exception = None
diff --git a/src/itp_interface/tools/isabelle_context_helper.py b/src/itp_interface/tools/isabelle_context_helper.py
index 6bb11a1..bd8117b 100644
--- a/src/itp_interface/tools/isabelle_context_helper.py
+++ b/src/itp_interface/tools/isabelle_context_helper.py
@@ -8,7 +8,7 @@
import logging
import typing
from itp_interface.tools.isabelle_executor import IsabelleExecutor
-from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TheoremProvingTrainingDataFormat
from typing import List
class IsabelleContextHelper(object):
@@ -39,11 +39,11 @@ def get_local_lemmas(self, isabelle_executor: IsabelleExecutor, logger: logging.
# Search is not supported as of now
raise Exception("Search is not supported in Isabelle as of now")
- def set_relevant_defns_in_training_data_point(self, training_data_point: TrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None):
+ def set_relevant_defns_in_training_data_point(self, training_data_point: TheoremProvingTrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None):
# Search is not supported as of now
raise Exception("Search is not supported in Isabelle as of now")
- def set_all_type_matched_query_result(self, training_data_point: TrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None):
+ def set_all_type_matched_query_result(self, training_data_point: TheoremProvingTrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None):
unique_thms = {defn.lemma_name: idx for idx, defn in enumerate(training_data_point.all_useful_defns_theorems)}
# query = training_data_point.get_human_readable_serialized_goal(idx, skip_special_tokens=True)
relevant_thms = isabelle_executor.search_type_matching_defns("") # Here the search simply returns everything
@@ -57,10 +57,10 @@ def set_all_type_matched_query_result(self, training_data_point: TrainingDataFor
goal.possible_useful_theorems_external = [LemmaRefWithScore(unique_thms[thm.name], 1.0) for thm in relevant_thms]
goal.possible_useful_theorems_local = []
- def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
+ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TheoremProvingTrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
# Search is not supported as of now
raise Exception("Search is not supported in Isabelle as of now")
- def set_local_thms_dfns(self, training_data_point: TrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None):
+ def set_local_thms_dfns(self, training_data_point: TheoremProvingTrainingDataFormat, isabelle_executor: IsabelleExecutor, logger: logging.Logger = None):
# Search is not supported as of now
raise Exception("Search is not supported in Isabelle as of now")
\ No newline at end of file
diff --git a/src/itp_interface/tools/isabelle_local_data_generation_transform.py b/src/itp_interface/tools/isabelle_local_data_generation_transform.py
index 9186bda..7ec7da3 100644
--- a/src/itp_interface/tools/isabelle_local_data_generation_transform.py
+++ b/src/itp_interface/tools/isabelle_local_data_generation_transform.py
@@ -10,7 +10,7 @@
from itp_interface.tools.isabelle_executor import IsabelleExecutor
from itp_interface.tools.isabelle_context_helper import IsabelleContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TrainingDataCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
from itp_interface.tools.training_data import TrainingData
# See this for running transformation on AFP
@@ -35,13 +35,13 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
def __call__(self, training_data: TrainingData, project_id : str, isabelle_executor: IsabelleExecutor, print_coq_executor_callback: typing.Callable[[], IsabelleExecutor], theorems: typing.List[str] = None, other_args: dict = {}) -> TrainingData:
print_isabelle_executor = print_coq_executor_callback()
@@ -68,7 +68,7 @@ def __call__(self, training_data: TrainingData, project_id : str, isabelle_execu
prev_goal : typing.List[Goal] = [Goal(goal.hypotheses, goal.goal) for goal in prev_goal]
next_goal : typing.List[Goal] = isabelle_context_helper.get_focussed_goals(isabelle_executor)
if len(prev_goal) > 0:
- training_data_format = TrainingDataFormat(
+ training_data_format = TheoremProvingTrainingDataFormat(
proof_id=proof_id,
all_useful_defns_theorems=[],
start_goals=prev_goal,
diff --git a/src/itp_interface/tools/iter_helpers.py b/src/itp_interface/tools/iter_helpers.py
new file mode 100644
index 0000000..b80035b
--- /dev/null
+++ b/src/itp_interface/tools/iter_helpers.py
@@ -0,0 +1,67 @@
+import typing
+from abc import ABC, abstractmethod
+
+class ClonableIterator(ABC):
+ @abstractmethod
+ def __iter__(self) -> typing.Iterator[str]:
+ pass
+
+ @abstractmethod
+ def __next__(self) -> str:
+ pass
+
+ @abstractmethod
+ def set_to_index(self, index: int):
+ pass
+
+ @abstractmethod
+ def clone(self) -> 'ClonableIterator':
+ pass
+
+
+class IntertwinedIterator(ClonableIterator):
+ def __init__(self, iterator: typing.Optional[ClonableIterator] = None):
+ self.base_iterator = iterator
+ self.next_instruction: typing.Optional[str] = None
+ self.base_iterator_stopped = iterator is None # if the base iterator is None, then it is stopped
+
+ def set_next_instruction(self, instruction: str):
+ assert self.next_instruction is None, "next_instruction must be None"
+ assert instruction is not None, "instruction must not be None"
+ self.next_instruction = instruction
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.next_instruction is not None:
+ # Return the next instruction if it is set
+ next_instruction = self.next_instruction
+ self.next_instruction = None
+ return next_instruction
+ # Otherwise, get the next instruction from the base iterator
+ if self.base_iterator is not None and not self.base_iterator_stopped:
+ try:
+ instruction = next(self.base_iterator)
+ return instruction
+ except StopIteration:
+ self.base_iterator_stopped = True
+ raise
+ else:
+ raise StopIteration()
+
+ def set_to_index(self, index: int):
+ if self.base_iterator is not None:
+ self.base_iterator.set_to_index(index)
+ self.base_iterator_stopped = False
+
+ def clone(self) -> 'IntertwinedIterator':
+ cloned_iterator = IntertwinedIterator()
+ if self.base_iterator is not None:
+ cloned_iterator.base_iterator = self.base_iterator.clone()
+ cloned_iterator.base_iterator_stopped = self.base_iterator_stopped
+ cloned_iterator.next_instruction = self.next_instruction
+ return cloned_iterator
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
\ No newline at end of file
diff --git a/src/itp_interface/tools/lean4_context_helper.py b/src/itp_interface/tools/lean4_context_helper.py
index 3ac1383..6a98597 100644
--- a/src/itp_interface/tools/lean4_context_helper.py
+++ b/src/itp_interface/tools/lean4_context_helper.py
@@ -8,7 +8,7 @@
import logging
import typing
from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor, ProofContext
-from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TheoremProvingTrainingDataFormat
from typing import List
class Lean4ContextHelper(object):
@@ -44,11 +44,11 @@ def get_local_lemmas(self, lean_executor: Lean4SyncExecutor, logger: logging.Log
raise Exception("Search is not supported in Lean as of now")
- def set_relevant_defns_in_training_data_point(self, training_data_point: TrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None):
+ def set_relevant_defns_in_training_data_point(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
- def set_all_type_matched_query_result(self, training_data_point: TrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None):
+ def set_all_type_matched_query_result(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None):
unique_thms = {defn.lemma_name: idx for idx, defn in enumerate(training_data_point.all_useful_defns_theorems)}
# query = training_data_point.get_human_readable_serialized_goal(idx, skip_special_tokens=True)
relevant_thms = self.search_executor.search_type_matching_defns("") # Here the search simply returns everything
@@ -63,10 +63,10 @@ def set_all_type_matched_query_result(self, training_data_point: TrainingDataFor
goal.possible_useful_theorems_external = [LemmaRefWithScore(unique_thms[f"{thm.namespace}.{thm.name}"], 1.0) for thm in relevant_thms]
goal.possible_useful_theorems_local = []
- def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
+ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
- def set_local_thms_dfns(self, training_data_point: TrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None):
+ def set_local_thms_dfns(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean4SyncExecutor, logger: logging.Logger = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
diff --git a/src/itp_interface/tools/lean4_local_data_extraction_transform.py b/src/itp_interface/tools/lean4_local_data_extraction_transform.py
new file mode 100644
index 0000000..9205638
--- /dev/null
+++ b/src/itp_interface/tools/lean4_local_data_extraction_transform.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+dir_name = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+root_dir = os.path.abspath(dir_name)
+if root_dir not in sys.path:
+ sys.path.append(root_dir)
+import typing
+import uuid
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
+from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, ExtractionDataCollection, TheoremProvingTrainingDataFormat
+from itp_interface.tools.training_data import TrainingData, DataLayoutFormat
+
+class Local4DataExtractionTransform(GenericTrainingDataGenerationTransform):
+ def __init__(self,
+ depth = None,
+ max_search_results = None,
+ buffer_size : int = 10000,
+ logger = None,
+ max_parallelism : int = 4):
+ super().__init__(TrainingDataGenerationType.LOCAL, buffer_size, logger)
+ self.depth = depth
+ self.max_search_results = max_search_results
+ self.max_parallelism = max_parallelism
+
+ def get_meta_object(self) -> TrainingDataMetadataFormat:
+ return TrainingDataMetadataFormat(
+ training_data_buffer_size=self.buffer_size,
+ data_filename_prefix="extraction_data_",
+ lemma_ref_filename_prefix="extraction_lemma_refs_")
+
+ def get_data_collection_object(self) -> MergableCollection:
+ return ExtractionDataCollection()
+
+ def load_meta_from_file(self, file_path) -> MergableCollection:
+ return TrainingDataMetadataFormat.load_from_file(file_path)
+
+ def load_data_from_file(self, file_path) -> MergableCollection:
+ return ExtractionDataCollection.load_from_file(file_path, self.logger)
+
+ def __call__(self,
+ training_data: TrainingData,
+ project_id : str,
+ lean_executor: SimpleLean4SyncExecutor,
+ print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor],
+ theorems: typing.List[str] = None,
+ other_args: dict = {}) -> TrainingData:
+ file_namespace = lean_executor.main_file.replace('/', '.')
+ self.logger.info(f"=========================Processing {file_namespace}=========================")
+ theorem_id = str(uuid.uuid4())
+ if isinstance(theorems, list) and len(theorems) == 1 and theorems[0] == "*":
+ theorems = None
+ else:
+ theorems = set(theorems) if theorems is not None else None
+ cnt = 0
+ temp_dir = os.path.join(training_data.folder, "temp")
+ os.makedirs(temp_dir, exist_ok=True)
+ json_output_path = f"{temp_dir}/{file_namespace.replace('.', '_')}.lean.deps.json"
+ file_dep_analyses = lean_executor.extract_all_theorems_and_definitions(json_output_path=json_output_path)
+ self.logger.info(f"Extracted {len(file_dep_analyses)} FileDependencyAnalysis objects from {file_namespace}")
+ self.logger.info(f"file_dep_analyses: {file_dep_analyses}")
+ assert len(file_dep_analyses) == 1, "Expected exactly one FileDependencyAnalysis object"
+ file_dep_analysis = file_dep_analyses[0]
+ for decls in file_dep_analysis.declarations:
+ line_info = decls.decl_info
+ if theorems is not None and line_info.name not in theorems:
+ continue
+ training_data.merge(decls)
+ cnt += 1
+ training_data.meta.last_proof_id = theorem_id
+ self.logger.info(f"===============Finished processing {file_namespace}=====================")
+ self.logger.info(f"Total declarations processed in this transform: {cnt}")
+ return training_data
+
+
+if __name__ == "__main__":
+ import os
+ import logging
+ import time
+ os.chdir(root_dir)
+ # project_dir = 'data/test/lean4_proj/'
+ project_dir = 'data/test/Mathlib'
+ # file_name = 'data/test/lean4_proj/Lean4Proj/Basic.lean'
+ file_name = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Algebra/Divisibility/Basic.lean'
+ project_id = project_dir.replace('/', '.')
+ time_str = time.strftime("%Y%m%d-%H%M%S")
+ output_path = f".log/local_data_generation_transform/data/{time_str}"
+ log_path = f".log/local_data_generation_transform/log/{time_str}"
+ log_file = f"{log_path}/local_data_generation_transform-{time_str}.log"
+ os.makedirs(output_path, exist_ok=True)
+ os.makedirs(log_path, exist_ok=True)
+ logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
+ logger = logging.getLogger(__name__)
+ def _print_lean_executor_callback():
+ search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir)
+ search_lean_exec.__enter__()
+ return search_lean_exec
+ transform = Local4DataExtractionTransform(0, buffer_size=1000)
+ training_data = TrainingData(
+ output_path,
+ "training_metadata.json",
+ training_meta=transform.get_meta_object(),
+ logger=logger,
+ layout=DataLayoutFormat.DECLARATION_EXTRACTION)
+ with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec:
+ transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=["*"])
+ save_info = training_data.save()
+ logger.info(f"Saved training data to {save_info}")
\ No newline at end of file
diff --git a/src/itp_interface/tools/lean4_local_data_generation_transform.py b/src/itp_interface/tools/lean4_local_data_generation_transform.py
index 46e4cd6..4854b02 100644
--- a/src/itp_interface/tools/lean4_local_data_generation_transform.py
+++ b/src/itp_interface/tools/lean4_local_data_generation_transform.py
@@ -6,10 +6,10 @@
sys.path.append(root_dir)
import typing
import uuid
-from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.lean4_context_helper import Lean4ContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, TrainingDataCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
from itp_interface.tools.training_data import TrainingData
class Local4DataGenerationTransform(GenericTrainingDataGenerationTransform):
@@ -28,15 +28,15 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
- def __call__(self, training_data: TrainingData, project_id : str, lean_executor: Lean4SyncExecutor, print_coq_executor_callback: typing.Callable[[], Lean4SyncExecutor], theorems: typing.List[str] = None, other_args: dict = {}) -> TrainingData:
+ def __call__(self, training_data: TrainingData, project_id : str, lean_executor: SimpleLean4SyncExecutor, print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor], theorems: typing.List[str] = None, other_args: dict = {}) -> TrainingData:
print_lean_executor = print_coq_executor_callback()
lean_context_helper = Lean4ContextHelper(print_lean_executor, self.depth, self.logger)
lean_context_helper.__enter__()
@@ -63,7 +63,7 @@ def __call__(self, training_data: TrainingData, project_id : str, lean_executor:
end_goals = []
if len(start_goals) > 0 and \
(len(start_goals) != len(end_goals) or not all(s_g == e_g for s_g, e_g in zip(start_goals, end_goals))):
- tdf = TrainingDataFormat(
+ tdf = TheoremProvingTrainingDataFormat(
proof_id=proof_id,
all_useful_defns_theorems=[],
start_goals=start_goals,
@@ -107,7 +107,7 @@ def __call__(self, training_data: TrainingData, project_id : str, lean_executor:
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
def _print_lean_executor_callback():
- search_lean_exec = Lean4SyncExecutor(main_file=file_name, project_root=project_dir)
+ search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir)
search_lean_exec.__enter__()
return search_lean_exec
transform = Local4DataGenerationTransform(0, buffer_size=1000)
@@ -116,7 +116,7 @@ def _print_lean_executor_callback():
"training_metadata.json",
training_meta=transform.get_meta_object(),
logger=logger)
- with Lean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec:
+ with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec:
transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=['{"namespace": "Lean4Proj1", "name": "test2"}'])
save_info = training_data.save()
logger.info(f"Saved training data to {save_info}")
\ No newline at end of file
diff --git a/src/itp_interface/tools/lean4_sync_executor.py b/src/itp_interface/tools/lean4_sync_executor.py
index 1e7ba6c..4dc705d 100644
--- a/src/itp_interface/tools/lean4_sync_executor.py
+++ b/src/itp_interface/tools/lean4_sync_executor.py
@@ -129,7 +129,10 @@ def __init__(self,
self._file_handle = None
self._in_tactic_mode = False
self._env_idx_last_thm = None
+ self._debug_traces = []
+ self.debug_enabled = False
self._last_tactics = {}
+ self.possible_proof_tactics = ""
self._last_tactic_line_idx = None
self._error_messages_so_far = set()
self._error_messages_since_last_thm = {}
@@ -185,7 +188,10 @@ def reset(self,
self._in_tactic_mode = False
self._env_idx_last_thm = None
self._last_tactics = {}
+ self.possible_proof_tactics = ""
self._last_tactic_line_idx = None
+ self._debug_traces = []
+ self.debug_enabled = False
self._error_messages_so_far = set()
self._error_messages_since_last_thm = {}
if self._enable_search:
@@ -692,6 +698,7 @@ def _clear_tacitcs(self):
tactics_so_far = sorted(tactics_so_far, key=lambda x: x[0])
tactics_so_far = [v for _, v in tactics_so_far]
self._write_lean_file(self._last_tactic_line_idx, "\n".join(tactics_so_far))
+ self.possible_proof_tactics += "\n".join(tactics_so_far)
self._last_tactics = {}
self._last_tactic_line_idx = None
self._error_messages_since_last_thm = {}
@@ -772,6 +779,8 @@ def _parse_response(self, idx, response, relevant_messages = [], dont_update_err
# Go over all sev after the line number and check if there is an error
for msg_idx, msg in enumerate(messages):
full_error_msg = f"Line: {idx} " + self._get_error_msg(msg_idx, msg)
+ if self.debug_enabled:
+ self._debug_traces.append(f"Debug Trace: Processing message at line {idx}: {full_error_msg}")
unsolved_goal_never_seen_before = not (full_error_msg in self._error_messages_since_last_thm.values())
if msg['severity'] == 'error' and 'pos' in msg and 'endPos' in msg and \
((msg['endPos'] is not None and 'line' in msg['endPos']) or \
@@ -1014,7 +1023,219 @@ def _parse_proof_context(self, proof_goals: list) -> ProofContext:
return ProofContext.empty()
else:
return ProofContext(goals, [], [], [])
-
+
+ def validate_proof_with_lake(self, theorem_name: Optional[str] = None, timeout_sec: int = 30, keep_temp_file: bool = True) -> typing.Dict[str, typing.Any]:
+ """
+ Validate the current proof state by running 'lake lean' on a temporary file.
+ This provides an independent verification without relying on the REPL.
+
+ Args:
+ theorem_name: Name of the theorem to validate (optional, for logging only)
+ timeout_sec: Timeout in seconds for the lake lean process
+ keep_temp_file: If True, keeps the temporary file after validation (default: True)
+
+ Returns:
+ Dictionary with validation results:
+ {
+ 'success': bool, # True if proof is complete with no errors
+ 'compilation_ok': bool, # True if file compiles
+ 'has_sorries': bool, # True if there are unsolved goals (sorries)
+ 'error_message': str, # Error message if any
+ 'errors': list, # List of error details
+ 'lean_code': str, # The code that was validated
+ 'theorem_name': str # Name of theorem being validated
+ }
+ """
+ import subprocess
+ import re
+
+ # Get theorem name for logging/reporting, but don't require it
+ if theorem_name is None:
+ theorem_name = self.curr_lemma_name or "unknown"
+
+ # Create the Lean code with all executed lines up to current point
+ lines_executed_str = '\n'.join(self._lines_executed)
+
+ if not lines_executed_str or not lines_executed_str.strip():
+ return {
+ 'success': False,
+ 'compilation_ok': False,
+ 'has_sorries': False,
+ 'error_message': 'No code to validate',
+ 'errors': [],
+ 'lean_code': '',
+ 'theorem_name': theorem_name,
+ 'full_output': 'No code available to validate',
+ 'stdout': '',
+ 'stderr': '',
+ 'return_code': -1,
+ 'temp_filename': 'N/A',
+ 'temp_file_path': 'N/A',
+ 'temp_file_kept': False,
+ 'debug_traces': list(self._debug_traces),
+ 'possible_proof_tactics': self.possible_proof_tactics
+ }
+
+ # Build the complete Lean code with actual proof tactics
+ # The proof tactics are accumulated in self.possible_proof_tactics
+ actual_proof = "" # Track the actual proof for sorry checking
+ proof_tactics_source = self.possible_proof_tactics
+
+ # If possible_proof_tactics is empty, try to use _last_tactics as fallback
+ if not proof_tactics_source or not proof_tactics_source.strip():
+ if self._last_tactics:
+ # Extract tactics from _last_tactics (same logic as _clear_tacitcs)
+ tactics_so_far = [(k, v) for k, v in self._last_tactics.items()]
+ tactics_so_far = sorted(tactics_so_far, key=lambda x: x[0])
+ tactics_so_far = [v for _, v in tactics_so_far]
+ proof_tactics_source = "\n".join(tactics_so_far)
+
+ # If both are empty, raise an error
+ if not proof_tactics_source or not proof_tactics_source.strip():
+ raise ValueError("No proof tactics available. Neither 'possible_proof_tactics' nor '_last_tactics' contain any proof steps.")
+
+ # Now build the Lean code with the proof tactics
+ if proof_tactics_source and proof_tactics_source.strip():
+ # Find the last ':=' in lines_executed
+ last_assign_idx = lines_executed_str.rfind(':=')
+ if last_assign_idx != -1:
+ # Take everything up to and including the last ':=' from lines_executed
+ code_prefix = lines_executed_str[:last_assign_idx + 2]
+
+ # In proof_tactics_source, find the first ':= by' (with flexible whitespace)
+ # Use regex to find ':=' followed by optional whitespace and 'by'
+ assign_by_match = re.search(r':=\s+by', proof_tactics_source)
+
+ if assign_by_match:
+ # Extract everything after ':= by' (excluding the match itself)
+ match_end_idx = assign_by_match.end()
+ actual_proof = proof_tactics_source[match_end_idx:].strip()
+ lean_code = code_prefix + ' by\n' + actual_proof
+ else:
+ # No ':= by' found, use proof_tactics_source as-is
+ actual_proof = proof_tactics_source.strip()
+ lean_code = code_prefix + ' by\n' + actual_proof
+ else:
+ # No ':=' found, just use lines_executed as-is
+ lean_code = lines_executed_str
+ else:
+ # No proof tactics available, use lines_executed as-is
+ lean_code = lines_executed_str
+
+ # Create a unique temporary file
+ temp_filename = f"validation_{self.ticks}_{self.random_num}.lean"
+ temp_file_path = os.path.join(self.project_root, temp_filename)
+
+ try:
+ # Write the Lean code to the temporary file
+ with open(temp_file_path, 'w') as f:
+ f.write(lean_code)
+
+ # Run lake lean on the file
+ try:
+ result = subprocess.run(
+ ['lake', 'lean', temp_filename],
+ cwd=self.project_root,
+ capture_output=True,
+ text=True,
+ timeout=timeout_sec
+ )
+
+ stdout = result.stdout
+ stderr = result.stderr
+ output = stdout + '\n' + stderr
+
+ except subprocess.TimeoutExpired:
+ # Don't delete temp file on timeout so it can be inspected
+ return {
+ 'success': False,
+ 'compilation_ok': False,
+ 'has_sorries': False,
+ 'error_message': f'Timeout after {timeout_sec} seconds',
+ 'errors': [],
+ 'lean_code': lean_code,
+ 'theorem_name': theorem_name,
+ 'full_output': f'Process timed out after {timeout_sec} seconds',
+ 'stdout': '',
+ 'stderr': '',
+ 'return_code': -1,
+ 'temp_filename': temp_filename,
+ 'temp_file_path': temp_file_path,
+ 'temp_file_kept': True, # Keep file on timeout for debugging
+ 'debug_traces': list(self._debug_traces),
+ 'possible_proof_tactics': self.possible_proof_tactics
+ }
+
+ # Parse the output for errors and warnings
+ errors = []
+ error_pattern = re.compile(r'(\S+):(\d+):(\d+):\s*(warning|error):\s*(.+)')
+
+ for line in output.split('\n'):
+ match = error_pattern.match(line)
+ if match:
+ filename, line_num, col_num, severity, message = match.groups()
+ errors.append({
+ 'file': filename,
+ 'line': int(line_num),
+ 'column': int(col_num),
+ 'severity': severity,
+ 'message': message
+ })
+
+ # Check for 'sorry' only in the actual proof we generated
+ has_sorries = 'sorry' in actual_proof.lower()
+
+ # Only fail on actual errors (not warnings)
+ # Also check for "unsolved goals" in error messages
+ theorem_has_error = False
+ for error in errors:
+ if error['severity'] == 'error':
+ theorem_has_error = True
+ # Also check if the error mentions unsolved goals
+ if 'unsolved goals' in error['message'].lower():
+ has_sorries = True
+
+ # Determine success: compilation ok, no sorries in actual proof, no errors (ignore warnings)
+ compilation_ok = result.returncode == 0
+ success = compilation_ok and not has_sorries and not theorem_has_error
+
+ error_message = ''
+ if not compilation_ok:
+ error_message = 'Compilation failed'
+ elif has_sorries:
+ error_message = 'Proof has unsolved goals (sorries)'
+ elif theorem_has_error:
+ error_message = 'Theorem has errors'
+ else:
+ error_message = 'Proof is complete'
+
+ # Combine full raw output for debugging
+ full_output = f"=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}"
+
+ return {
+ 'success': success,
+ 'compilation_ok': compilation_ok,
+ 'has_sorries': has_sorries,
+ 'error_message': error_message,
+ 'errors': errors,
+ 'lean_code': lean_code,
+ 'return_code': result.returncode,
+ 'stdout': stdout,
+ 'stderr': stderr,
+ 'full_output': full_output,
+ 'theorem_name': theorem_name,
+ 'temp_filename': temp_filename,
+ 'temp_file_path': temp_file_path,
+ 'temp_file_kept': keep_temp_file,
+ 'debug_traces': list(self._debug_traces),
+ 'possible_proof_tactics': self.possible_proof_tactics
+ }
+
+ finally:
+ # Clean up the temporary file only if requested
+ if not keep_temp_file and os.path.exists(temp_file_path):
+ os.remove(temp_file_path)
+
theorem_names_in_file_cache: Dict[str, List[TheoremDetails]] = {}
namespace_regex = r"^namespace[ ]+([\S]+)"
diff --git a/src/itp_interface/tools/lean_context_helper.py b/src/itp_interface/tools/lean_context_helper.py
index 92352b9..cf7ff8f 100644
--- a/src/itp_interface/tools/lean_context_helper.py
+++ b/src/itp_interface/tools/lean_context_helper.py
@@ -8,7 +8,7 @@
import logging
import typing
from itp_interface.tools.lean_cmd_executor import Lean3Executor
-from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, LemmaRefWithScore, LemmaReferences, TheoremProvingTrainingDataFormat
from typing import List
class Lean3ContextHelper(object):
@@ -122,7 +122,7 @@ def get_local_lemmas(self, lean_executor: Lean3Executor, logger: logging.Logger
# lemmas.append((lemma_name, lemma_val))
# return lemmas
- def set_relevant_defns_in_training_data_point(self, training_data_point: TrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None):
+ def set_relevant_defns_in_training_data_point(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
# logger = logger if logger is not None else self.logger
@@ -166,7 +166,7 @@ def set_relevant_defns_in_training_data_point(self, training_data_point: Trainin
# useful_defns = [LemmaRefWithScore(unique_defns[defn], score) for defn, _, score in useful_defns]
# goal.relevant_defns = useful_defns
- def set_all_type_matched_query_result(self, training_data_point: TrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None):
+ def set_all_type_matched_query_result(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None):
unique_thms = {defn.lemma_name: idx for idx, defn in enumerate(training_data_point.all_useful_defns_theorems)}
# query = training_data_point.get_human_readable_serialized_goal(idx, skip_special_tokens=True)
relevant_thms = self.search_executor.search_type_matching_defns("") # Here the search simply returns everything
@@ -234,7 +234,7 @@ def set_all_type_matched_query_result(self, training_data_point: TrainingDataFor
# goal.possible_useful_theorems_external = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_external_theorems if score <= Lean3ContextHelper.max_relevance_score]
# goal.possible_useful_theorems_local = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_local_theorems if score <= Lean3ContextHelper.max_relevance_score]
- def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
+ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: str, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None, depth: int = None, max_search_res: typing.Optional[int] = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
# # Use the hypothesis to find the definition
@@ -310,7 +310,7 @@ def set_useful_defns_theorems_for_training_data_generation(self, current_stmt: s
# goal.possible_useful_theorems_local = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_local_theorems if score <= Lean3ContextHelper.max_relevance_score]
# goal.possible_useful_theorems_external = [LemmaRefWithScore(defn_idx, score) for defn_idx, score in useful_external_theorems if score <= Lean3ContextHelper.max_relevance_score]
- def set_local_thms_dfns(self, training_data_point: TrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None):
+ def set_local_thms_dfns(self, training_data_point: TheoremProvingTrainingDataFormat, lean_executor: Lean3Executor, logger: logging.Logger = None):
# Search is not supported in Lean as of now
raise Exception("Search is not supported in Lean as of now")
# local_lemmas = self.get_local_lemmas(lean_executor, logger)
diff --git a/src/itp_interface/tools/lean_dojo_data_generation_transform.py b/src/itp_interface/tools/lean_dojo_data_generation_transform.py
index 2a10421..018c8cf 100644
--- a/src/itp_interface/tools/lean_dojo_data_generation_transform.py
+++ b/src/itp_interface/tools/lean_dojo_data_generation_transform.py
@@ -12,7 +12,7 @@
from itp_interface.lean_server.lean_context import ProofContext
from itp_interface.lean_server.lean4_utils import Lean4Utils
from itp_interface.tools.training_data import TrainingData
-from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataCollection, TrainingDataFormat, TrainingDataMetadataFormat
+from itp_interface.tools.training_data_format import Goal, MergableCollection, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat, TrainingDataMetadataFormat
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
class LocalDataGenerationTransform(GenericTrainingDataGenerationTransform):
@@ -31,13 +31,13 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
def dump_theorems_from_file(self, file_path: str, output_path: str, output_filename: str, logger = None):
assert file_path.endswith('.json'), f"Invalid file path {file_path}"
@@ -148,7 +148,7 @@ def __call__(self, training_data: TrainingData, project_id : str, executor, prin
raise
if len(start_goals.all_goals) > 0:
# Create a training data object
- training_data_format = TrainingDataFormat(
+ training_data_format = TheoremProvingTrainingDataFormat(
proof_id=theorem_id,
start_goals=[Goal(goal.hypotheses, goal.goal) for goal in start_goals.all_goals],
end_goals=[Goal(goal.hypotheses, goal.goal) for goal in end_goals.all_goals],
diff --git a/src/itp_interface/tools/lean_local_data_generation_transform.py b/src/itp_interface/tools/lean_local_data_generation_transform.py
index 93df1bc..4f25f5b 100644
--- a/src/itp_interface/tools/lean_local_data_generation_transform.py
+++ b/src/itp_interface/tools/lean_local_data_generation_transform.py
@@ -9,7 +9,7 @@
from itp_interface.tools.lean_cmd_executor import Lean3Executor
from itp_interface.tools.lean_context_helper import Lean3ContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TrainingDataCollection, TrainingDataFormat
+from itp_interface.tools.training_data_format import Goal, MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
from itp_interface.tools.training_data import TrainingData
class LocalDataGenerationTransform(GenericTrainingDataGenerationTransform):
@@ -28,13 +28,13 @@ def get_meta_object(self) -> MergableCollection:
return TrainingDataMetadataFormat(training_data_buffer_size=self.buffer_size)
def get_data_collection_object(self) -> MergableCollection:
- return TrainingDataCollection()
+ return TheoremProvingTrainingDataCollection()
def load_meta_from_file(self, file_path) -> MergableCollection:
return TrainingDataMetadataFormat.load_from_file(file_path)
def load_data_from_file(self, file_path) -> MergableCollection:
- return TrainingDataCollection.load_from_file(file_path, self.logger)
+ return TheoremProvingTrainingDataCollection.load_from_file(file_path, self.logger)
def __call__(self, training_data: TrainingData, project_id : str, lean_executor: Lean3Executor, print_coq_executor_callback: typing.Callable[[], Lean3Executor], theorems: typing.List[str] = None, other_args: dict = {}) -> TrainingData:
print_lean_executor = print_coq_executor_callback()
@@ -61,7 +61,7 @@ def __call__(self, training_data: TrainingData, project_id : str, lean_executor:
prev_goal : typing.List[Goal] = [Goal(goal.hypotheses, goal.goal) for goal in prev_goal]
next_goal : typing.List[Goal] = lean_context_helper.get_focussed_goals(lean_executor)
if len(prev_goal) > 0:
- training_data_format = TrainingDataFormat(
+ training_data_format = TheoremProvingTrainingDataFormat(
proof_id=proof_id,
all_useful_defns_theorems=[],
start_goals=prev_goal,
diff --git a/src/itp_interface/tools/lean_parse_utils.py b/src/itp_interface/tools/lean_parse_utils.py
index 8b3283f..60fd06b 100644
--- a/src/itp_interface/tools/lean_parse_utils.py
+++ b/src/itp_interface/tools/lean_parse_utils.py
@@ -1,15 +1,34 @@
#!/usr/bin/env python3
-import sys
-
-
-root_dir = f"{__file__.split('itp_interface')[0]}"
-if root_dir not in sys.path:
- sys.path.append(root_dir)
import typing
+from itp_interface.tools.iter_helpers import ClonableIterator
from itp_interface.lean_server.lean_utils import Lean3Utils
class LeanLineByLineReader(object):
+ class LineByLineIterator(ClonableIterator):
+ def __init__(self, lines: typing.List[str]):
+ self.lines = lines
+ self.current_index = 0
+
+ def __iter__(self) -> typing.Iterator[str]:
+ return self
+
+ def __next__(self) -> str:
+ if self.current_index >= len(self.lines):
+ raise StopIteration()
+ line = self.lines[self.current_index]
+ self.current_index += 1
+ return line
+
+ def set_to_index(self, index: int):
+ assert 0 <= index < len(self.lines), f"Index {index} out of bounds for lines of length {len(self.lines)}"
+ self.current_index = index
+
+ def clone(self) -> 'LeanLineByLineReader.LineByLineIterator':
+ cloned_iterator = LeanLineByLineReader.LineByLineIterator(self.lines)
+ cloned_iterator.current_index = self.current_index
+ return cloned_iterator
+
def __init__(self, file_name: str = None, file_content: str = None, remove_comments: bool = False, no_strip: bool = False):
assert file_name is not None or file_content is not None, "Either file_name or file_content must be provided"
assert file_name is None or file_content is None, "Only one of file_name or file_content must be provided"
@@ -21,12 +40,9 @@ def __init__(self, file_name: str = None, file_content: str = None, remove_comme
self.file_content : str = fd.read()
if remove_comments:
self.file_content = Lean3Utils.remove_comments(self.file_content)
-
- def instruction_step_generator(self) -> typing.Iterator[str]:
+
+ def instruction_step_generator(self) -> ClonableIterator:
lines = self.file_content.split('\n')
- for line in lines:
- if not self.no_strip:
- line = line.strip()
- else:
- line = line
- yield line
\ No newline at end of file
+ if not self.no_strip:
+ lines = [line.strip() for line in lines]
+ return LeanLineByLineReader.LineByLineIterator(lines)
\ No newline at end of file
diff --git a/src/itp_interface/tools/proof_exec_callback.py b/src/itp_interface/tools/proof_exec_callback.py
index 0494c1e..bcb9744 100644
--- a/src/itp_interface/tools/proof_exec_callback.py
+++ b/src/itp_interface/tools/proof_exec_callback.py
@@ -15,6 +15,7 @@
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.lean_cmd_executor import Lean3Executor
from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.isabelle_executor import IsabelleExecutor
from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor
from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor
@@ -68,7 +69,7 @@ def get_proof_executor(self) -> typing.Union[DynamicCoqProofExecutor, DynamicLea
lean_context_helper = Lean3ContextHelper(search_exec, self.search_depth, logger=self.logger)
return DynamicLeanProofExecutor(lean_context_helper, self.project_folder, self.file_path, context_type=DynamicLeanProofExecutor.ContextType.NoContext, use_hammer=self.use_hammer, timeout_in_seconds=self.timeout_in_secs, suppress_error_log=self.suppress_error_log, use_human_readable_proof_context=self.use_human_readable_proof_context, keep_local_context=self.keep_local_context)
elif self.language == ProofAction.Language.LEAN4:
- search_exec = Lean4SyncExecutor(self.project_folder, self.prefix, self.file_path, use_hammer=self.use_hammer, timeout_in_sec=self.timeout_in_secs, suppress_error_log=self.suppress_error_log, use_human_readable_proof_context=self.use_human_readable_proof_context, enable_search=self.always_use_retrieval, keep_local_context=self.keep_local_context, enforce_qed=self.enforce_qed)
+ search_exec = SimpleLean4SyncExecutor(self.project_folder, self.prefix, self.file_path, use_hammer=self.use_hammer, timeout_in_sec=self.timeout_in_secs, suppress_error_log=self.suppress_error_log, use_human_readable_proof_context=self.use_human_readable_proof_context, enable_search=self.always_use_retrieval, keep_local_context=self.keep_local_context, enforce_qed=self.enforce_qed)
lean4_context_helper = Lean4ContextHelper(search_exec, self.search_depth, logger=self.logger)
return DynamicLean4ProofExecutor(lean4_context_helper, self.project_folder, self.file_path, context_type=DynamicLeanProofExecutor.ContextType.NoContext, use_hammer=self.use_hammer, timeout_in_seconds=self.timeout_in_secs, suppress_error_log=self.suppress_error_log, use_human_readable_proof_context=self.use_human_readable_proof_context, keep_local_context=self.keep_local_context, enforce_qed=self.enforce_qed)
elif self.language == ProofAction.Language.ISABELLE:
diff --git a/src/itp_interface/tools/ray_utils.py b/src/itp_interface/tools/ray_utils.py
index 087b062..807ef7d 100644
--- a/src/itp_interface/tools/ray_utils.py
+++ b/src/itp_interface/tools/ray_utils.py
@@ -10,6 +10,7 @@
import typing
import logging
import gc
+import os
class RayUtils(object):
@@ -18,7 +19,15 @@ def init_ray(num_of_cpus: int = 10, object_store_memory_in_gb: float = 25, memor
gb = 2**30
object_store_memory = int(object_store_memory_in_gb * gb)
memory = int(memory_in_gb * gb)
- return ray.init(num_cpus=num_of_cpus, object_store_memory=object_store_memory, _memory=memory, ignore_reinit_error=True, runtime_env=runtime_env)
+ os.environ["RAY_INITIALIZED"] = "1"
+ obj = ray.init(num_cpus=num_of_cpus, object_store_memory=object_store_memory, _memory=memory, ignore_reinit_error=True, runtime_env=runtime_env)
+ return obj
+
+ @staticmethod
+ def is_ray_initialized() -> bool:
+ if os.environ.get("RAY_INITIALIZED", "0") == "1":
+ return True
+ return ray.is_initialized()
@staticmethod
def ray_run_within_parallel_limits(
diff --git a/src/itp_interface/tools/run_data_generation_transforms.py b/src/itp_interface/tools/run_data_generation_transforms.py
index 950429b..4701c36 100644
--- a/src/itp_interface/tools/run_data_generation_transforms.py
+++ b/src/itp_interface/tools/run_data_generation_transforms.py
@@ -10,9 +10,8 @@
import typing
import shutil
import gc
-import threading
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
-from itp_interface.tools.training_data import TrainingData
+from itp_interface.tools.training_data import TrainingData, DataLayoutFormat
# Conditional Ray import
try:
@@ -25,11 +24,12 @@
RayUtils = None
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.lean_cmd_executor import Lean3Executor
-from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
+from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.isabelle_executor import IsabelleExecutor
from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform
from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform
from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform as Lean4LocalDataGenerationTransform
+from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform as Lean4LocalDataExtractionTransform
from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
@@ -136,7 +136,7 @@ def _print_lean_callback():
search_lean_exec.__enter__()
return search_lean_exec
def _print_lean4_callback():
- search_lean4_exec = Lean4SyncExecutor(project_path, None, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error)
+ search_lean4_exec = SimpleLean4SyncExecutor(project_path, None, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error)
search_lean4_exec.__enter__()
return search_lean4_exec
def _print_isabelle_callback():
@@ -144,7 +144,11 @@ def _print_isabelle_callback():
search_isabelle_exec = IsabelleExecutor(project_path, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error, port=port)
search_isabelle_exec.__enter__()
return search_isabelle_exec
- if isinstance(transform, CoqLocalDataGenerationTransform) or isinstance(transform, LeanLocalDataGenerationTransform) or isinstance(transform, IsabelleLocalDataGenerationTransform) or isinstance(transform, Lean4LocalDataGenerationTransform):
+ if isinstance(transform, CoqLocalDataGenerationTransform) or \
+ isinstance(transform, LeanLocalDataGenerationTransform) or \
+ isinstance(transform, IsabelleLocalDataGenerationTransform) or \
+ isinstance(transform, Lean4LocalDataGenerationTransform) or \
+ isinstance(transform, Lean4LocalDataExtractionTransform):
if isinstance(transform, IsabelleLocalDataGenerationTransform) and transform.ray_resource_pool is not None:
# This is a blocking call
port = ray.get(transform.ray_resource_pool.wait_and_acquire.remote(1))[0]
@@ -157,7 +161,9 @@ def _print_isabelle_callback():
elif isinstance(transform, IsabelleLocalDataGenerationTransform):
exec = IsabelleExecutor(project_path, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error, port=port)
elif isinstance(transform, Lean4LocalDataGenerationTransform):
- exec = Lean4SyncExecutor(project_path, None, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error)
+ exec = SimpleLean4SyncExecutor(project_path, None, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error)
+ elif isinstance(transform, Lean4LocalDataExtractionTransform):
+ exec = SimpleLean4SyncExecutor(project_path, None, file_path, use_human_readable_proof_context=use_human_readable, suppress_error_log=log_error)
else:
raise Exception("Unknown transform")
with exec:
@@ -170,6 +176,8 @@ def _print_isabelle_callback():
transform(training_data, project_id, exec, _print_isabelle_callback, theorems, other_args)
elif isinstance(transform, Lean4LocalDataGenerationTransform):
transform(training_data, project_id, exec, _print_lean4_callback, theorems, other_args)
+ elif isinstance(transform, Lean4LocalDataExtractionTransform):
+ transform(training_data, project_id, exec, _print_lean4_callback, theorems, other_args)
else:
raise Exception("Unknown transform")
finally:
@@ -195,13 +203,18 @@ def get_training_data_object(transform, output_dir, logger: logging.Logger):
metadata.data_filename_suffix = RunDataGenerationTransforms.get_data_filename_suffix(transform)
metadata.lemma_ref_filename_prefix = RunDataGenerationTransforms.get_lemma_ref_filename_prefix(transform)
metadata.lemma_ref_filename_suffix = RunDataGenerationTransforms.get_lemma_ref_filename_suffix(transform)
+ if isinstance(transform, Lean4LocalDataExtractionTransform):
+ layout = DataLayoutFormat.DECLARATION_EXTRACTION
+ else:
+ layout = DataLayoutFormat.THEOREM_PROVING
training_data = TrainingData(
output_dir,
RunDataGenerationTransforms.get_meta_file_name(transform),
metadata,
transform.max_parallelism,
remove_from_store_after_loading=True,
- logger=logger)
+ logger=logger,
+ layout=layout)
return training_data
@staticmethod
@@ -265,7 +278,10 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
if self._use_ray:
object_store_memory_in_gb = 100
memory_in_gb = 5
- ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb)
+ if not RayUtils.is_ray_initialized():
+ ray_dashboard = RayUtils.init_ray(num_of_cpus=pool_size, object_store_memory_in_gb=object_store_memory_in_gb, memory_in_gb=memory_in_gb)
+ else:
+ ray_dashboard = "Ray already initialized"
self.logger.info(f"==============================>[{transform.name}] Ray initialized with {transform.max_parallelism} CPUs, Memory=({memory_in_gb} GiB, Object Memory = {object_store_memory_in_gb} GiB)<==============================")
self.logger.info(f"Ray Context:\n {ray_dashboard}")
else:
@@ -296,7 +312,7 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
os.makedirs(temp_file_dir, exist_ok=True)
log_file = os.path.join(self.logging_dir, f"{relative_file_path}.log")
theorems = projects[project][file_path]
- if isinstance(transform, Lean4LocalDataGenerationTransform):
+ if isinstance(transform, Lean4LocalDataGenerationTransform) or isinstance(transform, Lean4LocalDataExtractionTransform):
# For every theorem we need to create a separate job
for _idx, theorem in enumerate(theorems):
log_file = os.path.join(self.logging_dir, f"{relative_file_path}-{_idx}.log")
@@ -317,13 +333,18 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
final_training_meta.data_filename_suffix = RunDataGenerationTransforms.get_data_filename_suffix(transform)
final_training_meta.lemma_ref_filename_prefix = RunDataGenerationTransforms.get_lemma_ref_filename_prefix(transform)
final_training_meta.lemma_ref_filename_suffix = RunDataGenerationTransforms.get_lemma_ref_filename_suffix(transform)
+ if isinstance(transform, Lean4LocalDataExtractionTransform):
+ layout = DataLayoutFormat.DECLARATION_EXTRACTION
+ else:
+ layout = DataLayoutFormat.THEOREM_PROVING
final_training_data = TrainingData(
new_output_dir,
RunDataGenerationTransforms.get_meta_file_name(transform),
final_training_meta,
transform.max_parallelism,
remove_from_store_after_loading=True,
- logger=self.logger)
+ logger=self.logger,
+ layout=layout)
last_job_idx = 0
tds = [None]*len(job_spec)
num_theorems = 0
diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/tools/simple_lean4_sync_executor.py
new file mode 100644
index 0000000..b254605
--- /dev/null
+++ b/src/itp_interface/tools/simple_lean4_sync_executor.py
@@ -0,0 +1,948 @@
+#!/usr/bin/env python3
+
+import os
+import copy
+import random
+import logging
+import re
+import time
+import json
+import typing
+import bisect
+import subprocess
+from itp_interface.tools.tactic_parser import (
+ TacticParser,
+ ErrorInfo,
+ LeanLineInfo,
+ FileDependencyAnalysis,
+ RequestType,
+ print_tactics
+)
+from itp_interface.lean_server.lean_context import ProofContext
+from itp_interface.lean_server.lean4_utils import Lean4Utils
+from itp_interface.tools.lean_parse_utils import LeanLineByLineReader
+from itp_interface.tools.theorem_details import TheoremDetails
+from itp_interface.tools.misc_defns import HammerMode
+from itp_interface.tools.iter_helpers import ClonableIterator
+from typing import List, Optional, Tuple, OrderedDict, Dict
+
+class SimpleLean4SyncExecutor:
+ theorem_regex = r"((((theorem|lemma)[\s]+([^\s:]*))|example)([\S|\s]*?)(:=|=>)[\s]*?)[\s]+"
+ theorem_match = re.compile(theorem_regex, re.MULTILINE)
+ have_regex = r"(^\s*have\s+([^\s:]*):[=]*([^:]*))(:=\s*by)([\s|\S]*)"
+ have_match = re.compile(have_regex, re.MULTILINE)
+ unsolved_message = "unsolved goals"
+ no_goals = "No goals to be solved"
+ missing_closure_message = "unexpected end of input; expected '{'"
+ def __init__(self,
+ project_root: Optional[str] = None,
+ prefix: Optional[str] = None,
+ main_file: Optional[str] = None,
+ use_hammer: typing.Union[bool, HammerMode] = False,
+ timeout_in_sec: int = 60,
+ use_human_readable_proof_context: bool = True,
+ proof_step_iter: Optional[ClonableIterator] = None,
+ suppress_error_log: bool = False,
+ enable_search: bool = False,
+ keep_local_context: bool = False,
+ enforce_qed: bool = False,
+ logger: Optional[logging.Logger] = None):
+ assert proof_step_iter is None or isinstance(proof_step_iter, ClonableIterator), \
+ "proof_step_iter must be an iterator"
+ assert main_file is not None or proof_step_iter is not None, \
+ "Either main_file or proof_step_iter must be provided"
+ assert main_file is None or proof_step_iter is None, \
+ "Only one of main_file or proof_step_iter must be provided"
+ assert main_file is None or (os.path.exists(main_file) and main_file.endswith(".lean")), \
+ f"main_file must be a valid path to a '.lean' file ({main_file})"
+ assert project_root is None or (os.path.exists(project_root) and os.path.isdir(project_root)), \
+ "project_root must be a valid path to a directory"
+ assert not use_hammer, "Hammer is not supported for Lean4"
+ self.use_human_readable_proof_context = use_human_readable_proof_context
+ self.project_root = project_root if project_root is not None else "."
+ self.main_file = main_file
+ self.ticks = str(time.time()).replace(".", "") # This ensures that the temp file name is unique and doesn't clash with other temp files
+ # This helps in running parallel instances of prover
+ self.random_num = str(random.randint(0, 100000000))
+ self.temp_filename_suffix = f"temptodel{self.ticks}{self.random_num}.lean"
+ self.temp_file = os.path.join(prefix, self.temp_filename_suffix) if prefix is not None else self.temp_filename_suffix
+ self.temp_file_full_path = os.path.join(self.project_root, self.temp_file)
+ self.temp_file_full_path = os.path.abspath(self.temp_file_full_path)
+ self.use_hammer = use_hammer
+ self.timeout_in_sec = min(timeout_in_sec, 120) # Maximum 120s timeout
+ self.current_stmt = None
+ self.line_num = 0
+ self.main_file_iter = proof_step_iter
+ self.suppress_error_log = suppress_error_log
+ self.tactic_parser: TacticParser | None = None
+ self.execution_complete = False
+ self._enforce_qed = enforce_qed
+ self._ready_to_accept_proof = not self._enforce_qed
+ self._max_memory_in_mib = 40000 # 40 GiB is needed for mathlib to work seemlessly
+ self._lines_executed = []
+ self.proof_context : ProofContext| None = None
+ self.curr_lemma_name : Optional[str] = None
+ self.curr_lemma : Optional[str] = None
+ self.lean_error_messages : List[str] = []
+ self._proof_running = False
+ self._file_content = ""
+ self.local_file_lemmas: OrderedDict[str, str] = OrderedDict()
+ self.local_theorem_lemma_description: OrderedDict[str, str] = OrderedDict()
+ self._proof_start_idx: Optional[int] = None
+ self._import_end_idx: Optional[int] = None
+ self.logger = logger if logger is not None else logging.getLogger(__name__)
+ self.use_file = False
+ self._enable_search = enable_search
+ self._theorem_started = False
+ self._content_till_last_theorem_stmt = None
+ self._last_theorem = None
+ self._anon_theorem_count = 0
+ self._debug_traces = []
+ self.debug_enabled = False
+ self._last_tactics : dict[int, str] = {}
+ self.possible_proof_tactics = ""
+ self._last_tactic_line_idx = None
+ self._error_messages_so_far = set()
+ self._error_messages_since_last_thm = {}
+ self._run_exactly = False
+ self._nested_have_counts = 0
+ self._last_tactic_was_modified = False
+ if self._enable_search:
+ pass
+ pass
+
+ def set_run_exactly(self):
+ self._run_exactly = True
+
+ def unset_run_exactly(self):
+ self._run_exactly = False
+
+ def run_exactly(self):
+ return self._run_exactly
+
+ def reset(self,
+ proof_step_iter: Optional[ClonableIterator] = None):
+ # Note: We CANNOT reset the main_file_iter as it is a generator
+ assert (proof_step_iter is not None and isinstance(proof_step_iter, ClonableIterator)) or self.main_file is not None, \
+ "Either proof_step_iter must be provided or main_file must be set"
+ self.current_stmt = None
+ self.line_num = 0
+ self.main_file_iter = proof_step_iter if proof_step_iter is not None else self.main_file_iter
+ self.tactic_parser: TacticParser | None = None
+ self.execution_complete = False
+ self._lines_executed = []
+ self.proof_context : ProofContext | None = None
+ self.curr_lemma_name : Optional[str] = None
+ self.curr_lemma : Optional[str] = None
+ self.lean_error_messages : List[str] = []
+ self._proof_running = False
+ self._file_content = ""
+ self.local_file_lemmas: OrderedDict[str, str] = OrderedDict()
+ self.local_theorem_lemma_description: OrderedDict[str, str] = OrderedDict()
+ self._proof_start_idx: Optional[int] = None
+ self._import_end_idx: Optional[int] = None
+ self._theorem_started = False
+ self._content_till_last_theorem_stmt: str|None = None
+ self._content_till_after_theorem_stmt: str|None = None
+ self._last_theorem = None
+ self._anon_theorem_count = 0
+ self._last_tactics : dict[int, str] = {}
+ self.possible_proof_tactics = ""
+ self._last_tactic_line_idx = None
+ self._debug_traces = []
+ self.debug_enabled = False
+ self._error_messages_so_far = set()
+ self._error_messages_since_last_thm = {}
+ self._nested_have_counts = 0
+ self._last_tactic_was_modified = False
+ if self._enable_search:
+ pass
+ pass
+
+ def __enter__(self):
+ self.tactic_parser = TacticParser(project_path=self.project_root, logger=self.logger)
+ if self.main_file_iter is None:
+ assert self.main_file is not None, "main_file must be set if main_file_iter is None"
+ self.main_file_iter = LeanLineByLineReader(self.main_file, remove_comments=True, no_strip=True).instruction_step_generator()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.tactic_parser is not None:
+ self.tactic_parser.close()
+ if os.path.exists(self.temp_file_full_path):
+ os.remove(self.temp_file_full_path)
+
+ def is_in_proof_mode(self):
+ return True if self.proof_context else (len(self.lean_error_messages) > 0) # It is still in proof mode if we encountered a wrong proof
+
+ def needs_qed(self):
+ return self.proof_context is not None and len(self.proof_context.all_goals) == 0
+
+ def needs_cut_close(self):
+ return self.proof_context is not None and len(self.proof_context.fg_goals) == 0 and len(self.proof_context.all_goals) > 0
+
+ def run_next(self) -> bool:
+ try:
+ assert self.main_file_iter is not None, "main_file_iter should not be None"
+ next_line = next(self.main_file_iter)
+ except StopIteration:
+ self.execution_complete = True
+ return False
+ self.current_stmt = next_line
+ self.line_num += 1
+ try:
+ idx = len(self._lines_executed)
+ self._run_stmt_on_lean_server(idx, self.current_stmt)
+ except:
+ if not self.suppress_error_log:
+ self.logger.error(f"Got an exception while running '{self.current_stmt}' on lean. File name: {self.main_file}")
+ self.logger.exception(f"Exception Log")
+ raise
+ if self.run_exactly():
+ self._lines_executed.append(self.current_stmt)
+ else:
+ self._lines_executed.append("") # Add an empty line to keep the line numbers in sync
+ return True
+
+ def extract_all_theorems_and_definitions(self, json_output_path: str|None = None) -> List[FileDependencyAnalysis]:
+ assert self.main_file is not None, "main_file must be set to extract theorems and definitions"
+ assert self.tactic_parser is not None, "tactic_parser must be initialized to extract theorems and definitions"
+
+ json_output_path = json_output_path if json_output_path is not None else self.main_file + ".dependency_analysis.json"
+ file_dependency_analysis, _ = self.tactic_parser.parse_file(
+ self.main_file,
+ parse_type=RequestType.PARSE_DEPENDS,
+ json_output_path=json_output_path)
+ return file_dependency_analysis
+
+ def get_lemma_name_if_running(self) -> Optional[str]:
+ if not self.is_in_proof_mode():
+ return None
+ else:
+ try:
+ return self.curr_lemma_name
+ except:
+ return None
+
+ def get_lemma_stmt_if_running(self) -> Optional[str]:
+ try:
+ assert self.curr_lemma_name is not None, "Current lemma name should not be None"
+ return self.local_theorem_lemma_description[self.curr_lemma_name]
+ except:
+ return None
+
+ def get_current_lemma_name(self) -> Optional[str]:
+ if self.curr_lemma_name is None:
+ return None
+ else:
+ return self.curr_lemma_name
+
+ def _add_last_tactic(self, idx: int, stmt: str):
+ if idx not in self._last_tactics:
+ stmt = self._have_preprocessing(stmt)
+ indentation = " " * self._nested_have_counts * 2
+ if self._nested_have_counts > 0:
+ stmt = stmt.lstrip()
+ stmt = indentation + stmt
+ self._last_tactic_was_modified = True
+ self._last_tactics[idx] = stmt
+ self._last_tactic_line_idx = idx
+ # self.logger.info(f"Proofs so far:\n{self._get_tactics_so_far()}")
+
+ def _have_preprocessing(self, stmt: str) -> str:
+ stmt_match = SimpleLean4SyncExecutor.have_match.match(stmt)
+ if not stmt_match:
+ return stmt
+ else:
+ full_have_stmt = stmt_match.group(1)
+ by = stmt_match.group(4)
+ after_tactics = stmt_match.group(5)
+ # self.logger.info(f"Processing 'have' statement: {full_have_stmt} with by: {by} and after tactics: {after_tactics}")
+ assert by is not None, "By should not be None"
+ assert full_have_stmt is not None, "Full have statement should not be None"
+ if after_tactics is None:
+ # There is no tactic to apply afterwards to just leave it as it is
+ return stmt
+ else:
+ # split the after tactics by new lines
+ after_tactics = after_tactics.split("\n")
+ for i, tactic in enumerate(after_tactics):
+ indentation = " " * (self._nested_have_counts + 1) * 2
+ after_tactics[i] = indentation + tactic.lstrip()
+ after_tactics_str = "\n".join(after_tactics)
+ # Reconstruct the have statement with the tactics applied afterwards
+ by = by.rstrip()
+ new_stmt = f"{full_have_stmt}{by}\n{after_tactics_str}"
+ new_stmt = new_stmt.rstrip()
+ self._last_tactic_was_modified = True
+ return new_stmt
+
+ def _get_lean_code_with_tactics(self, idx: int, stmt: str):
+ assert self._last_theorem is not None, "Last theorem should not be None"
+ self._add_last_tactic(idx, stmt)
+ tactics_so_far = self._get_tactics_so_far()
+ assert len(tactics_so_far) > 0, "There should be at least one tactic so far"
+ _ , _, theorem_stmt = self._last_theorem
+ return theorem_stmt + tactics_so_far
+
+ def _backtrack_tactic_line(self, idx: int):
+ # identify the keys to remove
+ self._lines_executed = self._lines_executed[:idx]
+ idx_to_remove = []
+ backtracked = False
+ for k in self._last_tactics.keys():
+ if k >= idx:
+ idx_to_remove.append(k)
+ for k in idx_to_remove:
+ backtracked = True
+ del self._last_tactics[k]
+ idx_to_remove = []
+ for k in self._error_messages_since_last_thm.keys():
+ if k >= idx:
+ idx_to_remove.append(k)
+ for k in idx_to_remove:
+ backtracked = True
+ msg = self._error_messages_since_last_thm[k]
+ if msg in self._error_messages_so_far:
+ self._error_messages_so_far.remove(msg)
+ del self._error_messages_since_last_thm[k]
+ self._last_tactic_line_idx = max(self._last_tactics.keys(), default=None)
+ return backtracked
+
+ def _get_tactics_in_sorted_order(self) -> List[Tuple[int, str]]:
+ tactics_so_far = [(k, v) for k, v in self._last_tactics.items()]
+ tactics_so_far = sorted(tactics_so_far, key=lambda x: x[0])
+ return tactics_so_far
+
+ def _get_tactics_so_far(self) -> str:
+ tactics_so_far = self._get_tactics_in_sorted_order()
+ tactics_so_far = [v for _, v in tactics_so_far]
+ return "\n".join(tactics_so_far)
+
+ def _clear_tactics(self):
+ tactics_so_far = self._get_tactics_so_far()
+ self.possible_proof_tactics += tactics_so_far
+ self._last_tactics : dict[int, str] = {}
+ self._last_tactic_line_idx = None
+ self._error_messages_since_last_thm = {}
+ pass
+
+ def _theorem_started_init(self):
+ if self._theorem_started:
+ assert self._last_theorem is not None, "Last theorem should not be None"
+ theorem_name, theorem_stmt, full_stmt = self._last_theorem
+ self.curr_lemma_name = theorem_name
+ self.curr_lemma = theorem_stmt
+ if len(theorem_name) == 0:
+ self._anon_theorem_count += 1
+ theorem_name = f"anon_theorem____{self._anon_theorem_count}"
+ self.local_file_lemmas[theorem_name] = theorem_stmt
+ self.local_theorem_lemma_description[theorem_name] = full_stmt
+
+ def _format_error_message(self, error_info: ErrorInfo) -> str:
+ return f"L {error_info.position.line}, C {error_info.position.column}: {error_info.message}"
+
+ def _reset_proof_context(self):
+ self.proof_context = None
+ self.curr_lemma = None
+ self.curr_lemma_name = None
+ self._clear_tactics()
+ self._proof_running = False
+
+ def _set_proof_context(self,
+ proof_is_running: bool,
+ proof_goal_messages: List[str],
+ last_tactic: LeanLineInfo):
+ self._proof_running = proof_is_running
+ if self._proof_running:
+ proof_goals = []
+ if len(proof_goal_messages) == 0:
+ proof_goals = []
+ else:
+ proof_goals = [g_text for g_text in proof_goal_messages
+ if g_text is not None and len(g_text) > 0]
+ self.proof_context = self._parse_proof_context(proof_goals)
+ if self.proof_context == ProofContext.empty() and \
+ ((self._enforce_qed and last_tactic.text.strip() == "done") or not self._enforce_qed):
+ self._reset_proof_context()
+ else:
+ self.proof_context : ProofContext | None = None
+ self.lean_error_messages.clear()
+
+ def _get_nested_haves_count(self, tactics: List[LeanLineInfo], errors: List[ErrorInfo]) -> int:
+ # See all goal related error messages
+ goal_related : List[ErrorInfo] = []
+ for error in errors:
+ if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message):
+ # Check if the last tactic before this error was a 'have' tactic
+ goal_related.append(error)
+ nested_have_count = 0
+ for tactic in reversed(tactics):
+ if tactic.text.strip().startswith("have"):
+ # Check if there is any goal related error after this tactic
+ for error in goal_related:
+ if error.position.line == tactic.line:
+ nested_have_count += 1
+ return nested_have_count
+
+ def _update_proof_context(self, idx : int, tactics: List[LeanLineInfo], errors: List[ErrorInfo]):
+ proof_goal_messages: list[str] = []
+ error_messages: list[str] = []
+ assert len(tactics) >= 0, "Tactics should not be None"
+ last_tactic: LeanLineInfo = tactics[-1]
+ if not tactics and not errors:
+ raise ValueError(f"Response is None cannot update proof context for line number {idx}")
+ for error in errors:
+ if error.message.startswith(SimpleLean4SyncExecutor.unsolved_message):
+ # Always take the last unsolved goals message
+ proof_goal_messages.append(error.message[len(SimpleLean4SyncExecutor.unsolved_message):])
+ elif error.message.startswith(SimpleLean4SyncExecutor.no_goals):
+ proof_goal_messages.append(error.message)
+ else:
+ if not error.message.startswith(SimpleLean4SyncExecutor.missing_closure_message):
+ if error.position.line >= last_tactic.line:
+ self._error_messages_since_last_thm[idx] = self._format_error_message(error)
+ error_messages.append(error.message) # Always take the last error message
+ self._error_messages_so_far.add(self._format_error_message(error))
+ proof_is_running = False
+ proof_goal_messages = [msg for msg in proof_goal_messages if not msg.startswith(SimpleLean4SyncExecutor.no_goals)]
+ if len(proof_goal_messages) >= 0 and len(error_messages) == 0:
+ proof_is_running = True
+ if len(error_messages) == 0:
+ assert proof_is_running, f"Proof is not running but no error message is present, errors:\n{errors}, \nlemma: \n{self.curr_lemma_name}, \nlemma_stmt: \n{self.curr_lemma}, \nline_num: \n{self.line_num}"
+ self._nested_have_counts = self._get_nested_haves_count(tactics, errors)
+ self._set_proof_context(proof_is_running, proof_goal_messages, last_tactic)
+ else:
+ new_failed_tactic_error_lines = set()
+ if len(self._last_tactics) >= 2:
+ all_tactics = self._get_tactics_in_sorted_order()
+ last_tactic_line = all_tactics[-2][0]
+ # for error_info in errors:
+ # self.logger.info(f"Error at line {error_info.position.line}, col {error_info.position.column}: {error_info.message}")
+ # self.logger.info(f"Last tactic at line {last_tactic.line}, col {last_tactic.column}: {last_tactic.text}")
+ # Rollback the last tactic if there was an error
+ tactics_before_backtrack = self._get_tactics_so_far()
+ # errors after last tactic
+ errors_after_last_tactic = [e for e in errors if e.position.line > last_tactic_line]
+ # for error in errors_after_last_tactic:
+ # self.logger.info(f"Error after last tactic at line {error.position.line}, col {error.position.column}: {error.message}")
+ for error in errors_after_last_tactic:
+ new_failed_tactic_error_lines.add(error.position.line)
+ # self.logger.info(f"New failed tactic error lines: {new_failed_tactic_error_lines}")
+ tactics_which_failed = [t for t in tactics if t.line in new_failed_tactic_error_lines]
+ tactics_which_failed_str = "\n".join([t.text for t in tactics_which_failed])
+ self._backtrack_tactic_line(idx)
+ if len(new_failed_tactic_error_lines) >= 1 and len(tactics_which_failed) >= 1:
+ tactics_so_far = self._get_tactics_so_far()
+ # This should be (tactics_before_backtrack - tactics_so_far) - (tactics_which_failed)
+ # Where `-` is basically removing that part of the string
+ # self.logger.info(f"Backtracking tactics at line {idx}.\n Tactics so far:\n{tactics_so_far}\nTactics before backtrack:\n{tactics_before_backtrack}\nTactics which failed:\n{tactics_which_failed_str}")
+ # print_tactics(tactics, self.logger)
+ assert tactics_before_backtrack.startswith(tactics_so_far), \
+ "Tactics before backtrack should start with tactics so far"
+ tactics_tried = tactics_before_backtrack[len(tactics_so_far):]
+ # self.logger.info(f"Tactics tried:\n{tactics_tried}\nTactics which failed:\n{tactics_which_failed_str}")
+ assert tactics_tried.endswith(tactics_which_failed_str), "Tactics tried should end with tactics which failed"
+ partially_executed_tactics = tactics_tried[:-len(tactics_which_failed_str)] if len(tactics_which_failed_str) > 0 else tactics_tried
+ # self.logger.info(f"Partially executed tactics:\n{partially_executed_tactics}")
+ # Add the partially executed tactics back, and push the state update
+ if len(partially_executed_tactics.strip()) > 0:
+ partially_executed_tactics = partially_executed_tactics.strip()
+ self._run_stmt_on_lean_server(idx, partially_executed_tactics)
+ self.lean_error_messages = copy.deepcopy(error_messages)
+
+ def _run_stmt_on_lean_server(self, idx : int, stmt: str, theorem_started: bool = False):
+ assert self.tactic_parser is not None, "Tactic parser is not initialized"
+ assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None"
+ if "sorry" in stmt and self._proof_running:
+ # We don't need to run the sorry statements. This should be treated as a failed proof step
+ self.lean_error_messages = ["The tactic 'sorry' was found in the statement, this is not allowed"]
+ return
+ elif len(stmt.strip()) == 0 and self._proof_running:
+ # We don't need to run the empty statements. This should be treated as a failed proof step
+ self.lean_error_messages = ["There is no tactic in the statement, it is just empty line or whitespace"]
+ return
+ elif self.proof_context == ProofContext.empty() and \
+ self._proof_running and \
+ stmt != "done":
+ self.lean_error_messages = [
+ "The proof is about to finish, please use 'done' to finish the proof."]
+ return
+ elif stmt == "done" and \
+ self._proof_running and \
+ self.proof_context != ProofContext.empty():
+ self.lean_error_messages = [
+ "The proof is not finished, please complete the proof before using 'done'."]
+ return
+
+ proof_should_run = False
+ if theorem_started:
+ # Load the theorem context at once
+ self.tactic_parser.parse(
+ self._content_till_last_theorem_stmt,
+ fail_on_error=True,
+ parse_type=RequestType.CHKPT_TACTICS
+ )
+ if theorem_started or not self._proof_running:
+ proof_should_run = self._theorem_started
+ self._theorem_started_init()
+ if not self._proof_running and not proof_should_run:
+ return
+ code_was_executed = False
+ while not code_was_executed:
+ # Run the statement in tactic mode
+ code = self._get_lean_code_with_tactics(idx, stmt)
+ tactics, error_info = self.tactic_parser.parse(
+ code,
+ fail_on_error=False,
+ parse_type=RequestType.PARSE_TACTICS)
+ code_was_executed = True
+ self._update_proof_context(idx, tactics, error_info)
+ if self.debug_enabled:
+ tactics_json = [tactic.to_json() for tactic in tactics]
+ errors_json = [error.to_json() for error in error_info]
+ trace = ("
\n" + "-"*20 + "\n").join(tactics_json + errors_json)
+ self._debug_traces.append(trace)
+ pass
+
+ def _skip_to_theorem(self, theorem: str):
+ # Skip to the given theorem
+ found_theorem = False
+ thm_namespace, given_theorem_name = parse_thm_name(theorem)
+ # Scan the whole file first
+ lines : list[str] = []
+ assert self.main_file_iter is not None, "main_file_iter should not be None"
+ assert self.tactic_parser is not None, "Tactic parser is not initialized"
+ while True:
+ try:
+ stmt = next(self.main_file_iter)
+ except StopIteration:
+ break
+ lines.append(stmt)
+ full_file = "\n".join(lines) + "\n"
+ # run the tactic parser in theorem parsing mode
+ lean_line_infos, _ = self.tactic_parser.parse(
+ full_file,
+ fail_on_error=False,
+ parse_type=RequestType.PARSE_THEOREM)
+ # Filter out theorems and lemmas
+ theorems = [info for info in lean_line_infos if info.decl_type == "theorem" or info.decl_type == "lemma"]
+ found_theorem = False
+ for thm in theorems:
+ name = thm.name
+ assert name is not None, "Theorem name should not be None"
+ if given_theorem_name == name:
+ actual_namespace = thm.namespace if thm.namespace is not None else ""
+ if actual_namespace == thm_namespace:
+ # Found the theorem
+ found_theorem = True
+ line_num = thm.line
+ break
+ if not found_theorem:
+ raise ValueError(f"The theorem '{theorem}' was not found in the file '{self.main_file}'")
+ assert line_num > 0, "Theorem line number should be greater than 0"
+ self._lines_executed = lines[:line_num - 1]
+ theorem_text = thm.text
+
+ content_until_after_theorem = "\n".join(self._lines_executed) + "\n" + theorem_text
+
+ # Parse out tactics now
+ all_tactics_till_now, _ = self.tactic_parser.parse(content_until_after_theorem, fail_on_error=True, parse_type=RequestType.PARSE_TACTICS)
+ # Find the first index which is after the theorem line
+ first_idx_after_theorem = None
+ for idx, tactic_info in enumerate(all_tactics_till_now):
+ if tactic_info.line >= line_num:
+ first_idx_after_theorem = idx
+ break
+ if first_idx_after_theorem is None:
+ msg = "Could not find the first tactic after the theorem" + \
+ f" only tactic based proofs are supported. Theorem: '{theorem}' on line {line_num}, file: '{self.main_file}'" + \
+ " is probably not followed by any tactic based proof." + \
+ " All tactics parsed till now:\n" + \
+ "\n".join([f"L {t.line}, C {t.column}: {t.text}" for t in all_tactics_till_now]) + \
+ "\n^^^ Cannot see tactics for the theorem."
+ raise NotImplementedError(msg)
+ start_tactic = all_tactics_till_now[first_idx_after_theorem]
+ tactic_start_line = start_tactic.line
+ tactic_start_col = start_tactic.column
+ assert tactic_start_line > 0, "Tactic start line should be greater than 0"
+ content_until_after_theorem = "\n".join(lines[:tactic_start_line - 1] + [lines[tactic_start_line - 1][:tactic_start_col]])
+ self._content_till_after_theorem_stmt = content_until_after_theorem
+ self._content_till_after_theorem_stmt = self._content_till_after_theorem_stmt.strip()
+ assert self._content_till_after_theorem_stmt.endswith(':='), "Content till last theorem statement should end with ':='"
+ content_until_before_theorem = "\n".join(lines[:line_num - 1])
+ self._content_till_last_theorem_stmt = content_until_before_theorem
+ theorem_stmt = "\n".join(lines[line_num - 1:tactic_start_line - 1] + [lines[tactic_start_line - 1][:tactic_start_col]])
+ theorem_stmt = theorem_stmt.strip()
+ self._last_theorem = (given_theorem_name, theorem_stmt, theorem_stmt)
+ self._theorem_started = True
+ self._lines_executed.extend(lines[line_num - 1:tactic_start_line - 1] + [lines[tactic_start_line - 1][:tactic_start_col]])
+ self._run_stmt_on_lean_server(tactic_start_line, "by", theorem_started=True)
+ self._lines_executed.append('by')
+ # Reset the iterator to the line of the theorem
+ if lines[tactic_start_line - 1].strip().endswith("by"):
+ self.main_file_iter.set_to_index(tactic_start_line)
+ else:
+ self.main_file_iter.set_to_index(tactic_start_line + 1)
+ self.line_num = len(self._lines_executed)
+
+ def _parse_proof_context(self, proof_goals: list) -> ProofContext:
+ goals = []
+ for proof_goal in proof_goals:
+ if self.use_human_readable_proof_context:
+ goals.extend(Lean4Utils.parse_proof_context_human_readable_as_goals(proof_goal))
+ else:
+ raise NotImplementedError("Parsing of non-human readable proof context is not implemented")
+ if len(goals) == 0:
+ return ProofContext.empty()
+ else:
+ return ProofContext(goals, [], [], [])
+
+ def validate_proof(self, timeout_sec: int = 30, keep_temp_file: bool = True) -> typing.Dict[str, typing.Any]:
+ """
+ Validate the current proof state by running 'lake lean' on a temporary file.
+ This provides an independent verification without relying on the TacticParser.
+
+ Args:
+ timeout_sec: Timeout in seconds for the lake lean process
+ keep_temp_file: If True, keeps the temporary file after validation (default: True)
+
+ Returns:
+ Dictionary with validation results:
+ {
+ 'success': bool, # True if proof is complete with no errors
+ 'compilation_ok': bool, # True if file compiles
+ 'has_sorries': bool, # True if there are unsolved goals (sorries)
+ 'error_message': str, # Error message if any
+ 'errors': list, # List of error details
+ 'lean_code': str, # The code that was validated
+ 'theorem_name': str # Name of theorem being validated
+ }
+ """
+
+ # Get theorem name for logging/reporting, but don't require it
+ assert self._last_theorem is not None, "Either last theorem should not be None or there should be some executed lines"
+ assert self._content_till_last_theorem_stmt is not None, "Content till last theorem statement should not be None"
+ theorem_name, _, full_thm_stmt = self._last_theorem
+ code_before_thm = self._content_till_last_theorem_stmt
+
+ # Create the Lean code with all executed lines up to current point
+ lines_before_thm = code_before_thm + "\n" + full_thm_stmt + "\n"
+
+ # Build the complete Lean code with actual proof tactics
+ # The proof tactics are accumulated in self.possible_proof_tactics
+ actual_proof = "" # Track the actual proof for sorry checking
+ proof_tactics_source = self.possible_proof_tactics
+
+ # If possible_proof_tactics is empty, try to use _last_tactics as fallback
+ if not proof_tactics_source or not proof_tactics_source.strip():
+ if self._last_tactics:
+ # Extract tactics from _last_tactics (same logic as _clear_tacitcs)
+ tactics_so_far = [(k, v) for k, v in self._last_tactics.items()]
+ tactics_so_far = sorted(tactics_so_far, key=lambda x: x[0])
+ tactics_so_far = [v for _, v in tactics_so_far]
+ proof_tactics_source = "\n".join(tactics_so_far)
+
+ # If both are empty, raise an error
+ if not proof_tactics_source or not proof_tactics_source.strip():
+ raise ValueError("No proof tactics available. Neither 'possible_proof_tactics' nor '_last_tactics' contain any proof steps.")
+
+ # Now build the Lean code with the proof tactics
+ lean_code = lines_before_thm.rstrip() + "\n" + proof_tactics_source.strip() + "\n"
+
+ # Create a unique temporary file
+ temp_filename = f"validation_{self.ticks}_{self.random_num}.lean"
+ temp_file_path = os.path.join(self.project_root, temp_filename)
+
+ try:
+ # Write the Lean code to the temporary file
+ with open(temp_file_path, 'w') as f:
+ f.write(lean_code)
+
+ # Run lake lean on the file
+ try:
+ result = subprocess.run(
+ ['lake', 'lean', temp_filename],
+ cwd=self.project_root,
+ capture_output=True,
+ text=True,
+ timeout=timeout_sec
+ )
+
+ stdout = result.stdout
+ stderr = result.stderr
+ output = stdout + '\n' + stderr
+
+ except subprocess.TimeoutExpired:
+ # Don't delete temp file on timeout so it can be inspected
+ return {
+ 'success': False,
+ 'compilation_ok': False,
+ 'has_sorries': False,
+ 'error_message': f'Timeout after {timeout_sec} seconds',
+ 'errors': [],
+ 'lean_code': lean_code,
+ 'theorem_name': theorem_name,
+ 'full_output': f'Process timed out after {timeout_sec} seconds',
+ 'stdout': '',
+ 'stderr': '',
+ 'return_code': -1,
+ 'temp_filename': temp_filename,
+ 'temp_file_path': temp_file_path,
+ 'temp_file_kept': True, # Keep file on timeout for debugging
+ 'debug_traces': list(self._debug_traces),
+ 'possible_proof_tactics': self.possible_proof_tactics
+ }
+
+ # Parse the output for errors and warnings
+ errors = []
+ error_pattern = re.compile(r'(\S+):(\d+):(\d+):\s*(warning|error):\s*(.+)')
+
+ for line in output.split('\n'):
+ match = error_pattern.match(line)
+ if match:
+ filename, line_num, col_num, severity, message = match.groups()
+ errors.append({
+ 'file': filename,
+ 'line': int(line_num),
+ 'column': int(col_num),
+ 'severity': severity,
+ 'message': message
+ })
+
+ # Check for 'sorry' only in the actual proof we generated
+ has_sorries = 'sorry' in actual_proof.lower()
+
+ # Only fail on actual errors (not warnings)
+ # Also check for "unsolved goals" in error messages
+ theorem_has_error = False
+ for error in errors:
+ if error['severity'] == 'error':
+ theorem_has_error = True
+ # Also check if the error mentions unsolved goals
+ if 'unsolved goals' in error['message'].lower():
+ has_sorries = True
+
+ # Determine success: compilation ok, no sorries in actual proof, no errors (ignore warnings)
+ compilation_ok = result.returncode == 0
+ success = compilation_ok and not has_sorries and not theorem_has_error
+
+ error_message = ''
+ if not compilation_ok:
+ error_message = 'Compilation failed'
+ elif has_sorries:
+ error_message = 'Proof has unsolved goals (sorries)'
+ elif theorem_has_error:
+ error_message = 'Theorem has errors'
+ else:
+ error_message = 'Proof is complete'
+
+ # Combine full raw output for debugging
+ full_output = f"=== STDOUT ===\n{stdout}\n\n=== STDERR ===\n{stderr}"
+
+ return {
+ 'success': success,
+ 'compilation_ok': compilation_ok,
+ 'has_sorries': has_sorries,
+ 'error_message': error_message,
+ 'errors': errors,
+ 'lean_code': lean_code,
+ 'return_code': result.returncode,
+ 'stdout': stdout,
+ 'stderr': stderr,
+ 'full_output': full_output,
+ 'theorem_name': theorem_name,
+ 'temp_filename': temp_filename,
+ 'temp_file_path': temp_file_path,
+ 'temp_file_kept': keep_temp_file,
+ 'debug_traces': list(self._debug_traces),
+ 'possible_proof_tactics': self.possible_proof_tactics
+ }
+
+ finally:
+ # Clean up the temporary file only if requested
+ if not keep_temp_file and os.path.exists(temp_file_path):
+ os.remove(temp_file_path)
+
+
+theorem_names_in_file_cache: Dict[str, List[TheoremDetails]] = {}
+namespace_regex = r"^namespace[ ]+([\S]+)"
+namespace_match = re.compile(namespace_regex, re.MULTILINE)
+namespace_end_regex = r"^end[ ]+([\S]+)*"
+namespace_end_match = re.compile(namespace_end_regex, re.MULTILINE)
+
+def parse_thm_name(theorem_name: str) -> Tuple[str, str]:
+ if theorem_name.startswith("{") and theorem_name.endswith("}"):
+ thm_dict = json.loads(theorem_name)
+ return thm_dict["namespace"], thm_dict["name"]
+ else:
+ return "", theorem_name
+
+def process_namespaces(file_cotent: str, open_namespaces: List[str], is_full_content: bool=False):
+ # Match the namespace regex
+ # Break the content line by line and match the namespace and end namespace
+ file_lines = file_cotent.split('\n')
+ for line in file_lines:
+ namespace_matches = namespace_match.findall(line)
+ namespace_end_matches = namespace_end_match.findall(line)
+ for ns in namespace_matches:
+ if not is_full_content or ns not in open_namespaces:
+ open_namespaces.append(ns)
+ for ns in namespace_end_matches:
+ try:
+ open_namespaces.remove(ns)
+ except ValueError:
+ pass
+
+def get_all_theorems_in_file(file_path: str, use_cache: bool=False) -> List[TheoremDetails]:
+ if use_cache and file_path in theorem_names_in_file_cache:
+ return theorem_names_in_file_cache[file_path]
+ file_content = ""
+ open_namespaces = []
+ with open(file_path, "r") as f:
+ file_content = f.read()
+ line_by_line_reader = LeanLineByLineReader(file_content=file_content, remove_comments=True, no_strip=True)
+ all_stmts = list(line_by_line_reader.instruction_step_generator())
+ line_positions = [0] + [len(stmt) + 1 for stmt in all_stmts]
+ # Cumulative sum of the line positions
+ for i in range(1, len(line_positions)):
+ line_positions[i] += line_positions[i - 1]
+ full_content = '\n'.join(all_stmts)
+ # all_matches = Lean4SyncExecutor.theorem_match.findall(full_content)
+ all_matches = list(SimpleLean4SyncExecutor.theorem_match.finditer(full_content))
+ all_theorems = []
+ last_namespace_processed_idx = 0
+ for match in all_matches:
+ span_start, span_end = match.span()
+ process_namespaces(full_content[last_namespace_processed_idx:span_start], open_namespaces)
+ theorem_name = match.group(5)
+ theorem_name = theorem_name if theorem_name is not None else f"\"{match.group(6).strip(': ')}\""
+ theorem_namespace = '.'.join(open_namespaces) if len(open_namespaces) > 0 else ''
+ line_number_start = bisect.bisect_left(line_positions, span_start)
+ line_number_end = bisect.bisect_left(line_positions, span_end)
+ theorem_pos = {
+ 'line_start': line_number_start + 1,
+ 'line_end': line_number_end + 1,
+ 'global_pos_start': span_start,
+ 'global_pos_end': span_end,
+ 'line_pos_start': span_start - line_positions[line_number_start] if line_number_start < len(line_positions) else 0,
+ 'line_pos_end': span_end - line_positions[line_number_end] if line_number_end < len(line_positions) else 0
+ }
+ theorem_details = TheoremDetails(
+ theorem_name=theorem_name,
+ theorem_namespace=theorem_namespace,
+ theorem_file_path=file_path,
+ theorem_pos=theorem_pos)
+ all_theorems.append(theorem_details)
+ last_namespace_processed_idx = span_end
+ if use_cache:
+ theorem_names_in_file_cache[file_path] = all_theorems
+ return all_theorems
+
+def get_fully_qualified_theorem_name(theorem_details: TheoremDetails) -> str:
+ if len(theorem_details.theorem_namespace) == 0:
+ return theorem_details.theorem_name
+ else:
+ dict_thm = {"namespace": theorem_details.theorem_namespace, "name": theorem_details.theorem_name}
+ return json.dumps(dict_thm)
+
+def get_theorem_name_resembling(file_path: str, theorem_name: str, use_cache: bool=False) -> Optional[str]:
+ all_theorems = get_all_theorems_in_file(file_path, use_cache=use_cache)
+ all_theorems_name_unique_map : Dict[str, List[TheoremDetails]] = {}
+ for thm in all_theorems:
+ if thm.theorem_name in all_theorems_name_unique_map:
+ all_theorems_name_unique_map[thm.theorem_name].append(thm)
+ else:
+ all_theorems_name_unique_map[thm.theorem_name] = [thm]
+ all_parts = theorem_name.split('.')
+ thm_start_idx = len(all_parts) - 1
+ thm_found = False
+ while not thm_found and thm_start_idx >= 0:
+ full_name = '.'.join(all_parts[thm_start_idx:])
+ # look for any theorems matching with full_name
+ thm_found = full_name in all_theorems_name_unique_map
+ thm_start_idx -= 1
+ if not thm_found:
+ full_name = '_root_.' + full_name
+ # look for any theorems matching with the full_name
+ thm_found = full_name in all_theorems_name_unique_map
+ if not thm_found:
+ raise ValueError(f"The theorem '{theorem_name}' was not found in the file '{file_path}'")
+ assert thm_found, "The theorem was not found some code bug in finding the theorem names"
+ theorem_name_matches = all_theorems_name_unique_map[full_name]
+ if len(theorem_name_matches) == 1:
+ if len(theorem_name_matches[0].theorem_namespace) == 0:
+ return theorem_name_matches[0].theorem_name
+ else:
+ dict_thm = {"namespace": theorem_name_matches[0].theorem_namespace, "name": theorem_name_matches[0].theorem_name}
+ return json.dumps(dict_thm)
+ else:
+ # We need to find the namespace which matches with the theorem_name
+ for thm in theorem_name_matches:
+ if theorem_name.endswith(thm.theorem_namespace + '.' + thm.theorem_name) or\
+ (theorem_name.strip() == thm.theorem_name and len(thm.theorem_namespace) == 0):
+ dict_thm = {"namespace": thm.theorem_namespace, "name": thm.theorem_name}
+ return json.dumps(dict_thm)
+ raise ValueError(f"The theorem '{theorem_name}' was not found in the file '{file_path}'")
+
+def execute_thm_line_by_line(file_path: str, project_root: str, theorem_name: str, logger: logging.Logger, with_print: bool=False):
+ pprint = lambda msg: print(msg) if with_print else None
+ with SimpleLean4SyncExecutor(main_file=file_path, project_root=project_root, logger=logger) as executor:
+ executor.set_run_exactly()
+ executor._skip_to_theorem(theorem_name)
+ assert executor.proof_context is not None, "Proof context should be present"
+ proof_exec = False
+ while not executor.execution_complete:
+ if executor.proof_context is not None:
+ proof_exec = True
+ for goal in executor.proof_context.all_goals:
+ for hyp in goal.hypotheses:
+ pprint(hyp)
+ pprint('-'*10)
+ pprint(goal.goal)
+ pprint('-'*20)
+ executor.run_next()
+ pprint(f"Current statement: {executor.current_stmt}")
+ if executor.proof_context is None and proof_exec:
+ proof_exec = False
+ pprint("Proof finished")
+ break
+ if executor.lean_error_messages:
+ pprint(f"Error messages:\n{executor.lean_error_messages}")
+
+if __name__ == "__main__":
+ from itp_interface.tools.log_utils import setup_logger
+ import datetime
+ parent = os.path.dirname(os.path.abspath(__file__))
+ root = os.path.dirname(os.path.dirname(parent))
+ os.chdir(root)
+ project_root = 'data/test/Mathlib/'
+ file_path = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Computability/TuringMachine.lean'
+ assert os.path.exists(project_root), "Project root does not exist"
+ assert os.path.exists(file_path), "File path does not exist"
+ print("Finding all theorems in the file")
+ all_theorems = get_all_theorems_in_file(file_path, use_cache=True)
+ print(all_theorems)
+ date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ lean_exec_log_folder = f'.log/lean4_sync_executor/{date_time}'
+ os.makedirs(lean_exec_log_folder, exist_ok=True)
+ lean_exec_log_file = os.path.join(lean_exec_log_folder, "lean4_sync_executor.log")
+ logger = setup_logger("Lean4SyncExecutor", lean_exec_log_file, level=logging.DEBUG, format='')
+ theorems_similar_to_test = get_theorem_name_resembling(file_path, "Turing.TM1to1.tr_supports", use_cache=True)
+ assert theorems_similar_to_test is not None, "Theorem similar to test should not be None"
+ print("Theorem similar to ", "Turing.TM1to1.tr_supports", " is ", theorems_similar_to_test)
+ project_root = 'data/test/lean4_proj/'
+ file_path = 'data/test/lean4_proj/Lean4Proj/Basic.lean'
+ theorem_name = "Lean4Proj2.test3"
+ theorems_similar_to_test = get_theorem_name_resembling(file_path, theorem_name, use_cache=True)
+ assert theorems_similar_to_test is not None, "Theorem similar to test should not be None"
+ print("Theorem similar to ", "Lean4Proj2.test3", " is ", theorems_similar_to_test)
+ execute_thm_line_by_line(file_path, project_root, theorems_similar_to_test, logger, with_print=True)
+ mathlib_test_file = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Data/Nat/Bits.lean'
+ project_root = 'data/test/Mathlib'
+ assert os.path.exists(mathlib_test_file), "Mathlib test file does not exist"
+ assert os.path.exists(project_root), "Project root does not exist"
+ theorems_similar_to_test = get_theorem_name_resembling(mathlib_test_file, "one_bits", use_cache=True)
+ assert theorems_similar_to_test is not None, "Theorem similar to test should not be None"
+ print("Theorem similar to ", "one_bits", " is ", theorems_similar_to_test)
+ execute_thm_line_by_line(mathlib_test_file, project_root, theorems_similar_to_test, logger, with_print=True)
\ No newline at end of file
diff --git a/src/itp_interface/tools/tactic_parser.py b/src/itp_interface/tools/tactic_parser.py
new file mode 100644
index 0000000..bf9638d
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser.py
@@ -0,0 +1,917 @@
+"""
+Ultra-simple tactic parser for Lean 4 code.
+
+This module provides a lightweight interface to parse tactics from Lean 4 proofs
+without compiling or running any code - just pure syntax parsing.
+
+The parser process runs in the background to avoid restart overhead.
+"""
+
+import base64
+import json
+import os
+import subprocess
+import logging
+import re
+import shutil
+from enum import Enum
+from pydantic import BaseModel, field_validator
+from pathlib import Path
+from typing import List, Dict, Optional, Union
+
+class Position(BaseModel):
+ """Represents a position in the source code."""
+ line: int # Line counting starts from 1
+ column: int # Column counting starts from 0
+
+ @field_validator('line', 'column')
+ def validate_non_negative(cls, v):
+ if v < 0:
+ raise ValueError("Line and column numbers must be non-negative")
+ return v
+
+ def __lt__(self, other: 'Position') -> bool:
+ if self.line == other.line:
+ return self.column < other.column
+ return self.line < other.line
+
+ def __eq__(self, value: object) -> bool:
+ if not isinstance(value, Position):
+ return NotImplemented
+ return self.line == value.line and self.column == value.column
+
+ def __le__(self, other: 'Position') -> bool:
+ return self < other or self == other
+
+ def is_contained_in(self, start: 'Position', end: 'Position') -> bool:
+ """Check if this position is within the range [start, end]."""
+ return start <= self <= end
+
+class TreeNode(BaseModel):
+ """Represents a node in the syntax tree."""
+ decl_type: Optional[str] = None
+ name: Optional[str] = None
+ doc_string: Optional[str] = None
+ start_pos: Optional[Position] = None
+ end_pos: Optional[Position] = None
+ text: Optional[str] = None
+ namespace: Optional[str] = None
+ children: List['TreeNode'] = []
+
+ def __lt__(self, other: 'TreeNode') -> bool:
+ if self.start_pos is None or other.start_pos is None:
+ return False
+ if self.start_pos == other.start_pos:
+ if self.end_pos is None or other.end_pos is None:
+ return False
+ return self.end_pos < other.end_pos
+ return self.start_pos < other.start_pos
+
+ def __le__(self, other: 'TreeNode') -> bool:
+ return self < other or self == other
+
+ def __eq__(self, value: object) -> bool:
+ if not isinstance(value, TreeNode):
+ return NotImplemented
+ return (self.start_pos == value.start_pos and
+ self.end_pos == value.end_pos)
+
+
+ def is_contained_in(self, tree_node: 'TreeNode') -> bool:
+ """Check if this node is contained within another node's position range."""
+ if self.start_pos is None or self.end_pos is None:
+ return False
+ if tree_node.start_pos is None or tree_node.end_pos is None:
+ return False
+ return (self.start_pos.is_contained_in(tree_node.start_pos, tree_node.end_pos) and
+ self.end_pos.is_contained_in(tree_node.start_pos, tree_node.end_pos))
+
+class ErrorInfo(BaseModel):
+ """Represents an error in parsing."""
+ message: str
+ position: Position
+
+ def to_json(self) -> str:
+ # Use pydantic's built-in json method
+ return ErrorInfo.model_dump_json(self)
+
+class LeanLineInfo(BaseModel):
+ """Information about a single tactic."""
+ text: str
+ line: int
+ column: int
+ end_line: int
+ end_column: int
+ decl_type: Optional[str] = None
+ name: Optional[str] = None
+ doc_string: Optional[str] = None
+ namespace: Optional[str] = None
+
+ def __repr__(self) -> str:
+ return f"LeanLineInfo(text={self.text!r}, line={self.line}, column={self.column})"
+
+ def to_dict(self) -> Dict:
+ return {
+ "text": self.text,
+ "line": self.line,
+ "column": self.column,
+ "endLine": self.end_line,
+ "endColumn": self.end_column,
+ "declType": self.decl_type,
+ "name": self.name,
+ "docString": self.doc_string,
+ "namespace": self.namespace
+ }
+
+
+ def to_json(self, indent=0) -> str:
+ if indent == 0:
+ return self.model_dump_json()
+ else:
+ return self.model_dump_json(indent=indent)
+
+ @staticmethod
+ def load_from_file(file_path: str):
+ raise NotImplementedError("load_from_file must be implemented by the child class")
+
+ @staticmethod
+ def load_from_string(json_text: str):
+ raise NotImplementedError("load_from_string must be implemented by the child class")
+
+class DeclarationDependency(BaseModel):
+ """Information about a single dependency reference."""
+ name: str # Fully qualified name (e.g., "Nat.add_zero")
+ namespace: Optional[str] = None # Namespace portion (e.g., "Nat")
+ local_name: str # Local name without namespace (e.g., "add_zero")
+ file_path: Optional[str] = None # Source file if resolvable
+ module_name: Optional[str] = None # Module where defined
+ decl_id: Optional[str] = None # Optional declaration ID for linking
+
+ def __repr__(self) -> str:
+ module_info = f" (from {self.module_name})" if self.module_name else ""
+ return f"{self.name}{module_info}"
+
+class DeclWithDependencies(BaseModel):
+ """Declaration with its dependencies - designed for merging into larger collections."""
+ decl_info: LeanLineInfo # Declaration metadata
+ dependencies: List[DeclarationDependency]
+ unresolved_names: List[str] = [] # Names we couldn't resolve
+ decl_id: Optional[str] = None # Optional unique ID for this declaration
+
+ def __repr__(self) -> str:
+ id_info = f" [ID: {self.decl_id}]" if self.decl_id else ""
+ return f"[{self.decl_info.decl_type}] {self.decl_info.name} ({len(self.dependencies)} deps){id_info}"
+
+ def set_decl_id(self, decl_id: str) -> 'DeclWithDependencies':
+ """Set declaration ID (useful for chaining)."""
+ self.decl_id = decl_id
+ return self
+
+ def to_json(self, indent=0) -> str:
+ if indent == 0:
+ return self.model_dump_json()
+ else:
+ return self.model_dump_json(indent=indent)
+
+ @staticmethod
+ def from_dict(decl_data: Dict) -> 'DeclWithDependencies':
+ """Create from dictionary (e.g., from JSON dependency parser output)."""
+ decl_dict = decl_data['declaration']
+
+ # Create LeanLineInfo from declaration data
+ declaration = LeanLineInfo(
+ text=decl_dict.get('text', ''),
+ line=decl_dict.get('startPos', 0),
+ column=0,
+ end_line=decl_dict.get('endPos', 0),
+ end_column=0,
+ decl_type=decl_dict.get('declType'),
+ name=decl_dict.get('name'),
+ doc_string=decl_dict.get('docString'),
+ namespace=decl_dict.get('namespace')
+ )
+
+ # Parse dependencies
+ dependencies = []
+ for dep_data in decl_data.get('dependencies', []):
+ dependencies.append(DeclarationDependency(
+ name=dep_data['name'],
+ namespace=dep_data.get('namespace'),
+ local_name=dep_data['localName'],
+ file_path=dep_data.get('filePath'),
+ module_name=dep_data.get('moduleName'),
+ decl_id=dep_data.get('declId')
+ ))
+
+ return DeclWithDependencies(
+ decl_info=declaration,
+ dependencies=dependencies,
+ unresolved_names=decl_data.get('unresolvedNames', []),
+ decl_id=decl_data.get('declId')
+ )
+
+ @staticmethod
+ def from_dependency_analysis(analysis_dict: Dict) -> List['DeclWithDependencies']:
+ """
+ Extract list of declarations from dependency parser output.
+
+ Args:
+ analysis_dict: Dict returned by parse_file with PARSE_DEPENDS
+
+ Returns:
+ List of DeclWithDependencies that can be merged into larger collections
+ """
+ return [
+ DeclWithDependencies.from_dict(decl_data)
+ for decl_data in analysis_dict.get('declarations', [])
+ ]
+
+ @staticmethod
+ def load_from_file(file_path: str) -> 'DeclWithDependencies':
+ raise NotImplementedError("load_from_file must be implemented by the child class")
+
+ @staticmethod
+ def load_from_string(json_text: str) -> 'DeclWithDependencies':
+ raise NotImplementedError("load_from_string must be implemented by the child class")
+
+class FileDependencyAnalysis(BaseModel):
+ """
+ File-level dependency analysis output from the Lean dependency parser.
+
+ This model is used to parse the JSON output from the dependency-parser executable.
+ For merging into larger collections, extract the declarations list.
+ """
+ file_path: str
+ module_name: str
+ imports: List[Dict] # Raw import info
+ declarations: List[DeclWithDependencies]
+
+ def __repr__(self) -> str:
+ return f"FileDependencyAnalysis({self.module_name}, {len(self.declarations)} decls)"
+
+ # def to_dict(self) -> Dict:
+ # """Convert to dictionary matching the JSON format."""
+ # return {
+ # "filePath": self.file_path,
+ # "moduleName": self.module_name,
+ # "imports": self.imports,
+ # "declarations": [
+ # {
+ # "declaration": decl.declaration.to_dict(),
+ # "dependencies": [dep.model_dump(exclude_none=True) for dep in decl.dependencies],
+ # "unresolvedNames": decl.unresolved_names,
+ # **({"declId": decl.decl_id} if decl.decl_id else {})
+ # }
+ # for decl in self.declarations
+ # ]
+ # }
+
+ # @staticmethod
+ # def from_dict(data: Dict) -> 'FileDependencyAnalysis':
+ # """Parse from the JSON dict returned by dependency-parser executable."""
+ # declarations = DeclWithDependencies.from_dependency_analysis(data)
+
+ # return FileDependencyAnalysis(
+ # file_path=data['filePath'],
+ # module_name=data['moduleName'],
+ # imports=data.get('imports', []),
+ # declarations=declarations
+ # )
+
+# Create an enum for parsing request type
+class RequestType(Enum):
+ PARSE_TACTICS = "parse_tactics"
+ PARSE_THEOREM = "parse_theorem"
+ CHKPT_TACTICS = "chkpt_tactics"
+ BREAK_CHCKPNT = "break_chckpnt"
+ PARSE_DEPENDS = "parse_depends" # 13 chars - for dependency analysis
+
+def get_path_to_tactic_parser_project() -> str:
+ """Get the path to the tactic parser project directory."""
+ tools_dir = os.path.dirname(__file__)
+ tactic_parser_path = os.path.join(tools_dir, "tactic_parser")
+ abs_path = os.path.abspath(tactic_parser_path)
+ return abs_path
+
+def get_path_to_tactic_parser_executable() -> str:
+ """Get the path to the tactic parser executable."""
+ abs_path = get_path_to_tactic_parser_project()
+ tactic_parser_bin_path = os.path.join(abs_path, ".lake", "build", "bin", "tactic-parser")
+ return tactic_parser_bin_path
+
+def is_tactic_parser_built() -> bool:
+ """Check if the tactic parser executable exists."""
+ path_to_exec = get_path_to_tactic_parser_executable()
+ if not os.path.isfile(path_to_exec):
+ return False
+ else:
+ lean_version_needed = os.getenv("LEAN_VERSION", None)
+ if lean_version_needed is None:
+ return True
+ tactic_parser_project = get_path_to_tactic_parser_project()
+ # Check the version of the built parser
+ toolchain_file = os.path.join(tactic_parser_project, "lean-toolchain")
+ assert os.path.isfile(toolchain_file), f"lean-toolchain file not found at {toolchain_file}, something is wrong."
+ with open(toolchain_file, 'r') as f:
+ toolchain_content = f.read()
+ toolchain_content = toolchain_content.strip()
+ if toolchain_content.endswith(lean_version_needed):
+ return True
+ else:
+ # Replace the version in the toolchain file
+ # The version should be like 4.x.y
+ pattern = r'^4\.\d+\.\d+$'
+ if not re.match(pattern, lean_version_needed):
+ raise RuntimeError(f"Tactic parser built with Lean version {toolchain_content}, but version {lean_version_needed} is required." +
+ "Don't know how to build Lean which is not of the form 4.x.y. " +
+ "Please rebuild the tactic parser.")
+ toolchain_final = f"leanprover/lean4:v{lean_version_needed}"
+ with open(toolchain_file, 'w') as f:
+ f.write(toolchain_final)
+ return False
+
+def build_lean4_project(project_folder, logger: Optional[logging.Logger] = None, has_executable: bool = False):
+ """Build the Lean4 project at the given folder."""
+
+ logger = logger if logger else logging.getLogger(__name__)
+ lake_folder = os.path.join(project_folder, ".lake")
+ if os.path.exists(lake_folder):
+ logger.info(f"Cleaning existing .lake folder at {lake_folder} before build.")
+ shutil.rmtree(lake_folder)
+ # Define the command
+ if has_executable:
+ command = f"cd {project_folder} && lake build"
+ else:
+ command = f"cd {project_folder} && lake exe cache get && lake build"
+
+ logging.info(f"Building Lean4 project {project_folder}...")
+
+ # Run the command
+ # - shell=True is needed to process 'cd' and '&&'
+ # - capture_output=True captures stdout and stderr
+ # - text=True decodes stdout/stderr as text (using default encoding)
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
+
+ # Print the build logs from stdout
+ logging.info('-'*15 + f'Build Logs from {project_folder}' + '-'*15)
+ logging.info(result.stdout)
+
+ # Optionally print error logs if any exist
+ if result.stderr:
+ logging.error('-'*15 + f'Error Logs from {project_folder}' + '-'*15)
+ logging.error(result.stderr)
+
+ logging.info('-'*15 + f'End Build Logs from {project_folder}' + '-'*15)
+
+ # --- Here is how you check the exit code ---
+ exit_code = result.returncode
+ logging.info(f"Process finished with exit code: {exit_code}")
+
+ # You can now act on the exit code
+ if exit_code == 0:
+ logging.info("Build successful!")
+ else:
+ logging.error("Build FAILED!")
+ raise Exception(f"Build failed with code {exit_code}")
+
+def build_tactic_parser_if_needed(logger: Optional[logging.Logger] = None):
+ """Build the tactic parser if not already built."""
+ if not is_tactic_parser_built():
+ build_lean4_project(get_path_to_tactic_parser_project(), logger, has_executable=True)
+
+def get_path_to_dependency_parser_executable() -> str:
+ """Get the path to the dependency parser executable."""
+ abs_path = get_path_to_tactic_parser_project()
+ dependency_parser_bin_path = os.path.join(abs_path, ".lake", "build", "bin", "dependency-parser")
+ return dependency_parser_bin_path
+
+def analyze_lean_file_dependencies(
+ full_lean_file_path: str,
+ json_output_path: str,
+ working_dir: Optional[str] = None,
+ logger: Optional[logging.Logger] = None
+) -> tuple[List[FileDependencyAnalysis], List[ErrorInfo]]:
+ """
+ Analyze dependencies in a Lean file and export to JSON.
+
+ Args:
+ full_lean_file_path: Path to the Lean file to analyze (relative to working_dir)
+ json_output_path: Path where JSON output will be written (relative to working_dir)
+ working_dir: Working directory (Lean project root). If None, uses current directory.
+ logger: Optional logger for debugging
+
+ Returns:
+ tuple: (FileDependencyAnalysis, List[ErrorInfo]) - analysis results and any errors
+
+ Raises:
+ FileNotFoundError: If the executable or input file doesn't exist
+ subprocess.CalledProcessError: If the analysis fails
+ """
+ if logger is None:
+ logger = logging.getLogger(__name__)
+
+ if working_dir is None:
+ working_dir = os.getcwd()
+
+ # Ensure the dependency parser is built
+ build_tactic_parser_if_needed(logger)
+
+ # Get the executable path
+ exec_path = get_path_to_dependency_parser_executable()
+ if not os.path.isfile(exec_path):
+ raise FileNotFoundError(
+ f"Dependency parser executable not found at {exec_path}. "
+ "Please build it first with 'lake build'"
+ )
+
+ # Verify the input file exists (relative to working_dir)
+ # full_lean_path = Path(working_dir) / lean_file_path
+ if not os.path.exists(full_lean_file_path):
+ raise FileNotFoundError(f"Lean file not found: {full_lean_file_path}")
+
+ # Build the command
+ cmds = ["lake", "env", str(exec_path), full_lean_file_path, json_output_path]
+
+ logger.debug(f"Running dependency analysis: {' '.join(cmds)}")
+ logger.debug(f"Working directory: {working_dir}")
+
+ # Execute the command
+ result = subprocess.run(
+ cmds,
+ cwd=working_dir,
+ capture_output=True,
+ text=True,
+ check=False # Don't raise on error, handle it ourselves
+ )
+
+ logger.debug(f"Dependency analysis stdout: {result.stdout}")
+ if result.stderr:
+ logger.warning(f"Dependency analysis stderr: {result.stderr}")
+
+ # Check for errors
+ errors: List[ErrorInfo] = []
+ if result.returncode != 0 or result.stderr:
+ error_msg = f"Dependency parser failed with code {result.returncode}"
+ if result.stderr:
+ error_msg += f": {result.stderr}"
+ errors.append(ErrorInfo(message=error_msg, position=Position(line=0, column=0)))
+
+ # Read and parse the JSON output
+ full_json_path = Path(working_dir) / json_output_path
+ if full_json_path.exists():
+ with open(full_json_path, 'r') as f:
+ data = json.load(f)
+ analysis = FileDependencyAnalysis.model_validate(data)
+ else:
+ # Create empty analysis if file doesn't exist
+ analysis = FileDependencyAnalysis(
+ file_path=full_lean_file_path,
+ module_name="",
+ imports=[],
+ declarations=[]
+ )
+ if not errors:
+ errors.append(ErrorInfo(
+ message="Output file was not created",
+ position=Position(line=0, column=0)
+ ))
+
+ return [analysis], errors
+
+def get_from_original_text(code: str, lean_info: LeanLineInfo, relative_line_num : int = 1) -> str:
+ """Extract the text corresponding to a LeanLineInfo from the code."""
+ lines = code.splitlines()
+ start_line_idx = lean_info.line - relative_line_num
+ end_line_idx = lean_info.end_line - relative_line_num
+
+ if start_line_idx < 0 or end_line_idx >= len(lines):
+ raise ValueError("LeanLineInfo line numbers are out of bounds")
+
+ if start_line_idx == end_line_idx:
+ # Single line case
+ return lines[start_line_idx][lean_info.column:lean_info.end_column]
+ else:
+ # Multi-line case
+ extracted_lines = []
+ # First line
+ extracted_lines.append(lines[start_line_idx][lean_info.column:])
+ # Middle lines
+ for i in range(start_line_idx + 1, end_line_idx):
+ extracted_lines.append(lines[i])
+ # Last line
+ extracted_lines.append(lines[end_line_idx][:lean_info.end_column])
+ return '\n'.join(extracted_lines)
+
+theorem_name_regex = r"(((theorem|lemma)[\s]+([^\s:]*))|example)"
+theorem_name_match = re.compile(theorem_name_regex, re.MULTILINE)
+
+def parse_theorem_name(thm_stmt: str) -> Optional[str]:
+ match = theorem_name_match.search(thm_stmt)
+ if match:
+ theorem_name = match.group(4)
+ return theorem_name
+ return None
+
+class TacticParser:
+ """Parse tactics from Lean 4 code without compilation.
+
+ The parser process runs in the background and is reused across multiple requests.
+
+ If you want to parse tactics that use mathlib or other dependencies, provide a
+ project_path when initializing the parser. The process will run from that directory
+ and automatically find the project's .lake/build with all dependencies.
+ """
+
+ def __init__(self, parser_path: Optional[str] = None, project_path: Optional[str] = None, logger: Optional[logging.Logger] = None):
+ """
+ Initialize the tactic parser.
+
+ Args:
+ parser_path: Path to the tactic-parser executable. If None, uses the default path.
+ project_path: Path to a Lean project directory (contains lakefile.toml and .lake/build).
+ If provided, the parser will run from this directory and can use the
+ project's dependencies (like mathlib). If None, uses minimal environment.
+ logger: Optional logger for debugging
+ """
+ if parser_path is None:
+ # Default path relative to this file
+ default_path = Path(__file__).parent / "tactic_parser" / ".lake" / "build" / "bin" / "tactic-parser"
+ self.parser_path = str(default_path)
+ else:
+ self.parser_path = parser_path
+
+ self.project_path = project_path
+ self.logger = logger if logger else logging.getLogger(__name__)
+ self.process: Optional[subprocess.Popen] = None
+ self._start()
+
+ def _start(self):
+ """Start the tactic parser process."""
+ try:
+ # Determine working directory:
+ # - If project_path provided: use project directory (finds .lake/build automatically)
+ # - Otherwise: use tactic_parser directory (minimal environment)
+ if self.project_path:
+ working_dir = self.project_path
+ self.logger.debug(f"Starting parser in project mode from: {working_dir}")
+ else:
+ working_dir = Path(self.parser_path).parent.parent.parent
+ self.logger.debug(f"Starting parser in standalone mode from: {working_dir}")
+ # Ensure the parser is built
+ build_tactic_parser_if_needed(self.logger)
+ path_to_tactic_parser_exec = get_path_to_tactic_parser_executable()
+ assert os.path.isfile(path_to_tactic_parser_exec), f"Tactic parser executable not found at {path_to_tactic_parser_exec}, please build it first."
+ cmds = ["lake", "env", path_to_tactic_parser_exec]
+ self.process = subprocess.Popen(
+ cmds,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1, # Line buffered
+ cwd=str(working_dir)
+ )
+ self.logger.debug(f"Started tactic parser process (PID: {self.process.pid})")
+ except FileNotFoundError:
+ raise RuntimeError(
+ f"Tactic parser not found at {self.parser_path}. "
+ f"Please build it first with: cd {Path(self.parser_path).parent.parent.parent} && lake build"
+ )
+
+ def _ensure_running(self):
+ """Ensure the process is running, restart if needed."""
+ if self.process is None or self.process.poll() is not None:
+ self.logger.warning("Tactic parser process died, restarting...")
+ self._start()
+
+ def _is_tactic_request(self, parse_type: RequestType) -> bool:
+ return parse_type == RequestType.PARSE_TACTICS or parse_type == RequestType.CHKPT_TACTICS or parse_type == RequestType.BREAK_CHCKPNT
+
+ def parse(self, lean_code: str, fail_on_error: bool = True, parse_type: RequestType = RequestType.PARSE_TACTICS) -> tuple[List[LeanLineInfo], List[ErrorInfo]]:
+ """
+ Parse tactics from Lean 4 code.
+
+ Args:
+ lean_code: Lean 4 source code as a string
+
+ Returns:
+ List of leanInfo objects
+
+ Raises:
+ RuntimeError: If parsing fails
+ """
+ self._ensure_running()
+
+ retry_cnt = 5
+ succeeded = False
+
+ while retry_cnt > 0 and not succeeded:
+ # Encode Lean code as base64
+ final_code = parse_type.value + lean_code
+ b64_input = base64.b64encode(final_code.encode('utf-8')).decode('ascii')
+ self.logger.debug(f"Sending {len(final_code)} bytes of Lean code")
+ self.logger.debug(f"Base64 encoded input length: {len(b64_input)}")
+ self.logger.debug(f"Input (base64): {b64_input}")
+
+ # Send to parser
+ try:
+ self.process.stdin.write(b64_input + '\n')
+ self.process.stdin.flush()
+ succeeded = True
+ except BrokenPipeError:
+ self.logger.error("Broken pipe, restarting process")
+ self._start()
+ retry_cnt -= 1
+
+ # Read JSON response (one line)
+ try:
+ response_line = self.process.stdout.readline()
+ self.logger.debug(f"Response: {response_line.strip()}")
+ if not response_line:
+ # Check stderr for error messages
+ stderr_output = self.process.stderr.read() if self.process.stderr else ""
+ raise RuntimeError(f"Parser process died unexpectedly. Stderr: {stderr_output}")
+ except Exception as e:
+ raise RuntimeError(f"Failed to read response: {e}")
+
+ # Parse JSON response
+ try:
+ response = json.loads(response_line.strip())
+ except json.JSONDecodeError as e:
+ raise RuntimeError(f"Failed to parse JSON response: {e}\nOutput: {response_line}")
+
+ # Check for errors
+ errors : List[ErrorInfo] = []
+ if response.get("errors"):
+ if fail_on_error:
+ raise RuntimeError(f"Parse error: {response['errors']}")
+ else:
+ for err in response["errors"]:
+ error_info = ErrorInfo.model_validate(err)
+ errors.append(error_info)
+ self.logger.debug(f"Parse error: {error_info}")
+
+ # Convert tree to leanInfo objects
+ trees : list[TreeNode] = []
+ tactics = []
+ for t in response.get("trees", []):
+ if t is not None:
+ tree = TreeNode.model_validate(t)
+ trees.append(tree)
+ for t in trees:
+ assert t.start_pos is not None
+ assert t.end_pos is not None
+ if t.decl_type is not None and (t.decl_type == "theorem" or t.decl_type == "lemma"):
+ # TODO: Fix the incorrect theorem/lemma name parsing from the underlying lean tool
+ actual_name = parse_theorem_name(t.text if t.text else "")
+ assert actual_name is not None, "Theorem/lemma name should not be None"
+ if t.name != actual_name:
+ t.name = actual_name
+ tactics.append(
+ LeanLineInfo(
+ text=t.text if t.text else "",
+ line=t.start_pos.line,
+ column=t.start_pos.column,
+ end_line=t.end_pos.line,
+ end_column=t.end_pos.column,
+ decl_type=t.decl_type,
+ name=t.name,
+ doc_string=t.doc_string,
+ namespace=t.namespace
+ )
+ )
+ self.logger.debug(f"Parsed {len(tactics)} tactics")
+
+ return tactics, errors
+
+ def parse_file(self, file_path: str, parse_type: RequestType = RequestType.PARSE_THEOREM, json_output_path: Optional[str] = None) -> tuple[Union[List[LeanLineInfo], List[FileDependencyAnalysis]], List[ErrorInfo]]:
+ """
+ Parse tactics from a Lean 4 file or analyze its dependencies.
+
+ Args:
+ file_path: Path to the Lean 4 file
+ parse_type: Type of parsing to perform
+ json_output_path: For PARSE_DEPENDS only - path where JSON output will be written.
+ If None, generates a temporary path.
+
+ Returns:
+ - For PARSE_TACTICS/PARSE_THEOREM/CHKPT_TACTICS/BREAK_CHCKPNT: tuple of (List[LeanLineInfo], List[ErrorInfo])
+ - For PARSE_DEPENDS: tuple of (FileDependencyAnalysis, List[ErrorInfo])
+ """
+ if parse_type == RequestType.PARSE_DEPENDS:
+ # Use dependency parser executable
+ if json_output_path is None:
+ # Generate a temporary output path
+ json_output_file_path = Path(file_path).with_suffix('.deps.json')
+ json_output_path = str(json_output_file_path)
+
+ # Determine working directory
+ if self.project_path:
+ working_dir = self.project_path
+ # Make file_path relative to working_dir if it's absolute
+ file_path_obj = Path(file_path)
+ # Make sure that path is absolute
+ if not file_path_obj.is_absolute():
+ file_path = str(file_path_obj.resolve())
+ working_dir = str(Path(working_dir).resolve())
+ json_output_path = str(Path(json_output_path).resolve())
+ else:
+ working_dir = str(Path(file_path).parent.resolve())
+ file_path = str(Path(file_path).resolve())
+ json_output_path = str(Path(json_output_path).resolve())
+
+ return analyze_lean_file_dependencies(
+ full_lean_file_path=file_path,
+ json_output_path=json_output_path,
+ working_dir=working_dir,
+ logger=self.logger
+ )
+ else:
+ # Use normal tactic parser
+ with open(file_path, 'r', encoding='utf-8') as f:
+ lean_code = f.read()
+ return self.parse(lean_code, parse_type=parse_type)
+
+ def close(self):
+ """Close the parser process."""
+ if self.process:
+ try:
+ # Send exit command
+ self.process.stdin.write('\n')
+ self.process.stdin.flush()
+ self.process.wait(timeout=1)
+ except:
+ self.process.kill()
+ self.process = None
+ self.logger.debug("Closed tactic parser process")
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def __del__(self):
+ self.close()
+
+# Example usage
+def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] = None):
+ for tactic in tactics:
+ msg = f"Line {tactic.line}, Col {tactic.column} to Line {tactic.end_line}, Col {tactic.end_column}: {tactic.text}"
+ if logger:
+ logger.info(msg)
+ else:
+ print(msg)
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.DEBUG)
+ project_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj")
+
+ with TacticParser() as parser:
+ # Example 1: Simple proof
+ lean_code = "example : True := by trivial"
+
+ print("Parsing example 1...")
+ tactics, errors = parser.parse(lean_code)
+ print_tactics(tactics)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser() as parser:
+ # Example 1b: Simple have proofs
+ # p \implies q and q \implies r then have p \implies r
+ lean_code = """
+example (p q r: Prop) (h1: p → q) (h2: q → r) : p → r := by
+ have h3: p → r := by
+ try simp
+ wrong_tactic
+ done
+"""
+
+ print("Parsing example 1b...")
+ tactics, errors = parser.parse(lean_code, fail_on_error=False)
+ print_tactics(tactics)
+ if errors:
+ print(f"Error: {errors}")
+
+
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 2: Multiline with params
+ lean_code2 = "example (r: Nat) (p q : Prop) (hp : p) (hq : q) : p ∧ q := by\n apply And.intro\n exact hp\n exact hq"
+
+ print("\nParsing example 2...")
+ tactics2, errors = parser.parse(lean_code2)
+ print_tactics(tactics2)
+ if errors:
+ print(f"Error: {errors}")
+
+ # Check if linarith is parsed correctly
+ lean_code3 = """
+import Mathlib
+
+example (a b : Nat)
+(h1: a + b = 10)
+(h2: a = 5) :
+b = 5:= by
+ rw [h2] at h1
+ linarith
+"""
+ print("\nParsing example 3...")
+ tactics3, errors = parser.parse(lean_code3)
+ print_tactics(tactics3)
+ if errors:
+ print(f"Error: {errors}")
+
+ file_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj" / "Lean4Proj" / "Basic.lean")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 4: Parse from file
+ print("\nParsing example 4 (from file)...")
+ tactics4, errors = parser.parse_file(file_path)
+ print_tactics(tactics4)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 2: Multiline with params
+ lean_code4 = "example (r: ℕ) (p q : Prop) (hp : p) (hq : q) : p ∧ q := by grind"
+
+ print("\nParsing example 5...")
+ tactics5, errors = parser.parse(lean_code4)
+ print_tactics(tactics5)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 6: Parse tactics from file with multiple theorems
+ print("\nParsing example 6 (theorem parsing from file)...")
+ tactics6, errors = parser.parse(lean_code3 + "\n" + lean_code4, parse_type=RequestType.PARSE_TACTICS)
+ print_tactics(tactics6)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 7: Parse tactics which are wrong
+ print("\nParsing example 7 (theorem declaration parsing from file)...")
+ lean_code5 = "theorem wrong_decl : Nat := by assdfadfs"
+ tactics7, errors = parser.parse(lean_code5, fail_on_error=False)
+ print_tactics(tactics7)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 8: Parse tactics just before `by`
+ print("\nParsing example 8 (theorem with just before `by`...)")
+ lean_code8 = "theorem temp: 1 + 2 = 3 :=\nby"
+ tactics8, errors = parser.parse(lean_code8, fail_on_error=False)
+ print_tactics(tactics8)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 9: Parse tactics just before `by`
+ print("\nParsing example 9 (theorem with just before `by`...)")
+ lean_code9 = "import Mathlib\ntheorem temp: 1 + 2 = 3 :=\nby\n have h1: 1 + 1 = 2 := by\n linarith\n done"
+ tactics9, errors = parser.parse(lean_code9, fail_on_error=False)
+ print_tactics(tactics9)
+ if errors:
+ print(f"Error: {errors}")
+
+ with TacticParser(project_path=project_path) as parser:
+ # Example 10: Test checkpointing
+ print("\nParsing example 10 (checkpointing...)")
+ lean_code10 = """import Mathlib
+
+theorem temp: 1 + 2 = 3 :=
+by
+linarith
+
+theorem temp1: 3 = 1 + 1 + 1 :=
+by
+linarith
+"""
+ tactics10, errors = parser.parse(lean_code10, fail_on_error=True, parse_type=RequestType.CHKPT_TACTICS)
+ print_tactics(tactics10)
+ if errors:
+ print(f"Error: {errors}")
+ # Now just execute from the checkpoint
+ lean_code10b = """
+theorem temp2: 1 + 2 = 3 :=
+by
+have h_temp := temp1
+"""
+ print("\nContinuing from checkpoint...")
+ tactics10b, errors = parser.parse(lean_code10b, fail_on_error=False, parse_type=RequestType.PARSE_TACTICS)
+ print_tactics(tactics10b)
+ if errors:
+ # The error should contain h_temp
+ print(f"Error: {errors}")
+
+ print("\nBreaking checkpoint...")
+ new_lean_code10c = lean_code10 + lean_code10b
+ tactics10c, errors = parser.parse(new_lean_code10c, fail_on_error=False, parse_type=RequestType.BREAK_CHCKPNT)
+ # ^This will reimport everything all run all theorems from scratch
+ print_tactics(tactics10c)
+ if errors:
+ print(f"Error: {errors}")
\ No newline at end of file
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser.lean b/src/itp_interface/tools/tactic_parser/TacticParser.lean
new file mode 100644
index 0000000..da652e7
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser.lean
@@ -0,0 +1,4 @@
+import TacticParser.Base64
+import TacticParser.Types
+import TacticParser.SyntaxWalker
+import TacticParser.Main
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean b/src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean
new file mode 100644
index 0000000..7e8f02b
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean
@@ -0,0 +1,79 @@
+/-
+Base64 decoder for receiving Lean code from Python.
+-/
+
+namespace TacticParser.Base64
+
+/-- Helper to convert Option to Except -/
+def Option.toExcept {α : Type} (o : Option α) (err : String) : Except String α :=
+ match o with
+ | some a => .ok a
+ | none => .error err
+
+/-- Base64 character to 6-bit value mapping -/
+def charToValue (c : Char) : Option UInt8 :=
+ if 'A' ≤ c ∧ c ≤ 'Z' then
+ some (c.toNat - 'A'.toNat).toUInt8
+ else if 'a' ≤ c ∧ c ≤ 'z' then
+ some (c.toNat - 'a'.toNat + 26).toUInt8
+ else if '0' ≤ c ∧ c ≤ '9' then
+ some (c.toNat - '0'.toNat + 52).toUInt8
+ else if c = '+' then
+ some 62
+ else if c = '/' then
+ some 63
+ else
+ none
+
+/-- Decode a base64 string to bytes -/
+def decodeBytes (s : String) : Except String ByteArray :=
+ let chars := s.trim.toList.filter (· ≠ '=')
+ let rec loop (i : Nat) (result : ByteArray) : Except String ByteArray :=
+ if i >= chars.length then
+ .ok result
+ else
+ -- Get 4 characters (or remaining)
+ let c1 := chars[i]!
+ let c2 := if i + 1 < chars.length then chars[i + 1]! else '='
+ let c3 := if i + 2 < chars.length then chars[i + 2]! else '='
+ let c4 := if i + 3 < chars.length then chars[i + 3]! else '='
+
+ match Option.toExcept (charToValue c1) "Invalid base64 character" with
+ | .error e => .error e
+ | .ok v1 =>
+ match Option.toExcept (charToValue c2) "Invalid base64 character" with
+ | .error e => .error e
+ | .ok v2 =>
+ -- First byte is always present
+ let b1 := (v1.toNat <<< 2 ||| (v2.toNat >>> 4)).toUInt8
+ let result := result.push b1
+
+ -- Second byte if c3 exists
+ if c3 ≠ '=' then
+ match Option.toExcept (charToValue c3) "Invalid base64 character" with
+ | .error e => .error e
+ | .ok v3 =>
+ let b2 := ((v2.toNat &&& 0xF) <<< 4 ||| (v3.toNat >>> 2)).toUInt8
+ let result := result.push b2
+
+ -- Third byte if c4 exists
+ if c4 ≠ '=' then
+ match Option.toExcept (charToValue c4) "Invalid base64 character" with
+ | .error e => .error e
+ | .ok v4 =>
+ let b3 := ((v3.toNat &&& 0x3) <<< 6 ||| v4.toNat).toUInt8
+ loop (i + 4) (result.push b3)
+ else
+ loop (i + 4) result
+ else
+ loop (i + 4) result
+ loop 0 ByteArray.empty
+
+/-- Decode a base64 string to UTF-8 string -/
+def decode (s : String) : Except String String := do
+ let bytes ← decodeBytes s
+ match String.fromUTF8? bytes with
+ | some str => return str
+ | none => throw "Invalid UTF-8 encoding"
+
+end TacticParser.Base64
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean b/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean
new file mode 100644
index 0000000..5cb9479
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean
@@ -0,0 +1,486 @@
+import Lean
+import Lean.Data.Json
+import Lean.Elab.Frontend
+import TacticParser.Types
+import TacticParser.LineParser
+
+namespace TacticParser
+
+open Lean
+open Lean.Parser
+open Lean.Elab
+
+/-- Extract module name from import syntax -/
+partial def extractModuleName (stx : Syntax) : String :=
+ match stx with
+ | Syntax.ident _ _ name _ => name.toString
+ | Syntax.node _ _ args =>
+ -- Search through arguments to find identifiers and combine them
+ let parts := args.filterMap fun arg =>
+ let name := extractModuleName arg
+ if name != "" then some name else none
+ String.intercalate "." parts.toList
+ | _ => ""
+
+/-- Convert file path to module name -/
+def filePathToModuleName (filepath : System.FilePath) : String :=
+ let pathStr := filepath.toString
+ -- Remove .lean extension
+ let withoutExt := if pathStr.endsWith ".lean" then
+ pathStr.dropRight 5
+ else
+ pathStr
+ -- Replace path separators with dots
+ let modulePath := withoutExt.replace "/" "."
+ -- Remove leading ./ if present
+ if modulePath.startsWith ".." then
+ modulePath.drop 2
+ else if modulePath.startsWith "." then
+ modulePath.drop 1
+ else
+ modulePath
+
+/-- Split a fully qualified name into namespace and local name -/
+def splitNamespace (fullName : Name) : Option String × String :=
+ let str := fullName.toString
+ match str.splitOn "." with
+ | [] => (none, str)
+ | [single] => (none, single)
+ | parts =>
+ let ns := String.intercalate "." (parts.dropLast)
+ let localPart := parts.getLast!
+ (some ns, localPart)
+
+/-- Get all constants used in an expression -/
+def getConstantsFromExpr (e : Expr) : NameSet :=
+ e.foldConsts {} fun c s => s.insert c
+
+/-- Recursively collect all identifier names from a syntax tree (fallback for non-elaborated syntax) -/
+partial def collectIdentifiers (stx : Syntax) : List Name :=
+ match stx with
+ | Syntax.ident _ _ name _ => [name]
+ | Syntax.node _ _ args =>
+ args.toList.flatMap collectIdentifiers
+ | _ => []
+
+/-- Resolve a constant name to its source information using the environment -/
+def resolveConstant (env : Environment) (constName : Name) (moduleMap : Std.HashMap String String := {}) : Option DeclarationDependency := do
+ -- Try to find the constant in the environment
+ let constInfo? := env.find? constName
+ match constInfo? with
+ | none => none
+ | some _ =>
+ -- Get the module index where this constant is defined
+ let moduleIdx? := env.getModuleIdx? constName
+ let moduleName? := moduleIdx?.bind fun idx =>
+ if h : idx.toNat < env.header.moduleNames.size then
+ some env.header.moduleNames[idx.toNat]
+ else
+ none
+ let moduleStr := moduleName?.map Name.toString
+
+ -- Split into namespace and local name
+ let (namespc, localPart) := splitNamespace constName
+
+ -- Try to determine file path from module name
+ let filePath? := match moduleStr with
+ | some modName =>
+ -- First check if it's in our import map (project-local imports)
+ match moduleMap.get? modName with
+ | some path => some path
+ | none =>
+ -- Otherwise use standard conversion for external libraries
+ let path := modName.replace "." "/"
+ some s!"{path}.lean"
+ | none =>
+ -- No module name from getModuleIdx, try to infer from namespace
+ -- Since we can't definitively match, just return none for now
+ none
+
+ -- Also update module name if we inferred a file path but don't have module name yet
+ let moduleStr := match (moduleStr, filePath?) with
+ | (none, some fp) =>
+ -- Try to reverse lookup module name from file path
+ let entries := moduleMap.toList
+ entries.find? (fun (modName, path) => path == fp) |>.map (·.1)
+ | (some m, _) => some m
+ | (none, none) =>
+ -- No module info at all - only apply fallback heuristic if namespace
+ -- doesn't look like stdlib (Lean, Init, Std, Nat, List, etc.)
+ match namespc with
+ | some ns =>
+ let isStdlib := ns.startsWith "Lean" || ns.startsWith "Init" ||
+ ns.startsWith "Std" || ns == "Nat" || ns == "List" ||
+ ns == "Eq" || ns == "And" || ns == "Or" || ns == "String"
+ if isStdlib then
+ none
+ else
+ -- Find first non-Lean/non-Init import as best guess
+ moduleMap.toList.find? (fun (modName, _) =>
+ !modName.startsWith "Lean" && !modName.startsWith "Init"
+ ) |>.map (·.1)
+ | none => none
+
+ -- If we have a module name but no file path, get it from the map
+ let filePath? := match (filePath?, moduleStr) with
+ | (none, some modName) => moduleMap.get? modName
+ | (some fp, _) => some fp
+ | _ => none
+
+ some {
+ name := constName.toString
+ namespc := namespc
+ localName := localPart
+ filePath := filePath?
+ moduleName := moduleStr
+ }
+
+/-- Extract dependencies from a declaration's syntax -/
+def extractDependenciesFromSyntax (env : Environment) (stx : Syntax) : Array DeclarationDependency × Array String :=
+ -- Collect all identifiers from the syntax
+ let identifiers := collectIdentifiers stx
+
+ -- Remove duplicates by converting to NameSet then back
+ let uniqueNames := identifiers.foldl (fun set n => set.insert n) ({} : Lean.NameSet)
+
+ -- Resolve each name and partition into resolved and unresolved
+ uniqueNames.toArray.foldl (fun (deps, unres) name =>
+ match resolveConstant env name with
+ | some dep => (deps.push dep, unres)
+ | none => (deps, unres.push name.toString)
+ ) (#[], #[])
+
+/-- Parse imports and namespaces from a Lean 4 file -/
+def parseImports (filepath : System.FilePath) : IO DependencyInfo := do
+ let content ← IO.FS.readFile filepath
+ let inputCtx := mkInputContext content filepath.toString
+
+ -- Parse header which contains imports
+ let (headerStx, parserState, _) ← parseHeader inputCtx
+
+ let mut imports : Array ImportInfo := #[]
+ let mut namespaces : Array NamespaceInfo := #[]
+
+ -- Extract the underlying syntax from TSyntax
+ let headerSyn : Syntax := headerStx
+
+ -- Parse the header recursively to find all import commands
+ let rec findImports (stx : Syntax) : IO (Array ImportInfo) → IO (Array ImportInfo)
+ | accIO => do
+ let acc ← accIO
+ match stx with
+ | Syntax.node _ kind args =>
+ if kind == `Lean.Parser.Module.import then
+ -- Found an import
+ match stx.getRange? with
+ | some range =>
+ let moduleName := extractModuleName stx
+ let text := content.extract range.start range.stop
+ let info : ImportInfo := {
+ moduleName := moduleName
+ startPos := range.start.byteIdx
+ endPos := range.stop.byteIdx
+ text := text
+ }
+ return acc.push info
+ | none => return acc
+ else
+ -- Recursively search children
+ args.foldlM (fun acc child => findImports child (pure acc)) acc
+ | _ => return acc
+
+ imports ← findImports headerSyn (pure imports)
+
+ -- Now parse the rest of the file to find namespace declarations
+ let env ← Lean.importModules #[] {} 0
+ let opts := {}
+ let pmctx : ParserModuleContext := {
+ env := env
+ options := opts
+ }
+
+ let mut pstate := parserState
+ let mut done := false
+
+ while !done do
+ let startPos := pstate.pos
+ let (stx, pstate', _msgs) := parseCommand inputCtx pmctx pstate {}
+ pstate := pstate'
+
+ -- Check if this is a namespace declaration
+ if stx.getKind == `Lean.Parser.Command.namespace then
+ match stx.getRange? with
+ | some range =>
+ let namespaceName := extractModuleName stx
+ let text := content.extract range.start range.stop
+ let info : NamespaceInfo := {
+ name := namespaceName
+ startPos := range.start.byteIdx
+ endPos := range.stop.byteIdx
+ text := text
+ }
+ namespaces := namespaces.push info
+ | none => pure ()
+
+ -- Check if we made progress or reached end
+ if pstate.pos == startPos || inputCtx.input.atEnd pstate.pos then
+ done := true
+
+ -- Derive module name from file path
+ let moduleName := filePathToModuleName filepath
+
+ let result : DependencyInfo := {
+ filePath := filepath.toString
+ moduleName := moduleName
+ imports := imports
+ namespaces := namespaces
+ }
+ return result
+
+/-- Extract constants from a ConstantInfo (type + value) -/
+def extractConstantsFromConstInfo (cinfo : ConstantInfo) : NameSet :=
+ let fromType := getConstantsFromExpr cinfo.type
+ match cinfo.value? with
+ | some val =>
+ let fromVal := getConstantsFromExpr val
+ -- Merge the two sets by converting to arrays and combining
+ let combined := fromType.toArray ++ fromVal.toArray
+ combined.foldl (fun s n => s.insert n) {}
+ | none => fromType
+
+/-- Analyze all declarations in a file and extract their dependencies -/
+unsafe def analyzeFileDependencies (filepath : System.FilePath) : IO FileDependencyAnalysis := do
+ -- Read file content
+ let content ← IO.FS.readFile filepath
+
+ -- Parse basic file structure (imports and namespaces)
+ let depInfo ← parseImports filepath
+
+ -- Parse all declarations using LineParser
+ let decls ← parseFile filepath
+
+ -- Load environment with all imports and elaborate the file
+ Lean.initSearchPath (← Lean.findSysroot)
+ Lean.enableInitializersExecution
+
+ let inputCtx := Parser.mkInputContext content filepath.toString
+ let (header, parserState, messages) ← Parser.parseHeader inputCtx
+ let (env, _) ← processHeader header {} messages inputCtx
+
+ -- Elaborate the entire file to get all declarations in the environment
+ let commandState := Command.mkState env messages {}
+ let finalState ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState
+ let elaboratedEnv := finalState.env
+
+ -- Build a map of module names to file paths from imports
+ let mut moduleToFilePath : Std.HashMap String String := {}
+ for imp in depInfo.imports do
+ let filePath := imp.moduleName.replace "." "/" ++ ".lean"
+ moduleToFilePath := moduleToFilePath.insert imp.moduleName filePath
+
+ -- Analyze each declaration
+ let mut declsWithDeps : Array DeclWithDependencies := #[]
+ for declInfo in decls do
+ -- Construct the fully qualified name for this declaration
+ let declName := match declInfo.namespc with
+ | some ns =>
+ -- Build the name from namespace parts
+ let namespaceParts := ns.splitOn "."
+ let baseName := namespaceParts.foldl (fun n part => Name.mkStr n part) Name.anonymous
+ Name.mkStr baseName declInfo.name
+ | none => Name.mkStr Name.anonymous declInfo.name
+
+ -- Try to find this declaration in the elaborated environment
+ match elaboratedEnv.find? declName with
+ | some constInfo =>
+ -- Extract all constants from the type and value
+ let allConstants := extractConstantsFromConstInfo constInfo
+
+ -- Remove the declaration itself from its dependencies
+ let allConstants := allConstants.erase declName
+
+ -- Resolve each constant
+ let mut dependencies : Array DeclarationDependency := #[]
+ let mut unresolved : Array String := #[]
+
+ for constName in allConstants.toArray do
+ -- Check if constant exists in pre-elaboration env (means it's imported)
+ match env.find? constName with
+ | some _ =>
+ -- It's from an import, use pre-elaboration env for module info
+ match resolveConstant env constName moduleToFilePath with
+ | some dep => dependencies := dependencies.push dep
+ | none => unresolved := unresolved.push constName.toString
+ | none =>
+ -- It's a local declaration from this file, use elaborated env
+ match resolveConstant elaboratedEnv constName moduleToFilePath with
+ | some dep => dependencies := dependencies.push dep
+ | none => unresolved := unresolved.push constName.toString
+
+ let declWithDeps : DeclWithDependencies := {
+ declInfo := declInfo
+ dependencies := dependencies
+ unresolvedNames := unresolved
+ }
+ declsWithDeps := declsWithDeps.push declWithDeps
+
+ | none =>
+ -- Declaration not found in elaborated environment (might be namespace, end, etc.)
+ -- Fall back to syntax-based extraction
+ let declInputCtx := Parser.mkInputContext declInfo.text ""
+ let (_, declParserState, _) ← Parser.parseHeader declInputCtx
+ let pmctx : ParserModuleContext := { env := env, options := {} }
+ let (declStx, _, _) := Parser.parseCommand declInputCtx pmctx declParserState {}
+ let (dependencies, unresolvedNames) := extractDependenciesFromSyntax env declStx
+
+ let declWithDeps : DeclWithDependencies := {
+ declInfo := declInfo
+ dependencies := dependencies
+ unresolvedNames := unresolvedNames
+ }
+ declsWithDeps := declsWithDeps.push declWithDeps
+
+ let result : FileDependencyAnalysis := {
+ filePath := filepath.toString
+ moduleName := depInfo.moduleName
+ imports := depInfo.imports
+ declarations := declsWithDeps
+ }
+ return result
+
+/-- Export dependency info to JSON file -/
+def exportDependenciesToJson (info : DependencyInfo) (outputPath : System.FilePath) : IO Unit := do
+ let json := toJson info
+ let jsonStr := json.compress
+ IO.FS.writeFile outputPath jsonStr
+ IO.println s!"Exported dependencies to {outputPath}"
+
+/-- Print import info -/
+def printImportInfo (info : ImportInfo) : IO Unit := do
+ IO.println s!" - {info.moduleName}"
+ IO.println s!" Position: {info.startPos} - {info.endPos}"
+ IO.println s!" Text: {info.text}"
+
+/-- Print namespace info -/
+def printNamespaceInfo (info : NamespaceInfo) : IO Unit := do
+ IO.println s!" - {info.name}"
+ IO.println s!" Position: {info.startPos} - {info.endPos}"
+ IO.println s!" Text: {info.text}"
+
+/-- Parse and print dependencies -/
+def parseAndPrintDependencies (filepath : System.FilePath) : IO Unit := do
+ IO.println s!"Analyzing dependencies for: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+
+ let depInfo ← parseImports filepath
+
+ IO.println s!"Module: {depInfo.moduleName}"
+ IO.println ""
+
+ if depInfo.imports.isEmpty then
+ IO.println "No imports found."
+ else
+ IO.println s!"Imports ({depInfo.imports.size}):"
+ for imp in depInfo.imports do
+ printImportInfo imp
+
+ IO.println ""
+
+ if depInfo.namespaces.isEmpty then
+ IO.println "No namespaces found."
+ else
+ IO.println s!"Namespaces ({depInfo.namespaces.size}):"
+ for ns in depInfo.namespaces do
+ printNamespaceInfo ns
+
+/-- Parse and export dependencies -/
+def parseDependenciesAndExport (filepath : System.FilePath) (jsonOutput : Option System.FilePath := none) : IO Unit := do
+ let depInfo ← parseImports filepath
+
+ -- Print to console
+ IO.println s!"Analyzing dependencies for: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+
+ IO.println s!"Module: {depInfo.moduleName}"
+ IO.println ""
+
+ if depInfo.imports.isEmpty then
+ IO.println "No imports found."
+ else
+ IO.println s!"Imports ({depInfo.imports.size}):"
+ for imp in depInfo.imports do
+ printImportInfo imp
+
+ IO.println ""
+
+ if depInfo.namespaces.isEmpty then
+ IO.println "No namespaces found."
+ else
+ IO.println s!"Namespaces ({depInfo.namespaces.size}):"
+ for ns in depInfo.namespaces do
+ printNamespaceInfo ns
+
+ -- Export to JSON if output path provided
+ match jsonOutput with
+ | some outPath => exportDependenciesToJson depInfo outPath
+ | none => pure ()
+
+/-- Export FileDependencyAnalysis to JSON file -/
+def exportFileDependencyAnalysisToJson (analysis : FileDependencyAnalysis) (outputPath : System.FilePath) : IO Unit := do
+ let json := toJson analysis
+ let jsonStr := json.compress
+ IO.FS.writeFile outputPath jsonStr
+ IO.println s!"Exported file dependency analysis to {outputPath}"
+
+/-- Print declaration dependencies -/
+def printDeclDependencies (decl : DeclWithDependencies) : IO Unit := do
+ IO.println s!"[{decl.declInfo.declType}] {decl.declInfo.name}"
+ if !decl.dependencies.isEmpty then
+ IO.println s!" Dependencies ({decl.dependencies.size}):"
+ for dep in decl.dependencies do
+ let modInfo := match dep.moduleName with
+ | some m => s!" (from {m})"
+ | none => ""
+ IO.println s!" - {dep.name}{modInfo}"
+ if !decl.unresolvedNames.isEmpty then
+ IO.println s!" Unresolved ({decl.unresolvedNames.size}): {String.intercalate ", " decl.unresolvedNames.toList}"
+ IO.println ""
+
+/-- Analyze and print file dependencies with per-declaration tracking -/
+unsafe def analyzeAndPrintFileDependencies (filepath : System.FilePath) : IO Unit := do
+ IO.println s!"Analyzing file dependencies: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+
+ let analysis ← analyzeFileDependencies filepath
+
+ IO.println s!"Module: {analysis.moduleName}"
+ IO.println s!"File: {analysis.filePath}"
+ IO.println ""
+
+ if !analysis.imports.isEmpty then
+ IO.println s!"Imports ({analysis.imports.size}):"
+ for imp in analysis.imports do
+ IO.println s!" - {imp.moduleName}"
+ IO.println ""
+
+ IO.println s!"Declarations with Dependencies ({analysis.declarations.size}):"
+ IO.println ""
+ for decl in analysis.declarations do
+ printDeclDependencies decl
+
+/-- Analyze and export file dependencies -/
+unsafe def analyzeAndExportFileDependencies (filepath : System.FilePath) (jsonOutput : Option System.FilePath := none) : IO Unit := do
+ let analysis ← analyzeFileDependencies filepath
+
+ -- Print to console
+ IO.println s!"Analyzing file dependencies: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+ IO.println s!"Module: {analysis.moduleName}"
+ IO.println s!"Declarations analyzed: {analysis.declarations.size}"
+
+ -- Export to JSON if output path provided
+ match jsonOutput with
+ | some outPath => exportFileDependencyAnalysisToJson analysis outPath
+ | none => pure ()
+
+end TacticParser
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean b/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean
new file mode 100644
index 0000000..1250afe
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean
@@ -0,0 +1,39 @@
+import TacticParser.DependencyParser
+
+open TacticParser
+
+/-- Print usage information -/
+def printUsage : IO Unit := do
+ IO.println "Usage: dependency_parser "
+ IO.println ""
+ IO.println "Arguments:"
+ IO.println " Path to the Lean file to analyze"
+ IO.println " Path where JSON output will be written"
+ IO.println ""
+ IO.println "Example:"
+ IO.println " lake env .lake/build/bin/dependency_parser MyFile.lean output.json"
+
+unsafe def main (args : List String) : IO UInt32 := do
+ match args with
+ | [leanFilePath, jsonOutputPath] =>
+ try
+ let filepath : System.FilePath := leanFilePath
+ let jsonPath : System.FilePath := jsonOutputPath
+
+ -- Check if input file exists
+ if !(← filepath.pathExists) then
+ IO.eprintln s!"Error: Input file not found: {filepath}"
+ return 1
+
+ -- Analyze the file and export to JSON
+ analyzeAndExportFileDependencies filepath (some jsonPath)
+ return 0
+ catch e =>
+ IO.eprintln s!"Error: {e}"
+ return 1
+
+ | _ =>
+ IO.eprintln "Error: Invalid number of arguments"
+ IO.eprintln ""
+ printUsage
+ return 1
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean b/src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean
new file mode 100644
index 0000000..3ff3b63
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean
@@ -0,0 +1,487 @@
+import Lean
+import Lean.Data.Json
+import TacticParser.Types
+namespace TacticParser
+
+open Lean
+open Lean.Parser
+
+/-- Identify the type of declaration from syntax -/
+partial def identifySomeDeclType (stx : Syntax) : Option (DeclType × Nat) :=
+ let kind := stx.getKind
+ -- Check if this is a declaration wrapper, if so, look inside
+ if kind == `Lean.Parser.Command.declaration then
+ match stx with
+ | Syntax.node _ _ args =>
+ -- Look for the actual declaration type in the children
+ let idx := args.findIdx (fun a => (identifySomeDeclType a).isSome);
+ if idx = args.size then
+ some (.unknown, idx)
+ else
+ let decl := args[idx]!
+ let declTypeOpt := identifySomeDeclType decl
+ match declTypeOpt with
+ | some dt => some (dt.1, idx)
+ | none => some (.unknown, idx)
+ | _ => none
+ else if kind == `Lean.Parser.Command.end then some (.end, 0)
+ else if kind == `Lean.Parser.Command.namespace then some (.namespace, 0)
+ else if kind == `Lean.Parser.Command.inductive then some (.inductive, 0)
+ else if kind == `Lean.Parser.Command.theorem then some (.theorem, 0)
+ else if kind == `Lean.Parser.Command.definition then some (.def, 0)
+ else if kind == `Lean.Parser.Command.axiom then some (.axiom, 0)
+ else if kind == `Lean.Parser.Command.structure then some (.structure, 0)
+ else if kind == `Lean.Parser.Command.classDecl then some (.class_decl, 0)
+ else if kind == `Lean.Parser.Command.instance then some (.instance, 0)
+ else if kind == `Lean.Parser.Command.example then some (.example, 0)
+ else if kind == `Lean.Parser.Command.otherDecl then some (.other, 0)
+ else none
+
+/-- Identify the type of declaration from syntax -/
+unsafe def identifyDeclType (stx : Syntax) : DeclType :=
+ match identifySomeDeclType stx with
+ | some dt => dt.1
+ | none => .unknown
+
+/-- Check if a syntax node is an attribute/modifier that should be skipped -/
+def isModifierOrAttribute (stx : Syntax) : Bool :=
+ let kind := stx.getKind
+ -- Skip attributes (@[...]), docstrings, and other modifiers
+ kind == `Lean.Parser.Term.attrInstance ||
+ kind == `Lean.Parser.Command.docComment ||
+ kind == `Lean.Parser.Term.attributes ||
+ kind.toString.startsWith "Lean.Parser.Command.declModifiers"
+
+/-- Extract the name of the declaration from syntax tree -/
+partial def extractDeclName (stx : Syntax) : String :=
+ match stx with
+ | Syntax.ident _ _ name _ => name.toString
+ | Syntax.node _ kind args =>
+ -- For declaration nodes, skip modifiers/attributes and keywords to find the name
+ if kind == `Lean.Parser.Command.declaration ||
+ kind == `Lean.Parser.Command.theorem ||
+ kind == `Lean.Parser.Command.definition ||
+ kind == `Lean.Parser.Command.inductive ||
+ kind == `Lean.Parser.Command.structure ||
+ kind == `Lean.Parser.Command.classDecl ||
+ kind == `Lean.Parser.Command.instance ||
+ kind == `Lean.Parser.Command.axiom then
+ -- Skip attributes and find the first identifier that's not a keyword
+ (args.findSome? fun arg =>
+ if isModifierOrAttribute arg then
+ none
+ else
+ let result := extractDeclName arg
+ -- Skip keywords like "theorem", "def", etc.
+ if result != Name.anonymous.toString &&
+ result != "theorem" &&
+ result != "def" &&
+ result != "lemma" &&
+ result != "inductive" &&
+ result != "structure" &&
+ result != "class" &&
+ result != "instance" &&
+ result != "axiom" then
+ some result
+ else
+ none
+ ).getD Name.anonymous.toString
+ else
+ -- For other nodes, search through arguments
+ (args.findSome? fun arg =>
+ let result := extractDeclName arg
+ if result != Name.anonymous.toString then some result else none
+ ).getD Name.anonymous.toString
+ | _ => Name.anonymous.toString
+
+/-- Comment parsing state machine, can parse nested comments too -/
+partial def trimComment (text : String) (state : Nat := 0) (depth : Nat := 0) : Nat :=
+ if text.startsWith "--" ∧ state == 0 then
+ -- we are not inside a block comment, so this is a line comment
+ let newState := 0
+ -- Go till the end of line
+ let endOfLine := text.find (fun c => c == '\n')
+ let remaining := text.drop endOfLine.byteIdx
+ let ep := trimComment remaining newState depth
+ endOfLine.byteIdx + ep
+ else if text.startsWith "/-" ∧ state == 0 then
+ -- starting of a block comment
+ let newState := 1
+ let remaining := text.drop 2
+ let ep := trimComment remaining newState (depth + 1)
+ ep + 2
+ else if text.startsWith "-/" ∧ state == 1 then
+ -- ending of a block comment
+ let newDepth := depth - 1
+ let newState := if newDepth == 0 then 0 else 1
+ let remaining := text.drop 2
+ let ep := trimComment remaining newState newDepth
+ ep + 2
+ else if text.length == 0 then
+ 0
+ else
+ -- consume one character and continue
+ if state == 0 ∧ text.trimLeft.length == text.length then
+ -- not in comment and no leading spaces, stop
+ 0
+ else
+ let remaining := text.drop 1
+ let ep := trimComment remaining state depth
+ ep + 1
+
+def comment_testcase := "
+/- This is a /* nested */ comment -/
+/--This is an example lemma-/ --- let's see how it works
+def exampleLemma : Nat := 42
+"
+
+#eval trimComment comment_testcase -- should return length of comment part
+#eval comment_testcase.drop (trimComment comment_testcase) -- should return " rest of code"
+
+def no_comment_testcase := "def noComment : Nat := 100"
+
+#eval trimComment no_comment_testcase -- should return 0
+#eval no_comment_testcase.drop (trimComment no_comment_testcase) -- should return "def noComment : Nat := 100"
+
+def postProcess (text : String) : String × List Nat :=
+ -- Replace lines with `^lemma ` with `theorem `
+ let lines := text.splitOn "\n"
+ let processedLines := lines.mapIdx fun i line =>
+ if line.trimLeft.startsWith "lemma " then
+ let leadingSpaces := line.takeWhile (fun c => c == ' ' ∨ c == '\t')
+ let newLine := leadingSpaces ++ "theorem " ++ line.trimLeft.drop "lemma ".length
+ (newLine, lines.length)
+ --(newLine, lines.length)
+ else
+ (line, i)
+ let linesOnly := processedLines.map Prod.fst
+ let lineNumbers := processedLines.map Prod.snd
+ let filteredLineNums := lineNumbers.filter (fun n => n != lines.length)
+ (String.intercalate "\n" linesOnly, filteredLineNums)
+
+unsafe def parseCommon
+ (originalContent : String)
+ (parserState : ModuleParserState)
+ (pmctx : ParserModuleContext)
+ (inputCtx : InputContext)
+ : IO (Array DeclInfo) := do
+ -- First pass: parse all commands and collect their positions
+ -- We parse the ORIGINAL content to find declaration boundaries
+ let mut commands : Array (String.Pos × Syntax) := #[]
+ let mut pstate := parserState
+ let mut done := false
+
+ while !done do
+ let startPos := pstate.pos
+ -- IO.println s!"Parsing at position: {pstate.pos}, kind will be: ..."
+
+ -- Try to parse a command from original content
+ let (stx, pstate', msgs) := parseCommand inputCtx pmctx pstate {}
+ pstate := pstate'
+
+ -- IO.println s!" Got kind: {stx.getKind}"
+ -- IO.println s!" New position: {pstate.pos}, atEnd: {inputCtx.atEnd pstate.pos}, messages: {msgs.toList.length}"
+
+ -- Store command with its start position
+ commands := commands.push (startPos, stx)
+
+ -- Check if we made progress or reached end
+ if pstate.pos == startPos then --|| inputCtx.atEnd pstate.pos then
+ -- IO.println s!" Stopping: pos unchanged={pstate.pos == startPos}, atEnd={inputCtx.atEnd pstate.pos}"
+ done := true
+
+ -- Second pass: extract and re-parse declarations
+ let mut decls : Array DeclInfo := #[]
+ let mut openNamespaces : List String := []
+ for i in [:commands.size] do
+ let (parsePos, stx) := commands[i]!
+
+ -- Get position range for this command
+ let realStart := match stx.getRange? with
+ | some range => range.start
+ | none => parsePos
+ let endPos := if i + 1 < commands.size then
+ let (nextParsePos, nextStx) := commands[i + 1]!
+ let nextRealStart := match nextStx.getRange? with
+ | some range => range.start
+ | none => nextParsePos
+ ⟨nextRealStart.byteIdx - 1⟩
+ else
+ ⟨originalContent.endPos.byteIdx⟩
+
+ -- Extract text from ORIGINAL content
+ let text := originalContent.extract realStart endPos
+
+ -- Strip comments to check if this starts with "lemma"
+ let commentEnd := trimComment text
+ let docStringStr := (text.take commentEnd).trim
+ let mut docString := none
+ if !docStringStr.isEmpty then
+ docString := some docStringStr
+ let textWithoutComments := text.drop commentEnd
+ let isLemma := textWithoutComments.startsWith "lemma "
+
+ -- Print the docstring and the text without comments for debugging
+ -- IO.println s!"Docstring part:\n{docString}\n--- End of docstring ---"
+ -- IO.println s!"Text without comments:\n{textWithoutComments}\n--- End of text without comments ---"
+ -- -- print if it's identified as lemma
+ -- IO.println s!"Is lemma: {isLemma}"
+
+ let mut textToParse := text
+ if isLemma then
+ -- If it's a lemma, preprocess it for parsing
+ -- replace the "lemma" at the end position of the comment with "theorem"
+ textToParse := text.take commentEnd ++
+ "theorem " ++
+ textWithoutComments.drop "lemma ".length
+
+ let declInputCtx := mkInputContext textToParse ""
+ let (_, declParserState, _) ← parseHeader declInputCtx
+ let (declStx, _, _) := parseCommand declInputCtx pmctx declParserState {}
+
+ -- Now identify the declaration type from the (possibly preprocessed) syntax
+ let declType := identifyDeclType declStx
+ let name := extractDeclName declStx
+
+ if declType == .namespace then
+ openNamespaces := openNamespaces.append [name]
+ else if declType == .end then
+ -- Pop the last opened namespace if any
+ openNamespaces := openNamespaces.dropLast
+
+ -- If we preprocessed it and it parsed as theorem, it's actually a lemma
+ let actualDeclType := if isLemma && declType == .theorem then .lemma else declType
+ let namespc :=
+ if openNamespaces.isEmpty then
+ none
+ else
+ some (String.intercalate "." openNamespaces)
+
+ let start_pos := get_position_from_char_pos originalContent realStart.byteIdx
+ let end_pos := get_position_from_char_pos originalContent endPos.byteIdx
+
+ let info : DeclInfo := {
+ declType := actualDeclType
+ name := name
+ startPos := start_pos
+ endPos := end_pos
+ text := textWithoutComments -- Store text after extracting docstring
+ docString := docString -- Store extracted docstring
+ namespc := namespc
+ }
+ decls := decls.push info
+
+ return decls
+
+/-- Parse a Lean 4 file and extract declaration information -/
+unsafe def parseDecls (originalContent : String) : IO (Array DeclInfo) := do
+ let (postProcessedContent, modifiedLineIdx) := postProcess originalContent
+ let inputCtx := mkInputContext postProcessedContent ""
+
+ -- Parse the header (using original content)
+ let (_, parserState, _) ← parseHeader inputCtx
+
+ -- Create a minimal parser context with empty environment
+ let env ← Lean.importModules #[] {} 0
+ let opts := {}
+ let pmctx : ParserModuleContext := {
+ env := env
+ options := opts
+ }
+ let decls ← parseCommon postProcessedContent parserState pmctx inputCtx
+
+ return decls
+
+/-- Parse a Lean 4 file and extract declaration information -/
+unsafe def parseFile (filepath : System.FilePath) : IO (Array DeclInfo) := do
+ let originalContent ← IO.FS.readFile filepath
+ let (postProcessedContent, modifiedLineIdx) := postProcess originalContent
+ let inputCtx := mkInputContext postProcessedContent filepath.toString
+
+ -- Parse the header (using original content)
+ let (_, parserState, _) ← parseHeader inputCtx
+
+ -- Create a minimal parser context with empty environment
+ let env ← Lean.importModules #[] {} 0
+ let opts := {}
+ let pmctx : ParserModuleContext := {
+ env := env
+ options := opts
+ }
+
+ let decls ← parseCommon postProcessedContent parserState pmctx inputCtx
+
+ return decls
+
+/-- Convert DeclInfo to JSON -/
+def declInfoToJson (info : DeclInfo) : Json :=
+ let baseFields := [
+ ("declType", Json.str (toString info.declType)),
+ ("name", Json.str info.name),
+ ("startPos", toJson info.startPos),
+ ("endPos", toJson info.endPos),
+ ("text", Json.str info.text)
+ ]
+ let withDocString := match info.docString with
+ | some doc => baseFields ++ [("docString", Json.str doc)]
+ | none => baseFields
+ Json.mkObj withDocString
+
+/-- Simple helper to print declaration info -/
+def printDeclInfo (info : DeclInfo) : IO Unit := do
+ IO.println s!"[{info.declType}] {info.name}"
+ IO.println s!" Position: {toJson info.startPos} - {toJson info.endPos}"
+ let preview := if info.text.length > 100 then
+ info.text.take 50 ++ "\n ... more text ... \n" ++ info.text.drop (info.text.length - 50)
+ else
+ info.text
+ IO.println s!" Text: {preview}"
+ IO.println ""
+
+/-- Export declarations to JSON file -/
+def exportToJson (decls : Array DeclInfo) (outputPath : System.FilePath) : IO Unit := do
+ let jsonArray := Json.arr (decls.map declInfoToJson)
+ let jsonStr := jsonArray.pretty
+ IO.FS.writeFile outputPath jsonStr
+ IO.println s!"Exported {decls.size} declaration(s) to {outputPath}"
+
+/-- Parse and print all declarations in a file -/
+unsafe def parseAndPrint (filepath : System.FilePath) : IO Unit := do
+ IO.println s!"Parsing file: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+
+ let decls ← parseFile filepath
+
+ if decls.isEmpty then
+ IO.println "No declarations found."
+ else
+ IO.println s!"Found {decls.size} declaration(s):"
+ IO.println ""
+ for decl in decls do
+ printDeclInfo decl
+
+/-- Parse file and export to both console and JSON -/
+unsafe def parseAndExport (filepath : System.FilePath) (jsonOutput : Option System.FilePath := none) : IO Unit := do
+ let decls ← parseFile filepath
+
+ -- Print to console
+ IO.println s!"Parsing file: {filepath}"
+ IO.println (String.mk (List.replicate 50 '='))
+
+ if decls.isEmpty then
+ IO.println "No declarations found."
+ else
+ IO.println s!"Found {decls.size} declaration(s):"
+ IO.println ""
+ for decl in decls do
+ printDeclInfo decl
+
+ -- Export to JSON if output path provided
+ match jsonOutput with
+ | some outPath => exportToJson decls outPath
+ | none => pure ()
+
+def test_str := "import Mathlib
+namespace Lean4Proj1
+
+def hello := \"world\"
+theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp
+
+
+theorem test2 : p -> q -> p ∧ q ∧ p := fun hp hq => ⟨hp, ⟨hq, hp⟩⟩
+
+
+-- show a proof which uses calc
+theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
+calc
+ _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
+ _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
+ _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
+ _ = n*n + n + n + 1 := by rw [Nat.pow_two]
+ _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
+ _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
+ _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
+ _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]
+ _ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]
+done
+
+end Lean4Proj1
+
+namespace Lean4Proj2
+
+example : p -> q -> p ∧ q ∧ p := fun hp hq => ⟨hp, ⟨hq, hp⟩⟩
+
+/-- This is a test theorem -/
+theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp
+done
+
+@[simp]
+theorem test3 (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+ apply And.intro
+ exact hp
+ apply And.intro
+ exact hq
+ exact hp
+
+theorem imo_1959_p1
+ (n : ℕ)
+ (h₀ : 0 < n) :
+ Nat.gcd (21*n + 4) (14*n + 3) = 1 := by
+rw [Nat.gcd_rec]
+rw [Nat.mod_eq_of_lt (by linarith)]
+rw [Nat.gcd_rec]
+rw [Nat.gcd_rec]
+have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by
+ have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring
+ rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]
+ have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith
+ rw [Nat.mod_eq_of_lt]
+ rw [Nat.mod_eq_of_lt]
+ exact h₂
+ rw [Nat.mod_eq_of_lt]
+ exact h₂
+ exact h₂
+rw [eq₂]
+sorry
+
+
+lemma pow_dvd_pow (a : α) (h : m ≤ n) : a ^ m ∣ a ^ n :=
+ ⟨a ^ (n - m), by rw [← pow_add, Nat.add_comm, Nat.sub_add_cancel h]⟩
+
+lemma dvd_pow (hab : a ∣ b) : ∀ {n : ℕ} (_ : n ≠ 0), a ∣ b ^ n
+ | 0, hn => (hn rfl).elim
+ | n + 1, _ => by rw [pow_succ']; exact hab.mul_right _
+
+alias Dvd.dvd.pow := dvd_pow
+
+lemma dvd_pow_self (a : α) {n : ℕ} (hn : n ≠ 0) : a ∣ a ^ n := dvd_rfl.pow hn
+
+end Lean4Proj2
+"
+
+#eval parseDecls test_str
+
+#eval (test_str.extract ⟨15⟩ ⟨37⟩)
+
+#eval (test_str.extract ⟨37⟩ ⟨58⟩)
+
+#eval (test_str.extract ⟨298⟩ ⟨912⟩)
+
+#eval get_position_from_char_pos test_str 57 -- expect line 4, column 20
+
+#eval get_position_from_char_pos test_str 299 -- line 18, column 1
+
+end TacticParser
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Main.lean b/src/itp_interface/tools/tactic_parser/TacticParser/Main.lean
new file mode 100644
index 0000000..a6950c4
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/Main.lean
@@ -0,0 +1,215 @@
+/-
+Main executable: read base64 from stdin, output JSON to stdout.
+Runs in a loop to avoid restart overhead.
+
+The process should be started from the project directory for project-specific parsing.
+-/
+import TacticParser.Base64
+import TacticParser.Types
+import TacticParser.SyntaxWalker
+import TacticParser.LineParser
+import Lean
+
+open Lean
+open TacticParser
+
+/-- Result of parsing tactics -/
+class FromStr (α : Type) where
+fromStr : String → Option α
+
+inductive ParseRequestType
+ | parseTactics
+ | parseTheorem
+ | chkptTactics
+ | breakChckpnt
+deriving Inhabited, Repr, BEq
+
+instance : ToString ParseRequestType where
+ toString
+ | .breakChckpnt => "break_chckpnt"
+ | .chkptTactics => "chkpt_tactics"
+ | .parseTactics => "parse_tactics"
+ | .parseTheorem => "parse_theorem"
+
+-- define the representation of ParseRequestType
+def parse_request_names := [
+ "break_chckpnt",
+ "chkpt_tactics",
+ "parse_tactics",
+ "parse_theorem"
+]
+
+def parse_max_pad := parse_request_names.map String.length |>.foldl Nat.max 0
+
+#eval parse_max_pad
+
+instance : FromStr ParseRequestType where
+ fromStr s :=
+ match s with
+ | "break_chckpnt" => some ParseRequestType.breakChckpnt
+ | "chkpt_tactics" => some ParseRequestType.chkptTactics
+ | "parse_tactics" => some ParseRequestType.parseTactics
+ | "parse_theorem" => some ParseRequestType.parseTheorem
+ | _ => none
+
+structure UserParseRequest where
+ requestType : ParseRequestType
+ content : String -- Lean code
+ deriving Inhabited, Repr, BEq
+
+instance : FromStr UserParseRequest where
+ fromStr s :=
+ if s.length < parse_max_pad + 1 then
+ none
+ else
+ let pref := s.take parse_max_pad
+ let content := s.drop parse_max_pad
+ match FromStr.fromStr pref with
+ | some reqType => some { requestType := reqType, content := content }
+ | none => none
+
+#eval (FromStr.fromStr "parse_tactics" : Option ParseRequestType)
+
+def some_lean_code : String := "parse_tactics
+theorem test1 (p q : Prop) (hp : p) (hq : q) : p ∧ q := by
+ apply And.intro
+ exact hp
+ exact hq
+"
+#eval! (FromStr.fromStr some_lean_code : Option UserParseRequest)
+
+/-- Process a single request and output JSON -/
+unsafe def processRequest (b64Input : String) (chkptState : Option CheckpointedParseResult := none) : IO (Option CheckpointedParseResult) := do
+ try
+ -- Decode base64 to Lean code
+ let parse_request_raw ← match Base64.decode b64Input with
+ | .ok correct_parse_request => pure correct_parse_request
+ | .error msg =>
+ -- Output error as JSON
+ let errorInfo := ErrorInfo.mk (s!"Base64 decode error: {msg}") { line := 0, column := 0 }
+ let result : ParseResult := { trees := #[], errors := #[errorInfo] }
+ IO.println (toJson result).compress
+ return none
+
+ let user_parse_request : Option UserParseRequest ← pure (FromStr.fromStr parse_request_raw)
+
+ if user_parse_request.isNone then
+ -- Output error as JSON
+ let errorInfo := ErrorInfo.mk (s!"Invalid parse request format.") { line := 0, column := 0 }
+ let result : ParseResult := { trees := #[], errors := #[errorInfo] }
+ IO.println (toJson result).compress
+ return none
+
+ let parse_request := user_parse_request.get!
+
+ let mut result : ParseResult := { trees := #[], errors := #[] }
+ -- Initialize new checkpoint state to the current one
+ let mut newchkptState : Option CheckpointedParseResult := chkptState
+ let is_of_tactics_type :=
+ parse_request.requestType == ParseRequestType.parseTactics ∨
+ parse_request.requestType == ParseRequestType.chkptTactics ∨
+ parse_request.requestType == ParseRequestType.breakChckpnt
+ let is_checkpoint_request :=
+ parse_request.requestType == ParseRequestType.chkptTactics
+ let is_break_checkpoint_request :=
+ parse_request.requestType == ParseRequestType.breakChckpnt
+ if is_of_tactics_type then
+ if is_break_checkpoint_request then
+ -- First check if it is a breaking request, clear the last state
+ newchkptState := none
+ -- Parse tactics from Lean code
+ let cmdState :=
+ match newchkptState with
+ | some chkpt => chkpt.chkptState
+ | none => none
+ let chkpointParseResult ← parseTactics parse_request.content none cmdState
+ result := chkpointParseResult.parseResult
+ --IO.println s!"Parsed tactics with {result.trees.size} trees and {repr result.errors} errors."
+ if is_checkpoint_request then
+ -- Only changes if the checkpoint is to be updated
+ let line_num := chkpointParseResult.lineNum.getD 0
+ let prev_line_num :=
+ match newchkptState with
+ | some chkpt => chkpt.lineNum.getD 0
+ | none => 0
+ -- Adjust line number based on previous checkpoint
+ newchkptState := some {
+ parseResult := chkpointParseResult.parseResult,
+ lineNum := some (line_num + prev_line_num),
+ chkptState := chkpointParseResult.chkptState
+ }
+ -- Additionally, adjust error positions based on previous checkpoint
+ let prev_line_num :=
+ match newchkptState with
+ | some chkpt => chkpt.lineNum.getD 0
+ | none => 0
+ if prev_line_num > 0 then
+ -- Adjust error line numbers
+ let adjusted_errors := result.errors.map (fun err =>
+ { err with
+ position := {
+ line := err.position.line + prev_line_num,
+ column := err.position.column
+ }
+ })
+ -- Adjust tree line numbers
+ let adjusted_trees := result.trees.map (fun tree =>
+ let rec adjust_tree (node : InfoNodeStruct) : InfoNodeStruct :=
+ {
+ node with
+ startPos := {
+ line := node.startPos.line + prev_line_num,
+ column := node.startPos.column
+ },
+ endPos := {
+ line := node.endPos.line + prev_line_num,
+ column := node.endPos.column
+ },
+ children := node.children.map adjust_tree
+ }
+ adjust_tree tree
+ )
+ result := { trees := adjusted_trees, errors := adjusted_errors }
+ else
+ -- Unsupported request type
+ let temp_result ← parseDecls parse_request.content
+ let mut tree_list : Array InfoNodeStruct := #[]
+ for decl in temp_result do
+ let info_tree ← pure (InfoNodeStruct.mk decl.declType decl.name decl.docString decl.text decl.startPos decl.endPos decl.namespc #[])
+ tree_list := tree_list.push info_tree
+ result := { trees := tree_list, errors := #[] }
+
+ -- Output result as JSON
+ IO.println (toJson result).compress
+ return newchkptState
+ catch e =>
+ -- Output error as JSON
+ let errorInfo := ErrorInfo.mk (s!"Unexpected error: {e}") { line := 0, column := 0 }
+ let result : ParseResult := { trees := #[], errors := #[errorInfo] }
+ IO.println (toJson result).compress
+ return none
+
+/-- Loop to process requests -/
+unsafe def loop (stdin : IO.FS.Stream) (stdout : IO.FS.Stream) (chkptState : Option CheckpointedParseResult := none) : IO Unit := do
+ -- Read input from stdin (base64)
+ let line ← stdin.getLine
+ let line := line.trim
+
+ -- Exit on empty line or "exit" command
+ if line.isEmpty || line = "exit" then
+ return
+
+ -- Process the request
+ let mut newchkptState ← processRequest line chkptState
+
+ -- Flush output to ensure Python can read it
+ stdout.flush
+
+ -- Continue loop
+ loop stdin stdout newchkptState
+
+/-- Start processing -/
+unsafe def main (args : List String) : IO Unit := do
+ let stdin ← IO.getStdin
+ let stdout ← IO.getStdout
+ loop stdin stdout none
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean b/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean
new file mode 100644
index 0000000..fb9f9fc
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean
@@ -0,0 +1,418 @@
+/-
+Syntax walker to extract tactics from Lean code with lightweight elaboration.
+Uses InfoTrees (like REPL) which requires elaboration but NOT compilation!
+
+Can work in two modes:
+1. Standalone: Parse simple tactics without dependencies (minimal environment)
+2. Project mode: Use a project's search path to enable mathlib/dependencies
+-/
+import Lean
+import Lean.Elab.Frontend
+import TacticParser.Types
+
+open Lean Elab
+
+namespace Lean.Elab.IO
+
+/--
+Wrapper for `IO.processCommands` that enables info states, and returns
+* the new command state
+* messages
+* info trees
+-/
+def processCommandsWithInfoTrees
+ (inputCtx : Parser.InputContext) (parserState : Parser.ModuleParserState)
+ (commandState : Command.State) : IO (Command.State × Array Message × Array InfoTree) := do
+ let commandState := { commandState with infoState.enabled := true }
+ let s ← IO.processCommands inputCtx parserState commandState <&> Frontend.State.commandState
+ pure (s, s.messages.toArray, s.infoState.trees.toArray)
+
+/--
+Process some text input, with or without an existing command state.
+If there is no existing environment, we parse the input for headers (e.g. import statements),
+and create a new environment.
+Otherwise, we add to the existing environment.
+
+Returns:
+1. The header-only command state (only useful when cmdState? is none)
+2. The resulting command state after processing the entire input
+3. List of messages
+4. List of info trees
+-/
+def processInput (input : String) (cmdState? : Option Command.State)
+ (opts : Options := {}) (fileName : Option String := none) :
+ IO (Command.State × Command.State × Array Message × Array InfoTree) := unsafe do
+ Lean.initSearchPath (← Lean.findSysroot)
+ enableInitializersExecution
+ let fileName := fileName.getD ""
+ let inputCtx := Parser.mkInputContext input fileName
+
+ match cmdState? with
+ | none => do
+ -- Split the processing into two phases to prevent self-reference in proofs in tactic mode
+ let (header, parserState, messages) ← Parser.parseHeader inputCtx
+ let (env, messages) ← processHeader header opts messages inputCtx
+ let headerOnlyState := Command.mkState env messages opts
+ let (cmdState, messages, trees) ← processCommandsWithInfoTrees inputCtx parserState headerOnlyState
+ return (headerOnlyState, cmdState, messages, trees)
+
+ | some cmdStateBefore => do
+ let parserState : Parser.ModuleParserState := {}
+ let (cmdStateAfter, messages, trees) ← processCommandsWithInfoTrees inputCtx parserState cmdStateBefore
+ return (cmdStateBefore, cmdStateAfter, messages, trees)
+
+end Lean.Elab.IO
+
+namespace TacticParser
+
+open Lean
+open Lean.Elab
+open Lean.Parser
+open Lean.Syntax
+
+/-- Convert a String.Pos to line and column numbers -/
+def posToLineColumn (input : String) (pos : String.Pos) : Position :=
+ let lines := input.extract 0 pos |>.splitOn "\n"
+ let line := lines.length
+ let column := (lines.getLast!).length
+ { line, column }
+
+/-- Extract the source text for a syntax node -/
+def syntaxToString (stx : Syntax) : String :=
+ stx.reprint.getD (toString stx)
+
+/-- Pretty print InfoTree structure for debugging -/
+partial def printInfoTree (input : String) (tree : InfoTree) (indent : Nat := 0) : IO Unit := do
+ let spaces := String.pushn "" ' ' indent
+ match tree with
+ | .context _ t =>
+ --IO.println s!"{spaces}Context"
+ printInfoTree input t (indent)
+ | .node info children =>
+ match info with
+ | .ofTacticInfo tacInfo =>
+ -- Extract actual text from source using byte positions
+ let startByte := tacInfo.stx.getPos?.getD 0
+ let endByte := tacInfo.stx.getTailPos?.getD 0
+ let actualText := input.extract startByte endByte |>.trim
+
+ let startPos := posToLineColumn input startByte
+ let endPos := posToLineColumn input endByte
+ let preview := actualText
+ IO.println s!"{spaces}TacticInfo: L{startPos.line}:C{startPos.column}-L{endPos.line}:C{endPos.column} | {preview.replace "\n" "\\n"}"
+ for child in children do
+ printInfoTree input child (indent + 2)
+ | _ =>
+ --IO.println s!"{spaces}Other"
+ for child in children do
+ printInfoTree input child (indent)
+ | .hole _ =>
+ IO.println s!"{spaces}Hole"
+
+/-- Convert InfoTree to InfoTreeNode -/
+partial def infoTreeToNode (input : String) (tree : InfoTree) : InfoTreeNode :=
+ match tree with
+ | .context _ t =>
+ infoTreeToNode input t
+ | .node info children =>
+ let childNodes := (children.map (infoTreeToNode input)).toArray
+ -- filter all children that are .hole and .other
+ let filteredChildren := childNodes.filter fun
+ | .hole => false
+ | .other arr => arr.isEmpty
+ | _ => true
+ match info with
+ | .ofTacticInfo tacInfo =>
+ let text := tacInfo.stx.reprint.getD (toString tacInfo.stx) |>.trim
+ let startPos := posToLineColumn input (tacInfo.stx.getPos?.getD 0)
+ let endPos := posToLineColumn input (tacInfo.stx.getTailPos?.getD 0)
+ InfoTreeNode.leanInfo DeclType.tactic none none text startPos endPos none filteredChildren
+ | _ => .other childNodes
+ | .hole _ => .hole
+
+partial def removeOtherAndHoles (node : InfoTreeNode) : Option InfoTreeNode :=
+ match node with
+ | .context child =>
+ match removeOtherAndHoles child with
+ | some newChild => some (.context newChild)
+ | none => none
+ | InfoTreeNode.leanInfo decType name docString text startPos endPos namespc children =>
+ let newChildren := children.map removeOtherAndHoles |>.filterMap id
+ some (InfoTreeNode.leanInfo decType name docString text startPos endPos namespc newChildren)
+ | .other children =>
+ let newChildren := children.map removeOtherAndHoles |>.filterMap id
+ if newChildren.isEmpty then
+ none
+ else
+ some (.other newChildren)
+ | .hole => none
+
+partial def filterChildrenAtLevel (node : InfoTreeNode) (level : Nat) : Option InfoTreeNode :=
+ match node with
+ | .context child =>
+ match filterChildrenAtLevel child level with
+ | some newChild => some (.context newChild)
+ | none => none
+ | InfoTreeNode.leanInfo decType name docString text startPos endPos namespc children =>
+ if level == 0 then
+ some (InfoTreeNode.leanInfo decType name docString text startPos endPos namespc #[])
+ else
+ let newChildren := children.map fun child =>
+ filterChildrenAtLevel child (level - 1)
+ let filteredChildren := newChildren.filterMap id
+ some (InfoTreeNode.leanInfo decType name docString text startPos endPos namespc filteredChildren)
+ | .other children =>
+ let newChildren := children.map fun child =>
+ filterChildrenAtLevel child level
+ let filteredChildren := newChildren.filterMap id
+ if filteredChildren.isEmpty then
+ none
+ else
+ some (.other filteredChildren)
+ | .hole => none
+
+def nodeIsHole (node : InfoTreeNode) : Bool :=
+ match node with
+ | .hole => true
+ | _ => false
+
+def nodeEndPos (node : InfoTreeNode) : Option Position :=
+ match node with
+ | InfoTreeNode.leanInfo _ _ _ _ _ endPos _ _ => some endPos
+ | _ => none
+
+def filterAllNodesWhichDontStartAndEndOnLine (node : InfoTreeNode) (line_num: Nat) : Array InfoTreeNode :=
+match node with
+| .context child =>
+ filterAllNodesWhichDontStartAndEndOnLine child line_num
+| InfoTreeNode.leanInfo decType name docString text startPos endPos namespc children =>
+ let newChildren := children.flatMap fun child =>
+ filterAllNodesWhichDontStartAndEndOnLine child line_num
+ if startPos.line != line_num ∨ endPos.line != line_num then
+ newChildren
+ else
+ -- Add self to the front of the list
+ #[InfoTreeNode.leanInfo decType name docString text startPos endPos namespc #[]] ++ newChildren
+| .other children =>
+ children.flatMap fun child =>
+ filterAllNodesWhichDontStartAndEndOnLine child line_num
+| .hole => #[]
+
+def getMaxLineExtent (node : InfoTreeNode) (line_num: Nat) : InfoTreeNode × Nat :=
+let all_possible_nodes := filterAllNodesWhichDontStartAndEndOnLine node line_num
+let arg_max := all_possible_nodes.foldl (fun (acc_node, acc_len) n =>
+ let len := (match n with
+ | InfoTreeNode.leanInfo _ _ _ _ startPos endPos _ _ =>
+ endPos.column - startPos.column
+ | _ => 0
+ )
+ if len > acc_len then
+ (n, len)
+ else
+ (acc_node, acc_len)
+) (InfoTreeNode.hole, 0)
+arg_max
+
+def getAllLinesInTree (node : InfoTreeNode) : Std.HashSet Nat :=
+ match node with
+ | .context child =>
+ getAllLinesInTree child
+ | InfoTreeNode.leanInfo _ _ _ _ startPos endPos _ children =>
+ let childrenLines := children.map getAllLinesInTree
+ (childrenLines.foldl (init := ({}: Std.HashSet Nat)) (fun acc lines =>
+ acc.union lines)).union {startPos.line, endPos.line}
+ | .other children =>
+ let childrenLines := children.map getAllLinesInTree
+ childrenLines.foldl (init := {}) (fun acc lines =>
+ acc.union lines)
+ | .hole => {}
+
+def getAllLineNumsFromTrees (trees : Array InfoTreeNode) : Array Nat :=
+(trees.foldl (init := ({}: Std.HashSet Nat)) (fun acc tree =>
+acc.union (getAllLinesInTree tree))).toArray.insertionSort
+
+def getAllExtents (trees : Array InfoTreeNode) : Array InfoTreeNode :=
+let line_nums := getAllLineNumsFromTrees trees
+(line_nums.flatMap (
+ fun line_num =>
+ (trees.foldl (fun acc tree =>
+ let (n, _) := getMaxLineExtent tree line_num
+ acc.push n
+ ) (#[] : Array InfoTreeNode)).insertionSort (fun n1 n2 =>
+ match (nodeEndPos n1, nodeEndPos n2) with
+ | (some pos1, some pos2) => pos1.column < pos2.column
+ | _ => false
+ )
+)).filter fun n => ¬ nodeIsHole n
+
+def getTextFromPosition (input : String) (startPos : Position) (endPos : Position) : String :=
+ let lines := input.splitOn "\n"
+ if startPos.line > lines.length ∨ startPos.line == 0 ∨ endPos.line == 0 then
+ ""
+ else
+ let relevantLines := (lines.take endPos.line).drop (startPos.line - 1)
+ let firstLine := relevantLines[0]!.drop startPos.column--.extract ⟨startPos.column⟩ ⟨relevantLines[0]!.length⟩
+ let lastLine := relevantLines[relevantLines.length - 1]!.take endPos.column
+ let middleLines := (relevantLines.take (relevantLines.length - 1)).drop 1
+ let actualLines := if relevantLines.length > 1 then [firstLine] ++ middleLines ++ [lastLine] else [firstLine]
+ String.intercalate "\n" actualLines
+
+def dropNewLineAndCountSpaces (s : String) : String × String :=
+ let strWithoutSpace := s.dropRightWhile (fun c => c == '\t' || c == ' ')
+ let rightSpace := s.takeRightWhile (fun c => c == '\t' || c == ' ')
+ (strWithoutSpace.trimRight, rightSpace)
+
+/-- Helper: parse tactics in the current context -/
+unsafe def parseInCurrentContext (input : String) (filePath : Option String := none) (chkptState : Option Command.State := none) : IO CheckpointedParseResult := do
+ try
+ --let inputCtx := Parser.mkInputContext input ""
+ let (initialCmdState, cmdState, messages, trees) ← try
+ IO.processInput input chkptState Options.empty filePath
+ catch e =>
+ let errorInfo := ErrorInfo.mk (s!"Error during processing input: {e}") { line := 0, column := 0 }
+ let parseResult : ParseResult := { trees := #[], errors := #[errorInfo] }
+ return { parseResult := parseResult, chkptState := chkptState , lineNum := none }
+
+
+ -- Print any messages
+ -- IO.println "\n=== Elaboration Messages ==="
+ let mut errorInfos : Array ErrorInfo := #[]
+ for msg in messages do
+ if msg.severity == .error then
+ let msgPos := Position.mk msg.pos.line msg.pos.column
+ let errorInfo := ErrorInfo.mk (← msg.data.toString) msgPos
+ errorInfos := errorInfos.push errorInfo
+ -- IO.println s!"[ERROR] {← msg.data.toString} {msg.pos}"
+
+ -- IO.println s!"[{severity}] {← msg.data.toString}"
+ -- IO.println "=== End Messages ===\n"
+
+ -- Print the cmdState environment
+ -- IO.println "\n=== cmdState Environment Messages ==="
+ -- for msg in cmdState.messages.toArray do
+ -- let severity := match msg.severity with
+ -- | .error => "ERROR"
+ -- | .warning => "WARNING"
+ -- | .information => "INFO"
+ -- IO.println s!"[{severity}] {← msg.data.toString}"
+ -- IO.println "=== End cmdState Messages ===\n"
+
+ let level := 0 -- Only keep direct children of tactics
+ -- let transformed_trees := trees.map (fun t =>
+ -- let ans := removeOtherAndHoles (infoTreeToNode input t)
+ -- let ans_d := ans.getD (.other #[])
+ -- filterChildrenAtLevel ans_d level)
+ let transformed_trees := trees.map (fun t => removeOtherAndHoles (infoTreeToNode input t))
+ let t_trees := transformed_trees.map (fun t => t.getD (.other #[]))
+ let lineExtents := getAllExtents t_trees
+ let extentStruct := lineExtents.map getInfoNodeStruct
+ -- Go over all line extents and reassign the end_pos of the next node
+ let mut adjusted_trees : Array InfoNodeStruct := #[]
+ for i in [1:lineExtents.size] do
+ let prev_node := extentStruct[i - 1]!.getD default
+ let curr_node := extentStruct[i]!.getD default
+ let new_prev_node := {prev_node with endPos := curr_node.startPos}
+ adjusted_trees := adjusted_trees.push new_prev_node
+
+ let mut last_node := extentStruct[extentStruct.size - 1]!.getD default
+ let lines := input.splitOn "\n"
+ let lineCount := lines.length
+ last_node := {last_node with endPos := { line := lineCount, column := lines.getLast!.length }}
+ adjusted_trees := adjusted_trees.push last_node
+ -- Fix the text fields based on updated positions
+ adjusted_trees := adjusted_trees.map fun node =>
+ let new_text := getTextFromPosition input node.startPos node.endPos
+ { node with text := new_text }
+ let mut (prev_text, right_space) := dropNewLineAndCountSpaces adjusted_trees[0]!.text
+ adjusted_trees := adjusted_trees.set! 0 {adjusted_trees[0]! with text := prev_text}
+ for i in [1:adjusted_trees.size] do
+ let curr_node := adjusted_trees[i]!
+ let mut (curr_text, curr_right_space) := dropNewLineAndCountSpaces curr_node.text
+ curr_text := right_space ++ curr_text
+ right_space := curr_right_space
+ adjusted_trees := adjusted_trees.set! i {curr_node with text := curr_text}
+
+ -- let new_prev_node := {prev_node with endPos := nodeEndPos curr_node.getD prev_node}
+ let parseResult : ParseResult := { trees := adjusted_trees, errors := errorInfos }
+ return { parseResult := parseResult, chkptState := cmdState , lineNum := lineCount }
+ catch e =>
+ let errorInfo := ErrorInfo.mk (s!"Error in parseInCurrentContext: {e}") { line := 0, column := 0 }
+ let parseResult : ParseResult := { trees := #[], errors := #[errorInfo] }
+ return { parseResult := parseResult, chkptState := chkptState , lineNum := none }
+
+/-- Parse Lean code WITH elaboration to get InfoTrees (lightweight, no compilation!)
+
+ Initializes Lean from current working directory (finds .lake/build automatically).
+ For project-specific parsing, start the process from the project directory.
+-/
+unsafe def parseTacticsWithElaboration (input : String) (filePath : Option String := none) (chkptState : Option Command.State := none) : IO CheckpointedParseResult := do
+ try
+ -- Initialize Lean from current directory (finds .lake/build if present)
+ Lean.initSearchPath (← Lean.findSysroot)
+ Lean.enableInitializersExecution
+ return ← parseInCurrentContext input filePath chkptState
+ catch e =>
+ let errorInfo := ErrorInfo.mk (s!"Error in parseTacticsWithElaboration: {e}") { line := 0, column := 0 }
+ let parseResult : ParseResult := { trees := #[], errors := #[errorInfo] }
+ return { parseResult := parseResult, chkptState := chkptState , lineNum := none }
+
+/-- Parse Lean code and extract all tactics (uses elaboration-based approach) -/
+@[implemented_by parseTacticsWithElaboration]
+opaque parseTactics (input : String) (filePath : Option String := none) (chkptState : Option Command.State := none) : IO CheckpointedParseResult
+
+-- -- Test case 1: Simple proof with apply and exact
+def simple_example := "theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+have h1 : p ∧ q := by
+ sorry
+apply And.intro
+exact hq
+exact hp
+"
+
+def more_complex_example := "theorem test3 (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+have htemp : p ∧ q
+:= by
+ apply And.intro
+ exact hp
+ exact hq
+simp [htemp]
+rw [hp]
+"
+
+def import_example := "import Lean
+
+theorem test_import (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+
+"
+
+def wrong_tactic_example := "theorem test_wrong (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+applly And.intro
+exact hp
+apply And.intro
+"
+
+def wrong_tactic_example2 := "theorem wrong_decl : Nat := by assdfadfs"
+
+
+def temp := (parseTactics more_complex_example)
+
+#eval temp
+
+
+#eval parseTactics import_example
+
+#eval parseTactics wrong_tactic_example
+
+#eval parseTactics wrong_tactic_example2
+
+end TacticParser
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean b/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean
new file mode 100644
index 0000000..a4ef181
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean
@@ -0,0 +1,334 @@
+/-
+Types for tactic information.
+-/
+import Lean
+import Lean.Elab.Frontend
+
+namespace TacticParser
+
+open Lean
+open Lean.Elab
+
+/-- Represents different types of Lean declarations -/
+inductive DeclType where
+ | inductive
+ | theorem
+ | def
+ | axiom
+ | structure
+ | class_decl
+ | instance
+ | other
+ | example
+ | lemma
+ | unknown
+ | tactic
+ | namespace
+ | end
+ deriving Repr, BEq
+
+instance : ToString DeclType where
+ toString
+ | .inductive => "inductive"
+ | .theorem => "theorem"
+ | .def => "def"
+ | .axiom => "axiom"
+ | .structure => "structure"
+ | .class_decl => "class"
+ | .instance => "instance"
+ | .other => "other"
+ | .example => "example"
+ | .lemma => "lemma" -- lemma is treated as theorem
+ | .unknown => "unknown"
+ | .tactic => "tactic"
+ | .namespace => "namespace"
+ | .end => "end"
+
+/-- Position information for a tactic -/
+structure Position where
+ line : Nat
+ column : Nat
+ deriving Inhabited, Repr, BEq
+
+instance : ToJson Position where
+ toJson p := Json.mkObj [
+ ("line", toJson p.line),
+ ("column", toJson p.column)
+ ]
+
+instance : FromJson Position where
+ fromJson? j := do
+ let line ← j.getObjValAs? Nat "line"
+ let column ← j.getObjValAs? Nat "column"
+ return { line, column }
+
+/-- Information extracted from a declaration -/
+structure DeclInfo where
+ declType : DeclType
+ name : String
+ startPos : Position
+ endPos : Position
+ text : String
+ docString : Option String -- Extracted documentation comment
+ namespc : Option String -- Current namespace
+ deriving Repr
+
+instance : ToJson DeclInfo where
+ toJson d := Json.mkObj [
+ ("decl_type", toJson (ToString.toString d.declType)),
+ ("name", toJson d.name),
+ ("line", d.startPos.line),
+ ("column", d.startPos.column),
+ ("end_line", d.endPos.line),
+ ("end_column", d.endPos.column),
+ ("text", toJson d.text),
+ ("doc_string", toJson d.docString),
+ ("namespace", toJson d.namespc)
+ ]
+
+/-- Information about an import statement -/
+structure ImportInfo where
+ moduleName : String
+ startPos : Nat
+ endPos : Nat
+ text : String
+ deriving Repr
+
+instance : ToJson ImportInfo where
+ toJson i := Json.mkObj [
+ ("module_name", toJson i.moduleName),
+ ("start_pos", toJson i.startPos),
+ ("end_pos", toJson i.endPos),
+ ("text", toJson i.text)
+ ]
+
+/-- Information about a namespace declaration -/
+structure NamespaceInfo where
+ name : String
+ startPos : Nat
+ endPos : Nat
+ text : String
+ deriving Repr
+
+instance : ToJson NamespaceInfo where
+ toJson n := Json.mkObj [
+ ("name", toJson n.name),
+ ("start_pos", toJson n.startPos),
+ ("end_pos", toJson n.endPos),
+ ("text", toJson n.text)
+ ]
+
+/-- Information about file dependencies and module structure -/
+structure DependencyInfo where
+ filePath : String
+ moduleName : String -- The module name derived from file path
+ imports : Array ImportInfo
+ namespaces : Array NamespaceInfo
+ deriving Repr
+
+instance : ToJson DependencyInfo where
+ toJson d := Json.mkObj [
+ ("file_path", toJson d.filePath),
+ ("module_name", toJson d.moduleName),
+ ("imports", toJson d.imports),
+ ("namespaces", toJson d.namespaces)
+ ]
+
+/-- Information about a single dependency reference -/
+structure DeclarationDependency where
+ name : String -- Fully qualified name (e.g., "Nat.add_zero")
+ namespc : Option String -- Namespace portion (e.g., "Nat")
+ localName : String -- Local name without namespace
+ filePath : Option String -- Source file if resolvable
+ moduleName : Option String -- Module where defined
+ deriving Repr
+
+instance : ToJson DeclarationDependency where
+ toJson d := Json.mkObj [
+ ("name", toJson d.name),
+ ("namespace", toJson d.namespc),
+ ("local_name", toJson d.localName),
+ ("file_path", toJson d.filePath),
+ ("module_name", toJson d.moduleName)
+ ]
+
+/-- Declaration with its dependencies -/
+structure DeclWithDependencies where
+ declInfo : DeclInfo -- From LineParser
+ dependencies : Array DeclarationDependency
+ unresolvedNames : Array String -- Names we couldn't resolve
+ deriving Repr
+
+instance : ToJson DeclWithDependencies where
+ toJson d := Json.mkObj [
+ ("decl_info", toJson d.declInfo),
+ ("dependencies", toJson d.dependencies),
+ ("unresolved_names", toJson d.unresolvedNames)
+ ]
+
+/-- Complete file dependency analysis with per-declaration tracking -/
+structure FileDependencyAnalysis where
+ filePath : String
+ moduleName : String
+ imports : Array ImportInfo
+ declarations : Array DeclWithDependencies
+ deriving Repr
+
+instance : ToJson FileDependencyAnalysis where
+ toJson f := Json.mkObj [
+ ("file_path", toJson f.filePath),
+ ("module_name", toJson f.moduleName),
+ ("imports", toJson f.imports),
+ ("declarations", toJson f.declarations)
+ ]
+
+/-- InfoTree node representation -/
+inductive InfoTreeNode where
+ | context : InfoTreeNode → InfoTreeNode
+ | leanInfo
+ (declType: DeclType)
+ (name: Option String)
+ (docString: Option String)
+ (text: String)
+ (startPos: Position)
+ (endPos: Position)
+ (namespc: Option String)
+ (children: Array InfoTreeNode) : InfoTreeNode
+ | other : Array InfoTreeNode → InfoTreeNode
+ | hole : InfoTreeNode
+ deriving Inhabited, Repr
+
+partial def InfoTreeNode.toJson : InfoTreeNode → Json
+ | .context child =>
+ Json.mkObj [
+ ("type", "context"),
+ ("children", child.toJson)
+ ]
+ | leanInfo declType name docString text startPos endPos namespc children =>
+ Json.mkObj [
+ ("type", "leanInfo"),
+ ("decl_type", ToString.toString declType),
+ ("name", Lean.ToJson.toJson name),
+ ("doc_string", Lean.ToJson.toJson docString),
+ ("text", Lean.ToJson.toJson text),
+ ("start_pos", Lean.ToJson.toJson startPos),
+ ("end_pos", Lean.ToJson.toJson endPos),
+ ("namespace", Lean.ToJson.toJson namespc),
+ ("children", Json.arr (children.map InfoTreeNode.toJson))
+ ]
+ | .other children =>
+ Json.mkObj [
+ ("type", "other"),
+ ("children", Json.arr (children.map InfoTreeNode.toJson))
+ ]
+ | .hole =>
+ Json.mkObj [("type", "hole")]
+
+instance : ToJson InfoTreeNode where
+ toJson := InfoTreeNode.toJson
+
+structure InfoNodeStruct where
+ declType : DeclType
+ name : Option String
+ docString : Option String
+ text : String
+ startPos : Position
+ endPos : Position
+ namespc : Option String
+ children : Array InfoNodeStruct
+deriving Repr
+
+def defaultInfoNodeStruct : InfoNodeStruct :=
+ {
+ declType := .unknown,
+ name := none,
+ docString := none,
+ text := "",
+ startPos := { line := 0, column := 0 },
+ endPos := { line := 0, column := 0 },
+ namespc := none,
+ children := #[]
+ }
+
+instance : Inhabited InfoNodeStruct where
+ default := defaultInfoNodeStruct
+
+/-- toJson for InfoNodeStruct -/
+partial def InfoNodeStruct.toJson (n: InfoNodeStruct) : Json :=
+ Json.mkObj [
+ ("decl_type", ToString.toString n.declType),
+ ("name", Lean.ToJson.toJson n.name),
+ ("doc_string", Lean.ToJson.toJson n.docString),
+ ("text", Lean.ToJson.toJson n.text),
+ ("start_pos", Lean.ToJson.toJson n.startPos),
+ ("end_pos", Lean.ToJson.toJson n.endPos),
+ ("namespace", Lean.ToJson.toJson n.namespc),
+ ("children", Json.arr (n.children.map InfoNodeStruct.toJson))
+ ]
+
+instance : ToJson InfoNodeStruct where
+ toJson := InfoNodeStruct.toJson
+
+def getInfoNodeStruct (node : InfoTreeNode) : Option InfoNodeStruct :=
+ match node with
+ | .leanInfo declType name docString text startPos endPos namespc children =>
+ let childStructs := children.map getInfoNodeStruct
+ let filterSomes := childStructs.filterMap id
+ some {
+ declType,
+ name,
+ docString,
+ text,
+ startPos,
+ endPos,
+ namespc,
+ children := filterSomes
+ }
+ | _ => none
+
+structure ErrorInfo where
+ message : String
+ position : Position
+ deriving Inhabited, Repr
+
+instance : ToJson ErrorInfo where
+ toJson e := Json.mkObj [
+ ("message", toJson e.message),
+ ("position", toJson e.position)
+ ]
+
+/-- Result of parsing tactics from Lean code -/
+structure ParseResult where
+ trees : Array InfoNodeStruct := #[]
+ errors : Array ErrorInfo := #[]
+ deriving Inhabited, Repr
+
+instance : ToJson ParseResult where
+ toJson r := Json.mkObj [
+ ("trees", toJson r.trees),
+ ("errors", toJson r.errors)
+ ]
+
+/-- Storing ParseResult with a checkpointed state -/
+structure CheckpointedParseResult where
+ parseResult : ParseResult
+ chkptState : Option Command.State := none
+ lineNum : Option Nat := none
+ deriving Inhabited
+
+/-- Custom Repr instance for CheckpointedParseResult -/
+instance : Repr CheckpointedParseResult where
+ reprPrec r _ := (repr r.parseResult)
+
+def get_position_from_char_pos (content : String) (charPos : Nat) : Position :=
+ let before := content.extract ⟨0⟩ ⟨charPos⟩
+ let lines := before.splitOn "\n"
+ let lineCount := lines.length
+ if lineCount == 0 then
+ { line := 0, column := 0 }
+ else
+ let lastLine := lines[lineCount - 1]!
+ -- let byteLen := lastLine.endPos.byteIdx
+ { line := lineCount, column := lastLine.length }
+
+end TacticParser
diff --git a/src/itp_interface/tools/tactic_parser/lake-manifest.json b/src/itp_interface/tools/tactic_parser/lake-manifest.json
new file mode 100644
index 0000000..57ec5c7
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/lake-manifest.json
@@ -0,0 +1,5 @@
+{"version": "1.1.0",
+ "packagesDir": ".lake/packages",
+ "packages": [],
+ "name": "TacticParser",
+ "lakeDir": ".lake"}
diff --git a/src/itp_interface/tools/tactic_parser/lakefile.toml b/src/itp_interface/tools/tactic_parser/lakefile.toml
new file mode 100644
index 0000000..16e7605
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/lakefile.toml
@@ -0,0 +1,15 @@
+name = "TacticParser"
+defaultTargets = ["tactic-parser", "dependency-parser"]
+
+[[lean_lib]]
+name = "TacticParser"
+
+[[lean_exe]]
+name = "tactic-parser"
+root = "TacticParser.Main"
+supportInterpreter = true
+
+[[lean_exe]]
+name = "dependency-parser"
+root = "TacticParser.DependencyParserMain"
+supportInterpreter = true
diff --git a/src/itp_interface/tools/tactic_parser/lean-toolchain b/src/itp_interface/tools/tactic_parser/lean-toolchain
new file mode 100644
index 0000000..58ae245
--- /dev/null
+++ b/src/itp_interface/tools/tactic_parser/lean-toolchain
@@ -0,0 +1 @@
+leanprover/lean4:v4.24.0
\ No newline at end of file
diff --git a/src/itp_interface/tools/training_data.py b/src/itp_interface/tools/training_data.py
index ce8353f..2de2567 100644
--- a/src/itp_interface/tools/training_data.py
+++ b/src/itp_interface/tools/training_data.py
@@ -12,7 +12,15 @@
import logging
import time
import threading
-from itp_interface.tools.training_data_format import LemmaRefWithScore, LemmaReferencesCollection, MergableCollection, TrainingDataCollection, TrainingDataFormat, TrainingDataMetadataFormat
+from enum import Enum
+from itp_interface.tools.training_data_format import (
+ LemmaRefWithScore, LemmaReferencesCollection, MergableCollection,
+ TrainingDataCollection,
+ TheoremProvingTrainingDataCollection,
+ ExtractionDataCollection,
+ TrainingDataFormat,
+ TheoremProvingTrainingDataFormat,
+ TrainingDataMetadataFormat)
# Conditional Ray import
try:
@@ -33,6 +41,20 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
return False
+# Define Enum for DataLayoutFormat
+class DataLayoutFormat(Enum):
+ THEOREM_PROVING = "theorem_proving"
+ DECLARATION_EXTRACTION = "declaration_extraction"
+ LEMMA_REF_EXTRACTION = "lemma_ref_extraction"
+
+
+def get_training_data_collection(layout: DataLayoutFormat) -> type[TrainingDataCollection]:
+ if layout == DataLayoutFormat.THEOREM_PROVING:
+ return TheoremProvingTrainingDataCollection
+ elif layout == DataLayoutFormat.LEMMA_REF_EXTRACTION:
+ return LemmaReferencesCollection
+ else:
+ return ExtractionDataCollection
class TrainingData(MergableCollection):
def __init__(
@@ -43,7 +65,8 @@ def __init__(
max_parallelism: int = 4,
remove_from_store_after_loading: bool = True,
logger: logging.Logger = None,
- use_ray: bool = None):
+ use_ray: bool = None,
+ layout: DataLayoutFormat = DataLayoutFormat.THEOREM_PROVING):
assert os.path.exists(folder), f"Folder {folder} does not exist"
assert os.path.isdir(folder), f"Folder {folder} is not a directory"
assert training_meta_filename is not None, "Training meta filename cannot be None"
@@ -59,6 +82,8 @@ def __init__(
self.logger = logger if logger is not None else logging.getLogger(__name__)
self.remove_from_store_after_loading = remove_from_store_after_loading
self._meta_loaded = False
+ self._layout = layout
+ self._training_data_collection = get_training_data_collection(layout)
# Determine if Ray should be used
if use_ray is None:
self._use_ray = HAS_RAY
@@ -78,7 +103,7 @@ def __init__(
def __len__(self) -> int:
assert self.meta is not None, "Training meta is not set"
- return self.meta.total_proof_step_cnt
+ return self.meta.total_data_count
@property
def is_readonly(self) -> bool:
@@ -139,7 +164,7 @@ def _create_remote(filenames):
for i, filename in enumerate(filenames):
self.logger.info(f"[TrainingData] Starting the loading of [{base_idx + i}] {filename}...")
collection_fn = TrainingData._get_lemma_ref_collection if filename == self._lemma_ref_filename else TrainingData._get_training_data_collection
- remotes.append(collection_fn.remote(base_idx + i, self.folder, filename))
+ remotes.append(collection_fn.remote(base_idx + i, self.folder, filename, self._layout))
return remotes
def _transform_remote(results):
@@ -182,7 +207,7 @@ def _load_sequential(self, files_to_load):
self.lemma_ref_collection = LemmaReferencesCollection.load_from_file(file_path)
self.logger.info(f"[TrainingData] Finished loading {self._lemma_ref_filename}")
else:
- tdc = TrainingDataCollection.load_from_file(file_path)
+ tdc = self._training_data_collection.load_from_file(file_path)
self.training_data_collections[idx - 2] = tdc
self.logger.info(f"[TrainingData] Finished loading {idx}")
@@ -196,6 +221,9 @@ def unload(self):
# Reload the metadata
self.load_meta()
+ def undo_merge(self, size: int = 1, start_idx=0) -> object:
+ return NotImplementedError("undo_merge is not implemented yet")
+
def merge(self, __o: object, new_lemma_ref_idx: typing.List[int] = None):
with self._lock:
assert isinstance(__o, TrainingDataFormat) or \
@@ -219,7 +247,7 @@ def clone_skeleton(self, training_data, lemma_ref_collection: LemmaReferencesCol
assert isinstance(training_data, TrainingData), "Invalid type"
assert not self._meta_loaded, "Training metadata is already loaded"
self.meta.training_data_buffer_size = training_data.meta.training_data_buffer_size
- self.meta.total_proof_step_cnt = training_data.meta.total_proof_step_cnt
+ self.meta.total_data_count = training_data.meta.total_data_count
self.meta.external_theorems_used_cnt = training_data.meta.external_theorems_used_cnt
self.meta.local_theorems_used_cnt = training_data.meta.local_theorems_used_cnt
self.meta.last_proof_id = training_data.meta.last_proof_id
@@ -250,6 +278,29 @@ def clone_skeleton(self, training_data, lemma_ref_collection: LemmaReferencesCol
assert len(self._training_data_filenames) == len(self.training_data_collections), "Invalid length"
assert len(self._training_data_filenames) == len(training_data.training_data_collections), "Invalid length"
+ def _return_tdp(self, tdp: TrainingDataFormat) -> TrainingDataFormat:
+ training_data = copy.deepcopy(tdp)
+ if isinstance(training_data, TheoremProvingTrainingDataFormat):
+ lemma_refs : typing.Set[int] = set()
+ for goal in training_data.start_goals:
+ lemma_refs.update([ref.lemma_idx for ref in goal.relevant_defns])
+ lemma_refs.update([ref.lemma_idx for ref in goal.used_theorems_local])
+ lemma_refs.update([ref.lemma_idx for ref in goal.used_theorems_external])
+ lemma_refs.update([ref.lemma_idx for ref in goal.possible_useful_theorems_local])
+ lemma_refs.update([ref.lemma_idx for ref in goal.possible_useful_theorems_external])
+ ordered_lemma_refs = sorted(list(lemma_refs))
+ lemma_ref_map = {lemma_ref: idx for idx, lemma_ref in enumerate(ordered_lemma_refs)}
+ training_data.all_useful_defns_theorems = [self.lemma_ref_collection.training_data[lemma_idx].clone(idx) for idx, lemma_idx in enumerate(ordered_lemma_refs)]
+ # Change the lemma references
+ for goal in training_data.start_goals:
+ goal.relevant_defns = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.relevant_defns]
+ goal.used_theorems_local = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.used_theorems_local]
+ goal.used_theorems_external = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.used_theorems_external]
+ goal.possible_useful_theorems_local = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_local]
+ goal.possible_useful_theorems_external = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_external]
+ return training_data
+ return training_data
+
def __getitem__(self, idx: int) -> TrainingDataFormat:
tdc_idx = idx // self.meta.training_data_buffer_size
idx_in_tdc = idx % self.meta.training_data_buffer_size
@@ -258,24 +309,7 @@ def __getitem__(self, idx: int) -> TrainingDataFormat:
tdc = self.training_data_collections[tdc_idx]
if idx_in_tdc >= len(tdc):
raise IndexError(f"Index out of range (len(self.training_data_collections)={len(self.training_data_collections)},buffer={self.meta.training_data_buffer_size}, range idx={idx}, tdc_idx={tdc_idx}, idx_in_tdc={idx_in_tdc}, len(tdc)={len(tdc)})")
- training_data = copy.deepcopy(tdc.training_data[idx_in_tdc])
- lemma_refs : typing.Set[int] = set()
- for goal in training_data.start_goals:
- lemma_refs.update([ref.lemma_idx for ref in goal.relevant_defns])
- lemma_refs.update([ref.lemma_idx for ref in goal.used_theorems_local])
- lemma_refs.update([ref.lemma_idx for ref in goal.used_theorems_external])
- lemma_refs.update([ref.lemma_idx for ref in goal.possible_useful_theorems_local])
- lemma_refs.update([ref.lemma_idx for ref in goal.possible_useful_theorems_external])
- ordered_lemma_refs = sorted(list(lemma_refs))
- lemma_ref_map = {lemma_ref: idx for idx, lemma_ref in enumerate(ordered_lemma_refs)}
- training_data.all_useful_defns_theorems = [self.lemma_ref_collection.lemma_references[lemma_idx].clone(idx) for idx, lemma_idx in enumerate(ordered_lemma_refs)]
- # Change the lemma references
- for goal in training_data.start_goals:
- goal.relevant_defns = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.relevant_defns]
- goal.used_theorems_local = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.used_theorems_local]
- goal.used_theorems_external = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.used_theorems_external]
- goal.possible_useful_theorems_local = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_local]
- goal.possible_useful_theorems_external = [LemmaRefWithScore(lemma_ref_map[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_external]
+ training_data = self._return_tdp(tdc.training_data[idx_in_tdc])
return training_data
def save(self) -> str:
@@ -381,68 +415,93 @@ def _merge_training_data_format(self, other: TrainingDataFormat, new_lemma_ref_i
assert isinstance(other, TrainingDataFormat), "other must be a TrainingDataFormat"
assert self.lemma_ref_collection is not None, "Lemma ref collection is None"
if new_lemma_ref_idx is None:
- new_lemma_ref_idx : typing.List[int] = self.lemma_ref_collection.merge(other.all_useful_defns_theorems)
- assert len(new_lemma_ref_idx) == len(other.all_useful_defns_theorems), "Invalid lemma ref idx"
+ if isinstance(other, TheoremProvingTrainingDataFormat):
+ new_lemma_ref_idx : typing.List[int] = self.lemma_ref_collection.merge(other.all_useful_defns_theorems)
+ assert len(new_lemma_ref_idx) == len(other.all_useful_defns_theorems), "Invalid lemma ref idx"
+ else:
+ new_lemma_ref_idx : typing.List[int] = []
if len(self.training_data_collections) == 0:
- self.training_data_collections.append(TrainingDataCollection())
+ self.training_data_collections.append(self._training_data_collection())
last_training_data_collection = self.training_data_collections[-1]
if len(last_training_data_collection) + 1 > self.meta.training_data_buffer_size:
- self.training_data_collections.append(TrainingDataCollection())
+ self.training_data_collections.append(self._training_data_collection())
last_training_data_collection = self.training_data_collections[-1]
TrainingData._merge_training_data_collection(last_training_data_collection, [other], new_lemma_ref_idx)
# Update the metadata
- self.meta.last_proof_id = other.proof_id
- self.meta.last_training_data += 1
- self.meta.external_theorems_used_cnt += sum([len(goal.used_theorems_external) for goal in other.start_goals])
- self.meta.local_theorems_used_cnt += sum([len(goal.used_theorems_local) for goal in other.start_goals])
- self.meta.total_proof_step_cnt += len(other.proof_steps)
+ self._update_meta(other)
+
+ def _update_meta(self, other: TrainingDataFormat):
+ assert isinstance(other, TrainingDataFormat), "other must be a TrainingDataFormat"
+ if isinstance(other, TheoremProvingTrainingDataFormat):
+ self.meta.last_proof_id = other.proof_id
+ self.meta.last_training_data += 1
+ self.meta.external_theorems_used_cnt += sum([len(goal.used_theorems_external) for goal in other.start_goals])
+ self.meta.local_theorems_used_cnt += sum([len(goal.used_theorems_local) for goal in other.start_goals])
+ self.meta.total_data_count += len(other.proof_steps)
+ else:
+ self.meta.last_training_data += 1
+ self.meta.total_data_count += 1
# Define Ray remote methods conditionally
if HAS_RAY:
@staticmethod
@ray.remote(max_retries=-1)
- def _get_training_data_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]:
+ def _get_training_data_collection(
+ idx : int, folder: str, filename: str,
+ data_layout: DataLayoutFormat) -> typing.Tuple[int, typing.Any]:
file_path = os.path.join(folder, filename)
start_time = time.time()
ray.logger.info(f"[TrainingData] Trying to load {file_path}")
- tdc = TrainingDataCollection.load_from_file(file_path)
+ training_data_collection = get_training_data_collection(data_layout)
+ tdc = training_data_collection.load_from_file(file_path)
end_time = time.time()
ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
return idx, tdc
@staticmethod
@ray.remote(max_retries=-1)
- def _get_lemma_ref_collection(idx : int, folder: str, filename: str) -> typing.Tuple[int, typing.Any]:
+ def _get_lemma_ref_collection(
+ idx : int,
+ folder: str,
+ filename: str,
+ data_layout: DataLayoutFormat) -> typing.Tuple[int, typing.Any]:
file_path = os.path.join(folder, filename)
start_time = time.time()
ray.logger.info(f"[TrainingData] Trying to load {file_path}")
- res = LemmaReferencesCollection.load_from_file(file_path)
+ lemma_ref_collection = get_training_data_collection(DataLayoutFormat.LEMMA_REF_EXTRACTION)
+ res = lemma_ref_collection.load_from_file(file_path)
end_time = time.time()
ray.logger.info(f"[TrainingData] Loaded {file_path} in {end_time - start_time} seconds")
return idx, res
@staticmethod
@ray.remote(max_retries=-1)
- def _save_object(i : int, obj: typing.Union[TrainingDataCollection, TrainingDataMetadataFormat, LemmaReferencesCollection], filepath: str):
+ def _save_object(i : int, obj: typing.Union[
+ TrainingDataCollection,
+ TrainingDataMetadataFormat,
+ LemmaReferencesCollection,
+ ExtractionDataCollection], filepath: str):
save_start_time = time.time()
ray.logger.info(f"[TrainingData] Saving {filepath}")
with open(filepath, 'w') as f:
# serialize the current metadata
- json_str = obj.to_json()
+ if isinstance(obj, ExtractionDataCollection):
+ # TODO [HACK]: The dynamic dispatching does not work well with Ray remote functions
+ json_str = ExtractionDataCollection.to_json(obj)
+ else:
+ ray.logger.info(f"[TrainingData] Serializing object of type {type(obj)}")
+ json_str = obj.to_json()
# update the metadata in the file
f.write(json_str)
save_end_time = time.time()
ray.logger.info(f"[TrainingData] Saved {filepath} in {save_end_time - save_start_time}s")
return i, filepath
- def _merge_training_data_collection(other: TrainingDataCollection, training_data_points: typing.List[TrainingDataFormat], new_lemma_ref_idx: typing.List[int]):
- assert isinstance(other, TrainingDataCollection), "other must be a TrainingDataFormat or TrainingDataCollection"
- assert isinstance(training_data_points, list), "training_data_points must be a list"
- assert isinstance(new_lemma_ref_idx, list), "new_lemma_ref_idx must be a list"
- new_tdps : typing.List[TrainingDataFormat] = []
- for tdp in training_data_points:
- assert isinstance(tdp, TrainingDataFormat), "training_data_points must contain TrainingDataFormat objects"
- new_tdp = TrainingDataFormat(
+ @staticmethod
+ def _clone_tdp(tdp: TrainingDataFormat, new_lemma_ref_idx: typing.List[int]) -> TrainingDataFormat:
+ assert isinstance(tdp, TrainingDataFormat), "training_data_points must contain TrainingDataFormat objects"
+ if isinstance(tdp, TheoremProvingTrainingDataFormat):
+ new_tdp = TheoremProvingTrainingDataFormat(
tdp.proof_id,
all_useful_defns_theorems=[],
start_goals=tdp.start_goals,
@@ -460,5 +519,17 @@ def _merge_training_data_collection(other: TrainingDataCollection, training_data
goal.used_theorems_external = [LemmaRefWithScore(new_lemma_ref_idx[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.used_theorems_external]
goal.possible_useful_theorems_local = [LemmaRefWithScore(new_lemma_ref_idx[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_local]
goal.possible_useful_theorems_external = [LemmaRefWithScore(new_lemma_ref_idx[lemma_ref.lemma_idx], lemma_ref.score) for lemma_ref in goal.possible_useful_theorems_external]
+ else:
+ new_tdp = copy.deepcopy(tdp)
+ return new_tdp
+
+ @staticmethod
+ def _merge_training_data_collection(other: TrainingDataCollection, training_data_points: typing.List[TrainingDataFormat], new_lemma_ref_idx: typing.List[int]):
+ assert isinstance(other, TrainingDataCollection), "other must be a TrainingDataFormat or TrainingDataCollection"
+ assert isinstance(training_data_points, list), "training_data_points must be a list"
+ assert isinstance(new_lemma_ref_idx, list), "new_lemma_ref_idx must be a list"
+ new_tdps : typing.List[TrainingDataFormat] = []
+ for tdp in training_data_points:
+ new_tdp = TrainingData._clone_tdp(tdp, new_lemma_ref_idx)
new_tdps.append(new_tdp)
other.training_data.extend(new_tdps)
\ No newline at end of file
diff --git a/src/itp_interface/tools/training_data_format.py b/src/itp_interface/tools/training_data_format.py
index 6f16620..e7dabfa 100644
--- a/src/itp_interface/tools/training_data_format.py
+++ b/src/itp_interface/tools/training_data_format.py
@@ -6,15 +6,31 @@
sys.path.append(root_dir)
import copy
import os
-import jsonlines
-import typing
import logging
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
from collections import OrderedDict
-from typing import List, Optional, Tuple
+from typing import List, Optional, Union, runtime_checkable, Protocol
+from pydantic import BaseModel
+from itp_interface.tools.tactic_parser import DeclWithDependencies
-class MergableCollection(object):
+@runtime_checkable
+class TrainingDataFormat(Protocol):
+
+ def to_json(self, indent=0) -> str:
+ raise NotImplementedError("to_json must be implemented by the child class")
+
+ @staticmethod
+ def load_from_file(file_path: str):
+ raise NotImplementedError("load_from_file must be implemented by the child class")
+
+ @staticmethod
+ def load_from_string(json_text: str):
+ raise NotImplementedError("load_from_string must be implemented by the child class")
+
+
+@runtime_checkable
+class MergableCollection(Protocol):
def merge(self, __o: object):
raise NotImplementedError("merge must be implemented by the child class")
@@ -24,6 +40,32 @@ def undo_merge(self, size: int = 1, start_idx = 0) -> object:
def __len__(self) -> int:
raise NotImplementedError("__len__ must be implemented by the child class")
+
+@runtime_checkable
+class TrainingDataCollection(Protocol):
+ training_data: list
+
+ def merge(self, __o: object):
+ raise NotImplementedError("merge must be implemented by the child class")
+
+ def undo_merge(self, size: int = 1, start_idx = 0) -> object:
+ raise NotImplementedError("undo_merge must be implemented by the child class")
+
+ def __len__(self) -> int:
+ return len(self.training_data)
+
+ def to_json(self, indent=0) -> str:
+ raise NotImplementedError("to_json must be implemented by the child class")
+
+ @staticmethod
+ def load_from_file(file_path: str, logger: logging.Logger = None):
+ raise NotImplementedError("load_from_file must be implemented by the child class")
+
+ @staticmethod
+ def load_from_string(json_text: str, logger: logging.Logger = None):
+ raise NotImplementedError("load_from_string must be implemented by the child class")
+
+
@dataclass_json
@dataclass
class LemmaRefWithScore(object):
@@ -81,6 +123,9 @@ def __lt__(self, __o: object) -> bool:
def __gt__(self, __o: object) -> bool:
return self != __o and self >= __o
+ def to_json(self, indent=0) -> str:
+ return Goal.schema().dumps(self, indent=indent)
+
@staticmethod
def load_from_file(file_path: str):
assert os.path.exists(file_path), "file_path must be a valid path to a file"
@@ -115,86 +160,31 @@ def __str__(self) -> str:
return f"{self.lemma_name} : {self.lemma_defn}"
# return f"{self.lemma_defn} : {self.lemma_name}"
- def clone(self, idx : typing.Optional[int] = None):
+ def clone(self, idx : Optional[int] = None):
new_copy = copy.deepcopy(self)
if idx is not None:
new_copy.lemma_idx = idx
return new_copy
-
-@dataclass_json
-@dataclass
-class LemmaReferencesCollection(MergableCollection):
- """Class to store the lemma references."""
- lemma_references: typing.List[LemmaReferences] = field(default_factory=list)
-
- def __post_init__(self):
- self._lemma_ref_to_idx = {lemma_ref: idx for idx, lemma_ref in enumerate(self.lemma_references)}
-
- def merge(self, __o: object):
- """
- Merge the lemma references with another lemma references collection.
- Returns the merged lemma references collection index map.
- """
- if not isinstance(__o, LemmaReferencesCollection) and not isinstance(__o, LemmaReferences) and not isinstance(__o, list):
- raise TypeError(f"Cannot merge LemmaReferenceCollection with {type(__o)}")
- if isinstance(__o, list) and not all(isinstance(x, LemmaReferences) for x in __o):
- raise TypeError(f"Cannot merge LemmaReferenceCollection with list of {type(__o)}")
- if isinstance(__o, LemmaReferences):
- __o = [__o]
- elif isinstance(__o, LemmaReferencesCollection):
- __o = __o.lemma_references
- to_take_cnt = len(__o)
- new_idx_map = [-1] * to_take_cnt
- for idx in range(to_take_cnt):
- lemma_ref = __o[idx]
- assert 0 <= lemma_ref.lemma_idx < len(__o), f"lemma_idx must be in range [0, {len(__o)}"
- if lemma_ref not in self._lemma_ref_to_idx:
- self._lemma_ref_to_idx[lemma_ref] = len(self.lemma_references)
- lemma_ref_copy = copy.deepcopy(lemma_ref)
- lemma_ref_copy.lemma_idx = len(self.lemma_references)
- self.lemma_references.append(lemma_ref_copy)
- new_idx_map[idx] = lemma_ref_copy.lemma_idx
- else:
- lemma_idx = self._lemma_ref_to_idx[lemma_ref]
- new_idx_map[idx] = lemma_idx
- self.lemma_references[lemma_idx].ref_count += lemma_ref.ref_count
- assert all(idx != -1 for idx in new_idx_map), "new_idx_map must not contain any -1 values"
- return new_idx_map
-
- def __len__(self) -> int:
- return len(self.lemma_references)
- def __iter__(self):
- return iter(self.lemma_references)
-
- def __getitem__(self, idx: int) -> LemmaReferences:
- return self.lemma_references[idx]
+ def to_json(self, indent=0) -> str:
+ return LemmaReferences.schema().dumps(self, indent=indent)
@staticmethod
- def load_from_file(file_path: str, logger: logging.Logger = None):
- assert os.path.exists(file_path), f"file_path:{file_path} must be a valid path to a file"
+ def load_from_file(file_path: str):
+ assert os.path.exists(file_path), "file_path must be a valid path to a file"
json_text = None
- if logger is not None:
- logger.info(f"Loading json data from {file_path}")
with open(file_path, "r") as f:
json_text = f.read()
- if logger is not None:
- logger.info(f"Loaded json data from {file_path}")
- return LemmaReferencesCollection.load_from_string(json_text, logger)
-
+ return LemmaReferences.load_from_string(json_text)
+
@staticmethod
- def load_from_string(json_text: str, logger: logging.Logger = None):
+ def load_from_string(json_text: str):
assert json_text is not None, "json_text cannot be None"
- if logger is not None:
- logger.info(f"Deseiralizing json data from string of length {len(json_text)} characters")
- deserialized = LemmaReferencesCollection.schema().loads(json_text)
- if logger is not None:
- logger.info(f"Deseiralized json data from string of length {len(json_text)} characters")
- return deserialized
+ return LemmaReferences.schema().loads(json_text)
@dataclass_json
@dataclass
-class TrainingDataFormat(object):
+class TheoremProvingTrainingDataFormat(object):
"""Class to format the training data for coq based automatic theorem provers.
This class is responsible for formatting the training data for coq based automatic theorem provers.
"""
@@ -211,7 +201,7 @@ class TrainingDataFormat(object):
theorem_name: Optional[str] = None # The name of the theorem.
def __eq__(self, __o: object) -> bool:
- if not isinstance(__o, TrainingDataFormat):
+ if not isinstance(__o, TheoremProvingTrainingDataFormat):
return False
goal_set_a = set([goal.goal for goal in self.start_goals])
goal_set_b = set([goal.goal for goal in __o.start_goals])
@@ -251,7 +241,7 @@ def __eq__(self, __o: object) -> bool:
def __le__(self, __o: object) -> bool:
# TrainingDataFormat 'a' is less (hard) than TrainingDataFormat 'b' iff all goals in 'a' are subset of goals in 'b'
- if not isinstance(__o, TrainingDataFormat):
+ if not isinstance(__o, TheoremProvingTrainingDataFormat):
raise TypeError(f"Cannot compare TrainingDataFormat with {type(__o)}")
goal_set_a = set([goal.goal for goal in self.start_goals])
goal_set_b = set([goal.goal for goal in __o.start_goals])
@@ -307,7 +297,7 @@ def __le__(self, __o: object) -> bool:
def __ge__(self, __o: object) -> bool:
# TrainingDataFormat 'a' is more (hard) than TrainingDataFormat 'b' iff all goals in 'b' are subset of goals in 'a'
- if not isinstance(__o, TrainingDataFormat):
+ if not isinstance(__o, TheoremProvingTrainingDataFormat):
raise TypeError(f"Cannot compare TrainingDataFormat with {type(__o)}")
goal_set_a = set([goal.goal for goal in self.start_goals])
goal_set_b = set([goal.goal for goal in __o.start_goals])
@@ -364,7 +354,7 @@ def __hash__(self) -> int:
return hash(tuple(goal_set))
def have_same_proof_steps(self, __o: object) -> bool:
- if not isinstance(__o, TrainingDataFormat):
+ if not isinstance(__o, TheoremProvingTrainingDataFormat):
raise TypeError(f"Cannot compare TrainingDataFormat with {type(__o)}")
return len(self.proof_steps) == len(__o.proof_steps) and all([p_a == p_b for p_a, p_b in zip(self.proof_steps, __o.proof_steps)])
@@ -377,27 +367,81 @@ def get_human_readable_serialized_goal(self, idx: int, skip_special_tokens: bool
{hyps}
"""
+ def to_json(self, indent=0) -> str:
+ return TheoremProvingTrainingDataFormat.schema().dumps(self, indent=indent)
+
@staticmethod
def load_from_file(file_path: str):
assert os.path.exists(file_path), "file_path must be a valid path to a file"
json_text = None
with open(file_path, "r") as f:
json_text = f.read()
- return TrainingDataFormat.load_from_string(json_text)
+ return TheoremProvingTrainingDataFormat.load_from_string(json_text)
@staticmethod
def load_from_string(json_text: str):
assert json_text is not None, "json_text cannot be None"
- return TrainingDataFormat.schema().loads(json_text)
+ return TheoremProvingTrainingDataFormat.schema().loads(json_text)
@dataclass_json
@dataclass
-class TrainingDataCollection(MergableCollection):
- training_data: List[TrainingDataFormat] = field(default_factory=list) # The list of training data.
+class LemmaReferencesCollection(TrainingDataCollection):
+ """Class to store the lemma references."""
+ training_data: list[LemmaReferences] = field(default_factory=list)
+
+ def __post_init__(self):
+ self._lemma_ref_to_idx = {lemma_ref: idx for idx, lemma_ref in enumerate(self.training_data)}
+
+ def merge(self, __o: object):
+ """
+ Merge the lemma references with another lemma references collection.
+ Returns the merged lemma references collection index map.
+ """
+ if not isinstance(__o, LemmaReferencesCollection) and not isinstance(__o, LemmaReferences) and not isinstance(__o, list):
+ raise TypeError(f"Cannot merge LemmaReferenceCollection with {type(__o)}")
+ if isinstance(__o, list) and not all(isinstance(x, LemmaReferences) for x in __o):
+ raise TypeError(f"Cannot merge LemmaReferenceCollection with list of {type(__o)}")
+ if isinstance(__o, LemmaReferences):
+ __o = [__o]
+ elif isinstance(__o, LemmaReferencesCollection):
+ __o = __o.training_data
+ to_take_cnt = len(__o)
+ new_idx_map = [-1] * to_take_cnt
+ for idx in range(to_take_cnt):
+ lemma_ref = __o[idx]
+ assert 0 <= lemma_ref.lemma_idx < len(__o), f"lemma_idx must be in range [0, {len(__o)}"
+ if lemma_ref not in self._lemma_ref_to_idx:
+ self._lemma_ref_to_idx[lemma_ref] = len(self.training_data)
+ lemma_ref_copy = copy.deepcopy(lemma_ref)
+ lemma_ref_copy.lemma_idx = len(self.training_data)
+ self.training_data.append(lemma_ref_copy)
+ new_idx_map[idx] = lemma_ref_copy.lemma_idx
+ else:
+ lemma_idx = self._lemma_ref_to_idx[lemma_ref]
+ new_idx_map[idx] = lemma_idx
+ self.training_data[lemma_idx].ref_count += lemma_ref.ref_count
+ assert all(idx != -1 for idx in new_idx_map), "new_idx_map must not contain any -1 values"
+ return new_idx_map
+
+ def undo_merge(self, size: int = 1, start_idx=0) -> object:
+ assert size >= 1, "size must be greater than equal to 1"
+ assert start_idx >= 0, "start_idx must be greater than zero"
+ assert start_idx < len(self.training_data), f"can only cut-down from idx < {len(self.training_data)}"
+ fraction = self.training_data[start_idx: size]
+ return LemmaReferencesCollection(training_data=fraction)
+
+ def __iter__(self):
+ return iter(self.training_data)
+
+ def __getitem__(self, idx: int) -> LemmaReferences:
+ return self.training_data[idx]
+
+ def to_json(self, indent=0) -> str:
+ return LemmaReferencesCollection.schema().dumps(self, indent=indent)
@staticmethod
def load_from_file(file_path: str, logger: logging.Logger = None):
- assert os.path.exists(file_path), f"file_path: {file_path} must be a valid path to a file"
+ assert os.path.exists(file_path), f"file_path:{file_path} must be a valid path to a file"
json_text = None
if logger is not None:
logger.info(f"Loading json data from {file_path}")
@@ -405,21 +449,95 @@ def load_from_file(file_path: str, logger: logging.Logger = None):
json_text = f.read()
if logger is not None:
logger.info(f"Loaded json data from {file_path}")
- return TrainingDataCollection.load_from_string(json_text, logger)
+ return LemmaReferencesCollection.load_from_string(json_text, logger)
@staticmethod
def load_from_string(json_text: str, logger: logging.Logger = None):
assert json_text is not None, "json_text cannot be None"
if logger is not None:
- logger.info(f"Deseiralizing json data from string of length {len(json_text)} characters")
- deserialized = TrainingDataCollection.schema().loads(json_text)
+ logger.info(f"Deserializing json data from string of length {len(json_text)} characters")
+ deserialized = LemmaReferencesCollection.schema().loads(json_text)
if logger is not None:
- logger.info(f"Deseiralized json data from string of length {len(json_text)} characters")
+ logger.info(f"Deserialized json data from string of length {len(json_text)} characters")
return deserialized
+
+@dataclass_json
+@dataclass
+class TheoremProvingTrainingDataCollection(TrainingDataCollection):
+ training_data: list[TheoremProvingTrainingDataFormat] = field(default_factory=list) # The list of training data.
+
+ def merge(self, __o: object):
+ assert isinstance(__o, TheoremProvingTrainingDataCollection)
+ self.training_data.extend(__o.training_data)
+ def undo_merge(self, size: int = 1, start_idx=0) -> object:
+ assert size >= 1, "size must be greater than equal to 1"
+ assert start_idx >= 0, "start_idx must be greater than zero"
+ assert start_idx < len(self.training_data), f"can only cut-down from idx < {len(self.training_data)}"
+ fraction = self.training_data[start_idx: size]
+ return TheoremProvingTrainingDataCollection(training_data=fraction)
+
+ def to_json(self, indent=0) -> str:
+ return TheoremProvingTrainingDataCollection.schema().dumps(self, indent=indent)
+
+ @staticmethod
+ def load_from_file(file_path: str, logger: logging.Logger = None):
+ assert os.path.exists(file_path), f"file_path: {file_path} must be a valid path to a file"
+ json_text = None
+ if logger is not None:
+ logger.info(f"Loading json data from {file_path}")
+ with open(file_path, "r") as f:
+ json_text = f.read()
+ if logger is not None:
+ logger.info(f"Loaded json data from {file_path}")
+ return TheoremProvingTrainingDataCollection.load_from_string(json_text, logger)
+
+ @staticmethod
+ def load_from_string(json_text: str, logger: logging.Logger = None):
+ assert json_text is not None, "json_text cannot be None"
+ if logger is not None:
+ logger.info(f"Deserializing json data from string of length {len(json_text)} characters")
+ deserialized = TheoremProvingTrainingDataCollection.schema().loads(json_text)
+ if logger is not None:
+ logger.info(f"Deserialized json data from string of length {len(json_text)} characters")
+ return deserialized
+
+class ExtractionDataCollection(BaseModel):
+ training_data: list[DeclWithDependencies] = []
+
def __len__(self) -> int:
return len(self.training_data)
+ def to_json(self, indent=0) -> str:
+ if indent == 0:
+ return self.model_dump_json()
+ else:
+ return self.model_dump_json(indent=indent)
+
+ def merge(self, __o: object):
+ assert isinstance(__o, ExtractionDataCollection)
+ self.training_data.extend(__o.training_data)
+
+ def undo_merge(self, size: int = 1, start_idx=0) -> object:
+ assert size >= 1, "size must be greater than equal to 1"
+ assert start_idx >= 0, "start_idx must be greater than zero"
+ assert start_idx < len(self.training_data), f"can only cut-down from idx < {len(self.training_data)}"
+ fraction = self.training_data[start_idx: size]
+ return ExtractionDataCollection(training_data=fraction)
+
+ @staticmethod
+ def load_from_string(json_text: str, logger: logging.Logger = None):
+ assert json_text is not None, "json_text cannot be None"
+ return ExtractionDataCollection.model_validate_json(json_text)
+
+ @staticmethod
+ def load_from_file(file_path: str, logger: logging.Logger = None):
+ assert os.path.exists(file_path), "file_path must be a valid path to a file"
+ json_text = None
+ with open(file_path, "r") as f:
+ json_text = f.read()
+ return ExtractionDataCollection.load_from_string(json_text, logger=logger)
+
@dataclass_json
@dataclass
class TrainingDataMetadataFormat(MergableCollection):
@@ -432,7 +550,7 @@ class TrainingDataMetadataFormat(MergableCollection):
last_proof_id: Optional[str] = None
external_theorems_used_cnt: int = 0
local_theorems_used_cnt: int = 0
- total_proof_step_cnt: int = 0
+ total_data_count: int = 0
data_filename_prefix: str = "full_data"
data_filename_suffix: str = ".json"
lemma_ref_filename_prefix: str = "full_data_lemma_ref"
@@ -445,15 +563,20 @@ def merge(self, __o: object):
self.training_data_buffer_size = max(__o.training_data_buffer_size, self.training_data_buffer_size)
self.last_training_data = __o.last_training_data
self.last_proof_id = __o.last_proof_id
- self.total_proof_step_cnt += __o.total_proof_step_cnt
+ self.total_data_count += __o.total_data_count
self.external_theorems_used_cnt += __o.external_theorems_used_cnt
self.local_theorems_used_cnt += __o.local_theorems_used_cnt
self.num_theorems += __o.num_theorems
-
+
+ def undo_merge(self, size: int = 1, start_idx=0) -> object:
+ raise NotImplementedError("undo_merge is not implemented for TrainingDataMetadataFormat")
def __len__(self) -> int:
return 0
+ def to_json(self, indent=0) -> str:
+ return TrainingDataMetadataFormat.schema().dumps(self, indent=indent)
+
@staticmethod
def load_from_file(file_path: str):
assert os.path.exists(file_path), "file_path must be a valid path to a file"
@@ -476,15 +599,15 @@ def __init__(self, with_labels: bool = False):
def get_layout_format_name(self) -> str:
raise NotImplementedError("get_layout_format_name must be implemented in derived classes")
- def layout_training_data(self, training_data_format: TrainingDataFormat) -> typing.Union[str, typing.Tuple[str, str]]:
+ def layout_training_data(self, training_data_format: TheoremProvingTrainingDataFormat) -> Union[str, tuple[str, str]]:
raise NotImplementedError("get_formatted_training_data must be implemented in derived classes")
- def get_training_data_from_layout(self, formatted_training_data: str) -> TrainingDataFormat:
+ def get_training_data_from_layout(self, formatted_training_data: str) -> TheoremProvingTrainingDataFormat:
raise NotImplementedError("get_training_data_format must be implemented in derived classes")
if __name__ == "__main__":
# Test the training data collection
- training_data_format1 = TrainingDataFormat(
+ training_data_format1 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -500,7 +623,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format2 = TrainingDataFormat(
+ training_data_format2 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -516,7 +639,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format3 = TrainingDataFormat(
+ training_data_format3 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -531,7 +654,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format4 = TrainingDataFormat(
+ training_data_format4 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -546,7 +669,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format5 = TrainingDataFormat(
+ training_data_format5 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -560,7 +683,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format6 = TrainingDataFormat(
+ training_data_format6 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
@@ -573,7 +696,7 @@ def get_training_data_from_layout(self, formatted_training_data: str) -> Trainin
addition_state_info={}
)
- training_data_format7 = TrainingDataFormat(
+ training_data_format7 = TheoremProvingTrainingDataFormat(
proof_id="proof_id",
start_goals=[
Goal(hypotheses=[], goal="forall e : expr, size (constant_fold e) <= size e"),
diff --git a/src/test/simple_data_extract_test.py b/src/test/simple_data_extract_test.py
new file mode 100644
index 0000000..36f1dca
--- /dev/null
+++ b/src/test/simple_data_extract_test.py
@@ -0,0 +1,97 @@
+import unittest
+import os
+import subprocess
+try:
+ import ray
+ from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+ RayResourcePoolActor = None
+ TimedRayExec = None
+ RayUtils = None
+
+
+def pretty_print_file_contents(dir_path):
+ print(f"Printing all files in the directory: {dir_path}")
+ for f in os.listdir(dir_path):
+ file_path = os.path.join(dir_path, f)
+ if os.path.isfile(file_path):
+ print('-'*50)
+ print(f"Contents of {file_path}:")
+ with open(file_path, "r") as file:
+ print(file.read())
+
+class TestExtract(unittest.TestCase):
+ def test_lean_data_extract(self):
+ """
+ Test that the 'run-itp-data-gen' command runs successfully with the given configuration.
+ """
+ # Construct the command as a single string.
+ command = (
+ "run-itp-data-gen --config-dir=src/itp_interface/main/configs "
+ "--config-name=simple_lean_data_extract.yaml"
+ )
+
+ try:
+ # Run the command using shell=True so that the shell does the PATH lookup.
+ result = subprocess.run(
+ command,
+ shell=True,
+ capture_output=True,
+ text=True,
+ timeout=700
+ )
+ except subprocess.TimeoutExpired as e:
+ self.fail(f"'run-itp-data-gen' command timed out: {e}")
+ except Exception as e:
+ self.fail(f"'run-itp-data-gen' failed with unknown exception: {e}")
+
+ # Check that the command exited with a return code of 0.
+ self.assertEqual(
+ result.returncode, 0,
+ msg=f"'run-itp-data-gen' failed with return code {result.returncode}. Stderr: {result.stderr}"
+ )
+
+ # Print all the files in the .log/data_generation/benchmark/simple_benchmark_lean
+ # directory to see what was generated.
+ # Do a list and pick the last folder in the list as per the sorted order
+ dirs = sorted(os.listdir(".log/data_generation/benchmark/simple_benchmark_lean_ext"))
+ print("Directories:", dirs)
+ last_dir = dirs[-1]
+ # Print the directory contents
+ last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean_ext", last_dir)
+ print("Last Directory Contents:", os.listdir(last_dir_path))
+ train_data = os.path.join(last_dir_path, "train")
+ list_files = os.listdir(train_data)
+ print("Train Directory Contents:", list_files)
+ data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
+ print("Data Files:", data_files)
+ if len(data_files) == 0:
+ # Print the last directory contents again
+ pretty_print_file_contents(last_dir_path)
+ print('='*50)
+ # Open all the files in the train directory and print their contents for debugging
+ pretty_print_file_contents(train_data)
+
+ assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}"
+ print(data_files[0])
+ data_gen_file = os.path.join(train_data, data_files[0])
+ print("Data Gen File:", data_gen_file)
+ with open(data_gen_file, "r") as f:
+ print(f.read())
+
+def main():
+ unittest.main()
+
+if __name__ == '__main__':
+ if HAS_RAY:
+ os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
+ object_store_memory_in_gb = 0.15
+ memory_in_gb = 0.25
+ ray_dashboard = RayUtils.init_ray(
+ num_of_cpus=2,
+ object_store_memory_in_gb=object_store_memory_in_gb,
+ memory_in_gb=memory_in_gb)
+ main()
\ No newline at end of file
diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py
index 7c101b3..e39ce93 100644
--- a/src/test/simple_data_gen_test.py
+++ b/src/test/simple_data_gen_test.py
@@ -1,6 +1,26 @@
import unittest
import os
import subprocess
+try:
+ import ray
+ from itp_interface.tools.ray_utils import RayResourcePoolActor, TimedRayExec, RayUtils
+ HAS_RAY = True
+except ImportError:
+ HAS_RAY = False
+ ray = None
+ RayResourcePoolActor = None
+ TimedRayExec = None
+ RayUtils = None
+
+def pretty_print_file_contents(dir_path):
+ print(f"Printing all files in the directory: {dir_path}")
+ for f in os.listdir(dir_path):
+ file_path = os.path.join(dir_path, f)
+ if os.path.isfile(file_path):
+ print('-'*50)
+ print(f"Contents of {file_path}:")
+ with open(file_path, "r") as file:
+ print(file.read())
class TestDataGen(unittest.TestCase):
def test_proof_step_data_gen(self):
@@ -47,6 +67,13 @@ def test_proof_step_data_gen(self):
print("Train Directory Contents:", list_files)
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
print("Data Files:", data_files)
+ if len(data_files) == 0:
+ # Print the last directory contents again
+ pretty_print_file_contents(last_dir_path)
+ print('='*50)
+ # Open all the files in the train directory and print their contents for debugging
+ pretty_print_file_contents(train_data)
+
assert len(data_files) == 1, f"No files found in the train directory. Expected one file. Found: {data_files}"
print(data_files[0])
data_gen_file = os.path.join(train_data, data_files[0])
@@ -58,4 +85,12 @@ def main():
unittest.main()
if __name__ == '__main__':
+ if HAS_RAY:
+ os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
+ object_store_memory_in_gb = 0.15
+ memory_in_gb = 0.25
+ ray_dashboard = RayUtils.init_ray(
+ num_of_cpus=2,
+ object_store_memory_in_gb=object_store_memory_in_gb,
+ memory_in_gb=memory_in_gb)
main()
\ No newline at end of file
diff --git a/src/test/simple_env_test.py b/src/test/simple_env_test.py
index dd37b52..1005f54 100644
--- a/src/test/simple_env_test.py
+++ b/src/test/simple_env_test.py
@@ -1,20 +1,47 @@
import unittest
+import os
+from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed
+
+def pretty_print(s1, s2, proof_step, done):
+ print(f"Current Goal:")
+ print('-'*30)
+ for goal in s1.training_data_format.start_goals:
+ hyps = '\n'.join([hyp for hyp in goal.hypotheses])
+ print(hyps)
+ print('|- ', end='')
+ print(goal.goal)
+ print(f'*'*30)
+ print(f"="*30)
+ print(f"Action: {proof_step}")
+ print(f"="*30)
+ print(f"Next Goal:")
+ print('-'*30)
+ if s2 is not None:
+ for goal in s2.training_data_format.start_goals:
+ hyps = '\n'.join([hyp for hyp in goal.hypotheses])
+ print(hyps)
+ print('|- ', end='')
+ print(goal.goal)
+ print(f'*'*30)
+ print(f"="*30)
+ print(f"DONE: {done}")
+ print('-'*30)
+ if s2 is None and done:
+ print("No more goals. Proof Finished!")
class Helper():
def __init__(self):
self.current_switch = None
def build_lean4_project(self, project_folder):
- import os
+ build_tactic_parser_if_needed()
# Build the project
- with os.popen(f"cd {project_folder} && lake exe cache get && lake build") as proc:
- print("Building Lean4 project...")
- print('-'*15 + 'Build Logs' + '-'*15)
- print(proc.read())
- print('-'*15 + 'End Build Logs' + '-'*15)
+ path_to_lake_folder = os.path.join(project_folder, ".lake")
+ if not os.path.exists(path_to_lake_folder):
+ build_lean4_project(project_folder)
+
def build_coq_project(self, project_folder):
- import os
try:
with os.popen("opam switch show") as proc:
self.current_switch = proc.read().strip()
@@ -44,7 +71,6 @@ def build_coq_project(self, project_folder):
print('-'*15 + 'End Build Logs' + '-'*15)
def switch_to_current_switch(self):
- import os
if self.current_switch is not None:
try:
proc = os.popen(f"opam switch {self.current_switch} && eval $(opam env)")
@@ -66,7 +92,7 @@ def test_simple_lean4(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "test3"
+ theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
@@ -107,24 +133,7 @@ def test_simple_lean4(self):
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
+ pretty_print(s1, s2, proof_step, done)
assert proof_was_finished, "Proof was not finished"
def test_lean4_backtracking(self):
@@ -140,7 +149,7 @@ def test_lean4_backtracking(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "test3"
+ theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
@@ -237,24 +246,7 @@ def test_simple_coq(self):
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
+ pretty_print(s1, s2, proof_step, done)
helper.switch_to_current_switch()
def test_simple_lean_calc(self):
@@ -270,7 +262,7 @@ def test_simple_lean_calc(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "test_calc"
+ theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
# theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
@@ -309,41 +301,80 @@ def test_simple_lean_calc(self):
print('-'*30)
if done:
s1 : ProofState = state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print("Proof Finished!!")
+ pretty_print(s1, None, proof_step, done)
proof_was_finished = True
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
+ pretty_print(s1, s2, proof_step, done)
assert proof_was_finished, "Proof was not finished"
+ def test_simple_lean_calc_with_validation(self):
+ from itp_interface.rl.proof_state import ProofState
+ from itp_interface.rl.proof_action import ProofAction
+ from itp_interface.rl.simple_proof_env import ProofEnv
+ from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+ from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
+ project_folder = "src/data/test/lean4_proj"
+ file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
+ # Build the project
+ # cd src/data/test/lean4_proj && lake build
+ helper = Helper()
+ helper.build_lean4_project(project_folder)
+ language = ProofAction.Language.LEAN4
+ theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
+ # theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=project_folder,
+ file_path=file_path,
+ language=language,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
+ proof_steps = [
+"""calc
+_ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
+_ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
+_ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
+_ = n*n + n + n + 1 := by rw [Nat.pow_two]
+_ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
+_ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
+_ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
+_ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]""",
+"_ = (n + 1)*(n + 1) := by \n rw [Nat.right_distrib n 1 (n + 1)]"
+]
+ with env:
+ proof_was_finished = False
+ for proof_step in proof_steps:
+ state, _, next_state, _, done, info = env.step(ProofAction(
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
+ tactics=[proof_step]))
+ if info.error_message is not None:
+ print(f"Error: {info.error_message}")
+ # This prints StateChanged, StateUnchanged, Failed, or Done
+ print(f"DONE: {done}")
+ print(info.progress)
+ print('-'*30)
+ if done:
+ s1 : ProofState = state
+ pretty_print(s1, None, proof_step, done)
+ proof_was_finished = True
+ else:
+ s1 : ProofState = state
+ s2 : ProofState = next_state
+ pretty_print(s1, s2, proof_step, done)
+ assert proof_was_finished, "Proof was not finished"
+ # Run the validation
+ val_result = env.validate_proof_completion(timeout_in_secs=60, keep_validation_file=False)
+ print("Validation Result:")
+ print(val_result)
+ assert val_result.get('success', False), f"Proof validation failed:\n{val_result.get('error_message', '')}"
+ assert val_result.get('compilation_ok', False), f"Proof validation failed:\n{val_result.get('error_message', '')}"
+
def test_simple_lean_enforce_done_test(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
@@ -357,7 +388,7 @@ def test_simple_lean_enforce_done_test(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "test_calc"
+ theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
# theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
proof_exec_callback = ProofExecutorCallback(
project_folder=project_folder,
@@ -399,39 +430,12 @@ def test_simple_lean_enforce_done_test(self):
if done:
assert proof_step == "done", "Proof can only finish with done"
s1 : ProofState = state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print("Proof Finished!!")
+ pretty_print(s1, None, proof_step, done)
proof_finished = True
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
+ pretty_print(s1, s2, proof_step, done)
assert proof_finished, "Proof was not finished"
def test_simple_lean4_done_test(self):
@@ -447,7 +451,7 @@ def test_simple_lean4_done_test(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "test3"
+ theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
@@ -484,24 +488,7 @@ def test_simple_lean4_done_test(self):
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f"="*30)
+ pretty_print(s1, s2, proof_step, done)
def test_simple_lean4_have_test(self):
from itp_interface.rl.proof_state import ProofState
@@ -516,7 +503,7 @@ def test_simple_lean4_have_test(self):
helper = Helper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
- theorem_name = "imo_1959_p1"
+ theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"imo_1959_p1\"}'
# theorem test3 (p q : Prop) (hp : p) (hq : q)
# : p ∧ q ∧ p :=
proof_exec_callback = ProofExecutorCallback(
@@ -563,31 +550,81 @@ def test_simple_lean4_have_test(self):
else:
s1 : ProofState = state
s2 : ProofState = next_state
- print(f"Current Goal:")
- print('-'*30)
- for goal in s1.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f'*'*30)
- print(f"="*30)
- print(f"Action: {proof_step}")
- print(f"="*30)
- print(f"Next Goal:")
- print('-'*30)
- for goal in s2.training_data_format.start_goals:
- hyps = '\n'.join([hyp for hyp in goal.hypotheses])
- print(hyps)
- print('|- ', end='')
- print(goal.goal)
- print(f'*'*30)
- print(f"="*30)
- print(f"DONE: {done}")
- print('-'*30)
+ pretty_print(s1, s2, proof_step, done)
+
+ def test_simple_lean4_with_error(self):
+ from itp_interface.rl.proof_state import ProofState
+ from itp_interface.rl.proof_action import ProofAction
+ from itp_interface.rl.simple_proof_env import ProofEnv
+ from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+ from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
+ project_folder = "src/data/test/lean4_proj"
+ file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
+ # Build the project
+ # cd src/data/test/lean4_proj && lake build
+ helper = Helper()
+ helper.build_lean4_project(project_folder)
+ language = ProofAction.Language.LEAN4
+ theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
+ # theorem test3 (p q : Prop) (hp : p) (hq : q)
+ # : p ∧ q ∧ p :=
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=project_folder,
+ file_path=file_path,
+ language=language,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ env = ProofEnv("test_lean4", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
+ proof_steps = [
+ 'apply And.intro',
+ 'exact hpx', # Error here
+ 'exact hp', # This should automatically work
+ 'apply And.intro',
+ 'exact hq',
+ 'exact hp'
+ ]
+ proof_finished = False
+ with env:
+ for i, proof_step in enumerate(proof_steps):
+ state, _, next_state, _, done, info = env.step(ProofAction(
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
+ tactics=[proof_step]))
+ if info.error_message is not None:
+ print(f"Error: {info.error_message}")
+ print(f"Proof step {i + 1} failed")
+ if i == 1:
+ assert info.error_message is not None, "Error was expected at step 2"
+ else:
+ assert info.error_message is None, f"Error was not expected at step {i + 1}"
+ # This prints StateChanged, StateUnchanged, Failed, or Done
+ print(info.progress)
+ print('-'*30)
+ if done:
+ print("Proof Finished!!")
+ proof_finished = True
+ else:
+ s1 : ProofState = state
+ s2 : ProofState = next_state
+ pretty_print(s1, s2, proof_step, done)
+ assert proof_finished, "Proof was not finished"
def main():
unittest.main()
+ # Run only the Lean 4 tests
+ # t = Lean4Test()
+ # t.test_simple_lean4()
+ # t.test_lean4_backtracking()
+ # t.test_simple_lean4_done_test()
+ # t.test_simple_lean_calc()
+ # t.test_simple_lean_calc_with_validation()
+ # t.test_simple_lean4_with_error()
+ # t.test_simple_lean4_have_test()
+ # t.test_simple_lean_enforce_done_test()
+
if __name__ == '__main__':
main()
\ No newline at end of file
diff --git a/src/test/test_tactic_parser.py b/src/test/test_tactic_parser.py
new file mode 100644
index 0000000..cb79825
--- /dev/null
+++ b/src/test/test_tactic_parser.py
@@ -0,0 +1,316 @@
+"""
+Test cases for the Lean 4 tactic parser.
+
+Tests the tactic parser with various theorems from Basic.lean to ensure
+it correctly extracts atomic tactics from Lean 4 proofs.
+"""
+
+import sys
+from pathlib import Path
+
+# Add parent directory to path to import tactic_parser
+sys.path.insert(0, str(Path(__file__).parent.parent / "itp_interface" / "tools"))
+
+from itp_interface.tools.tactic_parser import TacticParser, print_tactics
+
+project_path = str(Path(__file__).parent.parent / "data" / "test" / "lean4_proj")
+
+class TestTacticParser:
+ """Test suite for the Lean 4 tactic parser."""
+
+ def test_simple_proof_with_apply_and_exact(self):
+ """Test parsing a simple proof with apply and exact tactics."""
+ lean_code = """theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract 5 atomic tactics
+ assert len(tactics) == 5, f"Expected 5 tactics, got {len(tactics)}"
+
+ # Check each tactic
+ assert tactics[0].text == "apply And.intro", f"Expected 'apply And.intro', got '{tactics[0].text}'"
+ assert tactics[0].line == 3
+
+ assert tactics[1].text == "exact hp", f"Expected 'exact hp', got '{tactics[1].text}'"
+ assert tactics[1].line == 4
+
+ assert tactics[2].text == "apply And.intro", f"Expected 'apply And.intro', got '{tactics[2].text}'"
+ assert tactics[2].line == 5
+
+ assert tactics[3].text == "exact hq", f"Expected 'exact hq', got '{tactics[3].text}'"
+ assert tactics[3].line == 6
+
+ assert tactics[4].text == "exact hp", f"Expected 'exact hp', got '{tactics[4].text}'"
+ assert tactics[4].line == 7
+
+ def test_simple_partial_proof(self):
+ """Test parsing a simple proof with apply and exact tactics."""
+ lean_code = """theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code, fail_on_error=False)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract 4 atomic tactics
+ assert len(tactics) == 1, f"Expected 1 tactic, got {len(tactics)}"
+
+ # Check each tactic
+ assert tactics[0].text == "apply And.intro", f"Expected 'apply And.intro', got '{tactics[0].text}'"
+ assert tactics[0].line == 3
+
+ def test_calc_proof(self):
+ """Test parsing a calc-based proof with multiple rewrite steps."""
+ lean_code = """
+import Mathlib
+
+theorem test_calc (n: Nat) : n^2 + 2*n + 1 = (n + 1)*(n + 1) := by
+calc
+ _ = n^2 + n*2 + 1 := by rw [Nat.mul_comm 2 n]
+ _ = n^2 + (n + n) + 1 := by rw [Nat.mul_two]
+ _ = n^2 + n + n + 1 := by rw [←Nat.add_assoc]
+ _ = n*n + n + n + 1 := by rw [Nat.pow_two]
+ _ = n*n + n*1 + n + 1 := by rw [Nat.mul_one n]
+ _ = n*(n + 1) + n + 1 := by rw [Nat.left_distrib n n 1]
+ _ = n*(n + 1) + (n + 1) := by rw [Nat.add_assoc]
+ _ = n*(n + 1) + 1*(n + 1) := by rw (config := { occs := .pos [2]}) [←Nat.mul_one (n + 1), Nat.mul_comm]
+ _ = (n + 1)*(n + 1) := by rw [Nat.right_distrib n 1 (n + 1)]
+done"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract tactics from calc steps (currently extracts lemma arguments)
+ # This is a known limitation - rw tactics are parsed deeply
+ assert len(tactics) > 0, "Expected tactics from calc proof"
+
+ def test_proof_with_indentation(self):
+ """Test parsing a proof with indented tactics."""
+ lean_code = """
+import Mathlib
+
+theorem test10 (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+ apply And.intro
+ exact hp
+ apply And.intro
+ exact hq
+ exact hp
+"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract 5 atomic tactics
+ assert len(tactics) == 5, f"Expected 5 tactics, got {len(tactics)}"
+
+ # Check tactics are correctly extracted
+ tactic_texts = [t.text for t in tactics]
+ expected_tactics = ["apply And.intro", "exact hp", "apply And.intro", "exact hq", "exact hp"]
+
+ for expected, actual in zip(expected_tactics, tactic_texts):
+ assert actual == expected, f"Expected '{expected}', got '{actual}'"
+
+ def test_complex_proof_with_have(self):
+ """Test parsing a complex proof with have, rw, ring, and linarith."""
+ lean_code = """
+import Mathlib
+
+theorem imo_1959_p1
+ (n : ℕ)
+ (h₀ : 0 < n) :
+ Nat.gcd (21*n + 4) (14*n + 3) = 1 := by
+rw [Nat.gcd_rec]
+rw [Nat.mod_eq_of_lt (by linarith)]
+rw [Nat.gcd_rec]
+rw [Nat.gcd_rec]
+have eq₂ : (21 * n + 4) % (14 * n + 3) = 7 * n + 1 := by
+ have eq₁ : 21 * n + 4 = (14 * n + 3) + (7 * n + 1) := by ring
+ rw [eq₁, Nat.add_mod, Nat.mod_self, zero_add]
+ have h₂ : 7 * n + 1 < 14 * n + 3 := by linarith
+ rw [Nat.mod_eq_of_lt]
+ rw [Nat.mod_eq_of_lt]
+ exact h₂
+ rw [Nat.mod_eq_of_lt]
+ exact h₂
+ exact h₂
+rw [eq₂]
+sorry"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract multiple tactics
+ assert len(tactics) == 7, f"Expected 7 tactics, got {len(tactics)}"
+
+ # Check that various tactic types are present
+ tactic_texts = [t.text for t in tactics]
+ has_rw = any("rw" in t for t in tactic_texts)
+ has_have = any("have" in t for t in tactic_texts)
+ has_exact = any("exact" in t for t in tactic_texts)
+ has_sorry = any("sorry" in t for t in tactic_texts)
+
+ assert has_rw, "Expected to find 'rw' tactics"
+ assert has_have, "Expected to find 'have' tactics"
+ assert has_exact, "Expected to find 'exact' tactics"
+ assert has_sorry, "Expected to find 'sorry' tactic"
+
+ def test_simple_one_liner(self):
+ """Test parsing a simple one-line proof."""
+ lean_code = "example : True := by trivial"
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract at least the trivial tactic
+ assert len(tactics) > 0, "Expected at least one tactic"
+
+ # Check that trivial is present
+ tactic_texts = [t.text for t in tactics]
+ has_trivial = any("trivial" in t for t in tactic_texts)
+ assert has_trivial, f"Expected to find 'trivial' tactic, got {tactic_texts}"
+
+ def test_term_mode_proof_no_tactics(self):
+ """Test that term-mode proofs (no tactics) return empty list."""
+ lean_code = "theorem test2 : p -> q -> p ∧ q ∧ p := fun hp hq => ⟨hp, ⟨hq, hp⟩⟩"
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Term-mode proof should have no tactics
+ assert len(tactics) == 0, f"Expected 0 tactics for term-mode proof, got {len(tactics)}"
+
+ def test_proof_with_done(self):
+ """Test parsing a proof ending with 'done'."""
+ lean_code = """theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp
+done"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Should extract 6 atomic tactics (done is filtered out)
+ assert len(tactics) == 6, f"Expected 6 tactics, got {len(tactics)}"
+
+ # 'done' should not be in the tactics
+ tactic_texts = [t.text for t in tactics]
+ assert "done" in tactic_texts, "Expected 'done' in tactic texts"
+
+ def test_no_duplicate_tactics(self):
+ """Test that duplicate tactics are filtered out."""
+ lean_code = """theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Check that there are no exact duplicates (same text, same position)
+ seen = set()
+ for tactic in tactics:
+ key = (tactic.text, tactic.line, tactic.column)
+ assert key not in seen, f"Found duplicate tactic: {tactic.text} at line {tactic.line}"
+ seen.add(key)
+
+ def test_tactic_positions(self):
+ """Test that tactic positions are correctly reported."""
+ lean_code = """theorem test (p q : Prop) (hp : p) (hq : q)
+: p ∧ q ∧ p := by
+apply And.intro
+exact hp
+apply And.intro
+exact hq
+exact hp"""
+
+ with TacticParser(project_path=project_path) as parser:
+ tactics, error_str = parser.parse(lean_code)
+ print_tactics(tactics)
+ if error_str:
+ print(f"Error: {error_str}")
+
+ # Check that positions are reasonable
+ for tactic in tactics:
+ assert tactic.line > 0, f"Line number should be positive, got {tactic.line}"
+ assert tactic.column >= 0, f"Column number should be non-negative, got {tactic.column}"
+ assert tactic.end_line >= tactic.line, \
+ f"End line {tactic.end_line} should be >= start line {tactic.line}"
+ assert tactic.end_column > 0, f"End column should be positive, got {tactic.end_column}"
+
+
+if __name__ == "__main__":
+ # Run tests without pytest
+ import traceback
+
+ test_suite = TestTacticParser()
+ test_methods = [
+ method for method in dir(test_suite)
+ if method.startswith("test_")
+ ]
+
+ passed = 0
+ failed = 0
+
+ for test_method in test_methods:
+ try:
+ print(f"\n{'='*60}")
+ print(f"Running: {test_method}")
+ print('='*60)
+ getattr(test_suite, test_method)()
+ print(f"✓ PASSED: {test_method}")
+ passed += 1
+ except Exception as e:
+ print(f"✗ FAILED: {test_method}")
+ print(f"Error: {e}")
+ traceback.print_exc()
+ failed += 1
+
+ print(f"\n{'='*60}")
+ print(f"Test Results: {passed} passed, {failed} failed")
+ print('='*60)
+
+ if failed > 0:
+ sys.exit(1)