Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ pip install itp-interface
```

2. Run the following command to prepare the REPL for Lean 4. The default version is 4.24.0. You can change the version by setting the `LEAN_VERSION` environment variable. If no version is set, then 4.24.0 is used.
>NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with.
>NOTE: The Lean 4 version must match the version of the Lean 4 project you are working with. `itp-interface` **supports Lean 4 version >= 4.15.0 and <= 4.24.0**. (It has been tested till version 4.24.0, but might as well work for future versions too, if the future versions are completely backwards-compatible).

```bash
install-lean-repl
# To use a different Lean version, set LEAN_VERSION before running:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
version = "1.1.15"
version = "1.1.16"
authors = [
{ name="Amitayush Thakur", email="[email protected]" },
]
Expand Down
12 changes: 11 additions & 1 deletion src/itp_interface/main/install.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import random
import string
import logging
import traceback
from itp_interface.tools.tactic_parser import build_tactic_parser_if_needed

file_path = os.path.abspath(__file__)

logging.basicConfig(level=logging.INFO)
# Create a console logger
logger = logging.getLogger(__name__)

def generate_random_string(length, allowed_chars=None):
if allowed_chars is None:
Expand All @@ -25,7 +30,12 @@ def install_itp_interface():
print("Lean toolchain version for tactic_parser: ", lean_toolchain_content)
print(f"LEAN_VERSION set: {os.environ.get('LEAN_VERSION', 'Not Set')}")
print("Building itp_interface")
build_tactic_parser_if_needed()
try:
build_tactic_parser_if_needed(logger)
except Exception:
# print the stack trace
traceback.print_exc()
raise

def install_lean_repl():
print("Updating Lean")
Expand Down
6 changes: 3 additions & 3 deletions src/itp_interface/tools/dynamic_lean4_proof_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ def cancel_tactic_till_line(self, tactic_line_num: int, no_backtracking: bool =
if line_num >= tactic_line_num:
del self.run_state.line_tactic_map[line_num]
line_proof_context_map_keys = list(self.run_state.line_proof_context_map.keys())
for line_num in line_proof_context_map_keys:
if line_num >= tactic_line_num:
del self.run_state.line_proof_context_map[line_num]
if not no_backtracking:
self.proof_context = self.run_state.line_proof_context_map[tactic_line_num]
self.line_num = tactic_line_num
cancelled_some_tactics = self._backtrack_tactic_line(tactic_line_num)
self._proof_running = self.proof_context is not None
for line_num in line_proof_context_map_keys:
if line_num >= tactic_line_num:
del self.run_state.line_proof_context_map[line_num]
return cancelled_some_tactics
29 changes: 0 additions & 29 deletions src/itp_interface/tools/tactic_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,35 +249,6 @@ class FileDependencyAnalysis(BaseModel):
def __repr__(self) -> str:
return f"FileDependencyAnalysis({self.module_name}, {len(self.declarations)} decls)"

# def to_dict(self) -> Dict:
# """Convert to dictionary matching the JSON format."""
# return {
# "filePath": self.file_path,
# "moduleName": self.module_name,
# "imports": self.imports,
# "declarations": [
# {
# "declaration": decl.declaration.to_dict(),
# "dependencies": [dep.model_dump(exclude_none=True) for dep in decl.dependencies],
# "unresolvedNames": decl.unresolved_names,
# **({"declId": decl.decl_id} if decl.decl_id else {})
# }
# for decl in self.declarations
# ]
# }

# @staticmethod
# def from_dict(data: Dict) -> 'FileDependencyAnalysis':
# """Parse from the JSON dict returned by dependency-parser executable."""
# declarations = DeclWithDependencies.from_dependency_analysis(data)

# return FileDependencyAnalysis(
# file_path=data['filePath'],
# module_name=data['moduleName'],
# imports=data.get('imports', []),
# declarations=declarations
# )

# Create an enum for parsing request type
class RequestType(Enum):
PARSE_TACTICS = "parse_tactics"
Expand Down
79 changes: 40 additions & 39 deletions src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,50 @@ def charToValue (c : Char) : Option UInt8 :=
else
none

/-- Decode a base64 string to bytes -/
def decodeBytes (s : String) : Except String ByteArray :=
let chars := s.trim.toList.filter (· ≠ '=')
let rec loop (i : Nat) (result : ByteArray) : Except String ByteArray :=
if i >= chars.length then
.ok result
else
-- Get 4 characters (or remaining)
let c1 := chars[i]!
let c2 := if i + 1 < chars.length then chars[i + 1]! else '='
let c3 := if i + 2 < chars.length then chars[i + 2]! else '='
let c4 := if i + 3 < chars.length then chars[i + 3]! else '='
partial def loop (i : Nat) (result : ByteArray) (chars : List Char) : Except String ByteArray :=
if i >= chars.length then
.ok result
else
-- Get 4 characters (or remaining)
let c1 := chars[i]!
let c2 := if i + 1 < chars.length then chars[i + 1]! else '='
let c3 := if i + 2 < chars.length then chars[i + 2]! else '='
let c4 := if i + 3 < chars.length then chars[i + 3]! else '='

match Option.toExcept (charToValue c1) "Invalid base64 character" with
match Option.toExcept (charToValue c1) "Invalid base64 character" with
| .error e => .error e
| .ok v1 =>
match Option.toExcept (charToValue c2) "Invalid base64 character" with
| .error e => .error e
| .ok v1 =>
match Option.toExcept (charToValue c2) "Invalid base64 character" with
| .error e => .error e
| .ok v2 =>
-- First byte is always present
let b1 := (v1.toNat <<< 2 ||| (v2.toNat >>> 4)).toUInt8
let result := result.push b1
| .ok v2 =>
-- First byte is always present
let b1 := (v1.toNat <<< 2 ||| (v2.toNat >>> 4)).toUInt8
let result := result.push b1

-- Second byte if c3 exists
if c3 ≠ '=' then
match Option.toExcept (charToValue c3) "Invalid base64 character" with
| .error e => .error e
| .ok v3 =>
let b2 := ((v2.toNat &&& 0xF) <<< 4 ||| (v3.toNat >>> 2)).toUInt8
let result := result.push b2
-- Second byte if c3 exists
if c3 ≠ '=' then
match Option.toExcept (charToValue c3) "Invalid base64 character" with
| .error e => .error e
| .ok v3 =>
let b2 := ((v2.toNat &&& 0xF) <<< 4 ||| (v3.toNat >>> 2)).toUInt8
let result := result.push b2

-- Third byte if c4 exists
if c4 ≠ '=' then
match Option.toExcept (charToValue c4) "Invalid base64 character" with
| .error e => .error e
| .ok v4 =>
let b3 := ((v3.toNat &&& 0x3) <<< 6 ||| v4.toNat).toUInt8
loop (i + 4) (result.push b3)
else
loop (i + 4) result
else
loop (i + 4) result
loop 0 ByteArray.empty
-- Third byte if c4 exists
if c4 ≠ '=' then
match Option.toExcept (charToValue c4) "Invalid base64 character" with
| .error e => .error e
| .ok v4 =>
let b3 := ((v3.toNat &&& 0x3) <<< 6 ||| v4.toNat).toUInt8
loop (i + 4) (result.push b3) chars
else
loop (i + 4) result chars
else
loop (i + 4) result chars

/-- Decode a base64 string to bytes -/
def decodeBytes (s : String) : Except String ByteArray :=
let chars := s.trim.toList.filter (· ≠ '=')
loop 0 ByteArray.empty chars

/-- Decode a base64 string to UTF-8 string -/
def decode (s : String) : Except String String := do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,31 @@ def extractDependenciesFromSyntax (env : Environment) (stx : Syntax) : Array Dec
| none => (deps, unres.push name.toString)
) (#[], #[])

/-- Parse the header recursively to find all import commands -/
partial def findImports (stx : Syntax) (content: String) : IO (Array ImportInfo) → IO (Array ImportInfo)
| accIO => do
let acc ← accIO
match stx with
| Syntax.node _ kind args =>
if kind == `Lean.Parser.Module.import then
-- Found an import
match stx.getRange? with
| some range =>
let moduleName := extractModuleName stx
let text := content.extract range.start range.stop
let info : ImportInfo := {
moduleName := moduleName
startPos := range.start.byteIdx
endPos := range.stop.byteIdx
text := text
}
return acc.push info
| none => return acc
else
-- Recursively search children
args.foldlM (fun acc child => findImports child content (pure acc)) acc
| _ => return acc

/-- Parse imports and namespaces from a Lean 4 file -/
def parseImports (filepath : System.FilePath) : IO DependencyInfo := do
let content ← IO.FS.readFile filepath
Expand All @@ -164,32 +189,9 @@ def parseImports (filepath : System.FilePath) : IO DependencyInfo := do
-- Extract the underlying syntax from TSyntax
let headerSyn : Syntax := headerStx

-- Parse the header recursively to find all import commands
let rec findImports (stx : Syntax) : IO (Array ImportInfo) → IO (Array ImportInfo)
| accIO => do
let acc ← accIO
match stx with
| Syntax.node _ kind args =>
if kind == `Lean.Parser.Module.import then
-- Found an import
match stx.getRange? with
| some range =>
let moduleName := extractModuleName stx
let text := content.extract range.start range.stop
let info : ImportInfo := {
moduleName := moduleName
startPos := range.start.byteIdx
endPos := range.stop.byteIdx
text := text
}
return acc.push info
| none => return acc
else
-- Recursively search children
args.foldlM (fun acc child => findImports child (pure acc)) acc
| _ => return acc

imports ← findImports headerSyn (pure imports)


imports ← findImports headerSyn content (pure imports)

-- Now parse the rest of the file to find namespace declarations
let env ← Lean.importModules #[] {} 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ partial def identifySomeDeclType (stx : Syntax) : Option (DeclType × Nat) :=
match stx with
| Syntax.node _ _ args =>
-- Look for the actual declaration type in the children
let idx := args.findIdx (fun a => (identifySomeDeclType a).isSome);
-- NOTE: using findIdx? for backward compatibility to Lean 4.15
-- (as.findIdx? p).getD as.size
let idx := (args.findIdx? (fun a => (identifySomeDeclType a).isSome)).getD args.size;
if idx = args.size then
some (.unknown, idx)
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def nodeEndPos (node : InfoTreeNode) : Option Position :=
| InfoTreeNode.leanInfo _ _ _ _ _ endPos _ _ => some endPos
| _ => none

def filterAllNodesWhichDontStartAndEndOnLine (node : InfoTreeNode) (line_num: Nat) : Array InfoTreeNode :=
partial def filterAllNodesWhichDontStartAndEndOnLine (node : InfoTreeNode) (line_num: Nat) : Array InfoTreeNode :=
match node with
| .context child =>
filterAllNodesWhichDontStartAndEndOnLine child line_num
Expand Down Expand Up @@ -213,7 +213,7 @@ let arg_max := all_possible_nodes.foldl (fun (acc_node, acc_len) n =>
) (InfoTreeNode.hole, 0)
arg_max

def getAllLinesInTree (node : InfoTreeNode) : Std.HashSet Nat :=
partial def getAllLinesInTree (node : InfoTreeNode) : Std.HashSet Nat :=
match node with
| .context child =>
getAllLinesInTree child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ partial def InfoNodeStruct.toJson (n: InfoNodeStruct) : Json :=
instance : ToJson InfoNodeStruct where
toJson := InfoNodeStruct.toJson

def getInfoNodeStruct (node : InfoTreeNode) : Option InfoNodeStruct :=
partial def getInfoNodeStruct (node : InfoTreeNode) : Option InfoNodeStruct :=
match node with
| .leanInfo declType name docString text startPos endPos namespc children =>
let childStructs := children.map getInfoNodeStruct
Expand Down
1 change: 0 additions & 1 deletion src/test/simple_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def test_simple_lean4(self):
assert proof_was_finished, "Proof was not finished"

def test_lean4_backtracking(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
from itp_interface.rl.simple_proof_env import ProofEnv
from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
Expand Down