Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions tests/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
DSL for building Vyper contracts in tests.

Example usage:
from tests.dsl import CodeModel

# create a model
model = CodeModel()

# define storage variables
balance = model.storage_var('balance: uint256')
owner = model.storage_var('owner: address')

# build a simple contract
code = (model
.function('__init__()')
.deploy()
.body(f'{owner} = msg.sender')
.done()
.function('deposit()')
.external()
.payable()
.body(f'{balance} += msg.value')
.done()
.function('get_balance() -> uint256')
.external()
.view()
.body(f'return {balance}')
.done()
.build())

# The generated code will be:
# balance: uint256
# owner: address
#
# @deploy
# def __init__():
# self.owner = msg.sender
#
# @external
# @payable
# def deposit():
# self.balance += msg.value
#
# @external
# @view
# def get_balance() -> uint256:
# return self.balance
"""

from tests.dsl.code_model import CodeModel, VarRef

__all__ = [CodeModel, VarRef]
224 changes: 224 additions & 0 deletions tests/dsl/code_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""
Code model for building Vyper contracts programmatically.

This module provides a fluent API for constructing Vyper contracts
with proper formatting and structure.
"""

import textwrap
from typing import Optional, List, Dict, Any, Union

from vyper.ast import parse_to_ast
from vyper.ast.nodes import FunctionDef


class VarRef:
"""Reference to a variable with type and location information."""

def __init__(self, name: str, typ: str, location: str, visibility: Optional[str] = None):
self.name = name
self.typ = typ
self.location = location
self.visibility = visibility

def __str__(self) -> str:
"""Return the variable name for use in expressions."""
# storage and transient vars need self prefix
if self.location in ("storage", "transient"):
return f"self.{self.name}"
return self.name


class FunctionBuilder:
"""Builder for function definitions."""

def __init__(self, signature: str, parent: "CodeModel"):
self.signature = signature
self.parent = parent
self.decorators: List[str] = []
self.body_code: Optional[str] = None
self.is_internal = True # functions are internal by default

# parse just the name from the signature
paren_idx = signature.find('(')
if paren_idx == -1:
raise ValueError(f"Invalid function signature: {signature}")
self.name = signature[:paren_idx].strip()

def __str__(self) -> str:
"""Return the function name for use in expressions."""
if self.is_internal:
return f"self.{self.name}"
return self.name

def external(self) -> "FunctionBuilder":
"""Add @external decorator."""
self.decorators.append("@external")
self.is_internal = False
return self

def internal(self) -> "FunctionBuilder":
"""Add @internal decorator."""
self.decorators.append("@internal")
self.is_internal = True
return self

def deploy(self) -> "FunctionBuilder":
"""Add @deploy decorator."""
self.decorators.append("@deploy")
self.is_internal = False # deploy functions are not called with self
return self

def view(self) -> "FunctionBuilder":
"""Add @view decorator."""
self.decorators.append("@view")
return self

def pure(self) -> "FunctionBuilder":
"""Add @pure decorator."""
self.decorators.append("@pure")
return self

def payable(self) -> "FunctionBuilder":
"""Add @payable decorator."""
self.decorators.append("@payable")
return self

def nonreentrant(self) -> "FunctionBuilder":
"""Add @nonreentrant decorator."""
self.decorators.append("@nonreentrant")
return self

def body(self, code: str) -> "FunctionBuilder":
"""Set the function body."""
# dedent the code to handle multi-line strings nicely
self.body_code = textwrap.dedent(code).strip()
return self

def done(self) -> "CodeModel":
"""Finish building the function and return to parent CodeModel."""
lines = []

lines.extend(self.decorators)
lines.append(f"def {self.signature}:")

if self.body_code:
indented_body = "\n".join(f" {line}" for line in self.body_code.split("\n"))
lines.append(indented_body)
else:
lines.append(" pass")

self.parent._functions.append("\n".join(lines))
return self.parent


class CodeModel:
"""Model for building a Vyper contract."""

def __init__(self):
self._storage_vars: List[str] = []
self._transient_vars: List[str] = []
self._constants: List[str] = []
self._immutables: List[str] = []
self._events: List[str] = []
self._structs: List[str] = []
self._flags: List[str] = []
self._functions: List[str] = []
self._imports: List[str] = []
self._local_vars: Dict[str, VarRef] = {}

def storage_var(self, declaration: str) -> VarRef:
"""Add a storage variable."""
name, typ = self._parse_declaration(declaration)
self._storage_vars.append(declaration)
return VarRef(name, typ, "storage", "public")

def transient_var(self, declaration: str) -> VarRef:
"""Add a transient storage variable."""
name, typ = self._parse_declaration(declaration)
self._transient_vars.append(f"{name}: transient({typ})")
return VarRef(name, typ, "transient", "public")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vars will be public by default?


def constant(self, declaration: str) -> VarRef:
"""Add a constant."""
# constants have format: "NAME: constant(type) = value"
parts = declaration.split(":", 1)
name = parts[0].strip()
# extract type from constant(...) = value
type_start = parts[1].find("constant(") + 9
type_end = parts[1].find(")", type_start)
typ = parts[1][type_start:type_end].strip()

self._constants.append(declaration)
return VarRef(name, typ, "constant", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does the visibility for constants differ?


def immutable(self, declaration: str) -> VarRef:
"""Add an immutable variable."""
name, typ = self._parse_declaration(declaration)
self._immutables.append(f"{name}: immutable({typ})")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for immutables we don't have to declare with "immutable" but for constants "constant" is required?

return VarRef(name, typ, "immutable", "public")

def local_var(self, name: str, typ: str) -> VarRef:
"""Register a local variable (used in function bodies)."""
ref = VarRef(name, typ, "memory", None)
self._local_vars[name] = ref
return ref

def event(self, definition: str) -> None:
"""Add an event definition."""
self._events.append(f"event {definition}")

def struct(self, definition: str) -> None:
"""Add a struct definition."""
self._structs.append(f"struct {definition}")

def flag(self, definition: str) -> None:
"""Add a flag (enum) definition."""
self._flags.append(f"flag {definition}")

def function(self, signature: str) -> FunctionBuilder:
"""Start building a function."""
return FunctionBuilder(signature, self)

def build(self) -> str:
"""Build the complete contract code."""
sections = []

if self._imports:
sections.append("\n".join(self._imports))

if self._events:
sections.append("\n".join(self._events))

if self._structs:
sections.append("\n".join(self._structs))

if self._flags:
sections.append("\n".join(self._flags))

if self._constants:
sections.append("\n".join(self._constants))

if self._storage_vars:
sections.append("\n".join(self._storage_vars))

if self._transient_vars:
sections.append("\n".join(self._transient_vars))

if self._immutables:
sections.append("\n".join(self._immutables))

if self._functions:
sections.append("\n\n".join(self._functions))

return "\n\n".join(sections)

def _parse_declaration(self, declaration: str) -> tuple[str, str]:
"""Parse a variable declaration of form 'name: type' into (name, type)."""
parts = declaration.split(":", 1)
if len(parts) != 2:
raise ValueError(f"Invalid declaration format: {declaration}")

name = parts[0].strip()
typ = parts[1].strip()
return name, typ
Loading
Loading