Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/deps/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ async def main() -> None:

# Use the package immediately
print("\n3. Using the installed package:")
result = await session.run('''
result = await session.run("""
import art
print(art.text2art("Hello!"))
''')
""")
print(result.stdout)

# Add another package with version specifier
Expand Down
9 changes: 7 additions & 2 deletions src/py_code_mode/execution/container/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,15 @@ def to_docker_config(
}

# Add port binding
bind_ip: str | None = None
if self.auth_disabled and not self.auth_token:
# Safer default: when auth is disabled, don't publish on all interfaces.
bind_ip = "127.0.0.1"

if self.port > 0:
config["ports"] = {"8080/tcp": self.port}
config["ports"] = {"8080/tcp": (bind_ip, self.port) if bind_ip else self.port}
else:
config["ports"] = {"8080/tcp": None} # Auto-assign
config["ports"] = {"8080/tcp": (bind_ip, None) if bind_ip else None} # Auto-assign

# Add volumes from storage access
volumes = {}
Expand Down
6 changes: 3 additions & 3 deletions src/py_code_mode/execution/container/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ async def api_search_tools(query: str, limit: int = 10) -> list[dict[str, Any]]:
async def api_list_workflows() -> list[dict[str, Any]]:
"""Return all workflows."""
if _state.workflow_library is None:
return []
raise HTTPException(status_code=503, detail="Workflow library not initialized")

workflows = _state.workflow_library.list()
return [
Expand All @@ -809,7 +809,7 @@ async def api_list_workflows() -> list[dict[str, Any]]:
async def api_search_workflows(query: str, limit: int = 5) -> list[dict[str, Any]]:
"""Search workflows."""
if _state.workflow_library is None:
return []
raise HTTPException(status_code=503, detail="Workflow library not initialized")

workflows = _state.workflow_library.search(query, limit=limit)
return [
Expand All @@ -825,7 +825,7 @@ async def api_search_workflows(query: str, limit: int = 5) -> list[dict[str, Any
async def api_get_workflow(name: str) -> dict[str, Any] | None:
"""Get workflow by name with full source."""
if _state.workflow_library is None:
return None
raise HTTPException(status_code=503, detail="Workflow library not initialized")

workflow = _state.workflow_library.get(name)
if workflow is None:
Expand Down
25 changes: 18 additions & 7 deletions src/py_code_mode/tools/adapters/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from dataclasses import dataclass, field
from typing import Any
from urllib.parse import quote

from py_code_mode.errors import ToolCallError, ToolNotFoundError
from py_code_mode.tools.types import Tool, ToolCallable, ToolParameter
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
self,
base_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = 30.0,
) -> None:
"""Initialize adapter with base URL.

Expand All @@ -63,6 +65,7 @@ def __init__(
"""
self.base_url = base_url.rstrip("/")
self.headers = headers or {}
self.timeout = timeout
self._endpoints: dict[str, Endpoint] = {}

@property
Expand Down Expand Up @@ -148,13 +151,21 @@ async def call_tool(
) from e

# Build URL with path parameters
url = self._build_url(endpoint.path, args)
try:
url = self._build_url(endpoint.path, args)
except Exception as e:
raise ToolCallError(name, tool_args=args, cause=e) from e

# Separate path params from body params
path_params = self._extract_path_params(endpoint.path)
body_params = {k: v for k, v in args.items() if k not in path_params}

async with aiohttp.ClientSession(headers=self.headers) as session:
session_timeout = (
aiohttp.ClientTimeout(total=self.timeout)
if self.timeout is not None
else aiohttp.ClientTimeout(total=None)
)
async with aiohttp.ClientSession(headers=self.headers, timeout=session_timeout) as session:
try:
if endpoint.method.upper() in ("POST", "PUT", "PATCH"):
response = await session.request(
Expand All @@ -179,7 +190,7 @@ async def call_tool(

return await response.json()

except aiohttp.ClientError as e:
except (aiohttp.ClientError, TimeoutError) as e:
raise ToolCallError(name, tool_args=args, cause=e) from e

def _build_url(self, path: str, args: dict[str, Any]) -> str:
Expand All @@ -192,12 +203,12 @@ def _build_url(self, path: str, args: dict[str, Any]) -> str:
Returns:
Full URL with parameters substituted.
"""
# Substitute path parameters
url = self.base_url + path
for param_name, param_value in args.items():
for param_name in self._extract_path_params(path):
if param_name not in args:
raise ValueError(f"Missing required path parameter: {param_name}")
placeholder = "{" + param_name + "}"
if placeholder in url:
url = url.replace(placeholder, str(param_value))
url = url.replace(placeholder, quote(str(args[param_name]), safe=""))
return url

def _extract_path_params(self, path: str) -> set[str]:
Expand Down
6 changes: 2 additions & 4 deletions src/py_code_mode/workflows/vector_stores/redis_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,9 @@ def _knn_search(
try:
results = self._redis.ft(self._index_name).search(q, query_params={"vec": query_bytes})
except redis.exceptions.ResponseError as e:
logger.error(f"RediSearch query failed: {e}")
return {}
raise RuntimeError(f"RediSearch query failed: {e}") from e
except redis.exceptions.ConnectionError as e:
logger.error(f"Redis connection failed during search: {e}")
return {}
raise RuntimeError(f"Redis connection failed during search: {e}") from e

scores: dict[str, float] = {}
for doc in results.docs:
Expand Down
25 changes: 20 additions & 5 deletions src/py_code_mode/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,22 @@ class WorkflowParameter:
default: Any = None


# Special parameters that are injected, not user-provided
_INJECTED_PARAMS = {"tools", "workflows", "artifacts", "deps"}
_NAMESPACE_GLOBALS = frozenset({"tools", "workflows", "artifacts", "deps"})


def _validate_run_does_not_take_namespace_params(run_func: ast.AsyncFunctionDef) -> None:
"""Workflows access namespaces via globals, not run() parameters."""
args = [
*run_func.args.posonlyargs,
*run_func.args.args,
*run_func.args.kwonlyargs,
]
for arg in args:
if arg.arg in _NAMESPACE_GLOBALS:
raise ValueError(
f"Workflow run() must not declare parameter {arg.arg!r}; "
f"use the global {arg.arg} namespace instead."
)


def _annotation_to_type_str(annotation: ast.expr | None) -> str:
Expand Down Expand Up @@ -133,9 +147,6 @@ def _extract_parameters_from_ast(run_func: ast.AsyncFunctionDef) -> list[Workflo
parameters: list[WorkflowParameter] = []

def _add_param(arg_node: ast.arg, default_node: ast.expr | None) -> None:
if arg_node.arg in _INJECTED_PARAMS:
return

if default_node is not None:
default_val = _default_expr_to_value(default_node)
has_default = True
Expand Down Expand Up @@ -228,6 +239,8 @@ def from_source(
if has_sync_run:
raise ValueError("Workflow must define 'async def run()', not 'def run()'")

_validate_run_does_not_take_namespace_params(run_node)

# Extract description from source if not provided
if not description:
# Try module docstring first
Expand Down Expand Up @@ -275,6 +288,8 @@ def from_file(cls, path: Path) -> PythonWorkflow:
if has_sync_run:
raise ValueError(f"Workflow {path} must define 'async def run()', not 'def run()'")

_validate_run_does_not_take_namespace_params(run_node)

# Extract description from module or function docstring
description = ast.get_docstring(tree) or ast.get_docstring(run_node) or ""
description = description.strip().split("\n")[0]
Expand Down
30 changes: 30 additions & 0 deletions tests/container/test_container_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,36 @@ def test_search_workflows_requires_auth(self, auth_client) -> None:
response = client.get("/api/workflows/search", params={"query": "fetch"})
assert response.status_code == 401

def test_workflows_endpoints_return_503_if_library_not_initialized(self, tmp_path) -> None:
"""Workflows endpoints should not silently return empty results if init failed."""
try:
from fastapi.testclient import TestClient
except ImportError:
pytest.skip("FastAPI not installed")

from py_code_mode.execution.container.config import SessionConfig
from py_code_mode.execution.container.server import create_app

# Force workflow library init failure: workflows_path exists as a FILE.
workflows_path = tmp_path / "workflows"
workflows_path.write_text("not a directory")

config = SessionConfig(
artifacts_path=tmp_path / "artifacts",
workflows_path=workflows_path,
)
config.auth_token = "test-token"

app = create_app(config)
with TestClient(app) as client:
headers = {"Authorization": "Bearer test-token"}

resp = client.get("/api/workflows", headers=headers)
assert resp.status_code == 503

resp = client.get("/api/workflows/search", params={"query": "x"}, headers=headers)
assert resp.status_code == 503

def test_get_workflow_returns_none_when_not_found(self, auth_client) -> None:
"""GET /api/workflows/{name} returns null when workflow not found."""
client, token = auth_client
Expand Down
22 changes: 22 additions & 0 deletions tests/container/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,28 @@ async def test_to_docker_config_with_volumes(self, tmp_path) -> None:
)


class TestContainerExecutorNetworking:
"""Tests for networking configuration."""

def test_auth_disabled_binds_published_port_to_localhost(self) -> None:
"""When auth is disabled, published ports should not bind 0.0.0.0."""
config = ContainerConfig(port=9000, auth_disabled=True)
docker_config = config.to_docker_config()
assert docker_config["ports"]["8080/tcp"] == ("127.0.0.1", 9000)

def test_auth_disabled_binds_auto_port_to_localhost(self) -> None:
"""When auth is disabled and port=0, auto-assigned ports should bind localhost."""
config = ContainerConfig(port=0, auth_disabled=True)
docker_config = config.to_docker_config()
assert docker_config["ports"]["8080/tcp"] == ("127.0.0.1", None)

def test_auth_enabled_keeps_default_port_binding(self) -> None:
"""With auth enabled, preserve existing port binding behavior."""
config = ContainerConfig(port=9000, auth_token="secret")
docker_config = config.to_docker_config()
assert docker_config["ports"]["8080/tcp"] == 9000


class TestContainerExecutorEnvironment:
"""Tests for environment configuration."""

Expand Down
34 changes: 34 additions & 0 deletions tests/test_http_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,40 @@ async def test_call_tool_handles_http_error(self, adapter) -> None:
with pytest.raises(ToolCallError):
await adapter.call_tool("get_user", None, {"user_id": 42})

@pytest.mark.asyncio
async def test_call_tool_urlencodes_path_parameters(self, adapter) -> None:
"""Path parameters are URL-encoded to avoid accidental path traversal."""
with patch("aiohttp.ClientSession") as mock_client_session:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={"ok": True})

mock_session = AsyncMock()
mock_session.request = AsyncMock(return_value=mock_response)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_client_session.return_value = mock_session

await adapter.call_tool("get_user", None, {"user_id": "a/b"})

call_args = mock_session.request.call_args
assert "/users/a%2Fb" in call_args[0][1]

@pytest.mark.asyncio
async def test_call_tool_timeout_is_wrapped(self, adapter) -> None:
"""Timeout errors are raised as ToolCallError (not left as bare TimeoutError)."""
from py_code_mode.errors import ToolCallError

with patch("aiohttp.ClientSession") as mock_client_session:
mock_session = AsyncMock()
mock_session.request = AsyncMock(side_effect=TimeoutError())
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_client_session.return_value = mock_session

with pytest.raises(ToolCallError):
await adapter.call_tool("get_user", None, {"user_id": 42})


class TestHTTPAdapterWithRegistry:
"""Tests for HTTPAdapter integration with ToolRegistry."""
Expand Down
27 changes: 16 additions & 11 deletions tests/test_skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,29 @@ def test_workflow_with_tools_access(self, tmp_path: Path) -> None:
dedent('''
"""Scan a network target."""

async def run(target: str, tools) -> str:
"""Run a scan using tools.

Args:
target: Target to scan
tools: Tools namespace (injected)
"""
async def run(target: str) -> str:
"""Run a scan using tools."""
# In real use, would call tools.call(...)
_ = tools # namespace is available as a global
return f"Scanning {target}"
''').strip()
)

workflow = PythonWorkflow.from_file(workflow_path)

# tools parameter should be recognized as special, not a user param
user_params = [p for p in workflow.parameters if p.name != "tools"]
assert len(user_params) == 1
assert user_params[0].name == "target"
assert [p.name for p in workflow.parameters] == ["target"]


class TestPythonWorkflowNamespaceParamValidation:
def test_from_source_rejects_namespace_params(self) -> None:
"""run() must not accept tools/workflows/artifacts/deps as parameters."""
source = dedent("""
async def run(x: int, tools) -> int:
return x
""").strip()

with pytest.raises(ValueError, match=r"must not declare parameter 'tools'"):
PythonWorkflow.from_source(name="bad", source=source)


class TestPythonWorkflowFromSource:
Expand Down
Loading