Skip to content

Commit aa46006

Browse files
feat: enhance provider initialization and tool execution with improved error handling and timeout management
1 parent 87c7e1e commit aa46006

2 files changed

Lines changed: 189 additions & 33 deletions

File tree

justllms/core/client.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
2+
3+
import logging
24

35
from justllms.config import Config
46
from justllms.core.base import BaseProvider, BaseResponse
57
from justllms.core.completion import Completion, CompletionResponse
68
from justllms.core.models import Message, ProviderConfig
7-
from justllms.exceptions import ProviderError
9+
from justllms.exceptions import ConfigurationError, ProviderError
810
from justllms.routing import Router
911

12+
logger = logging.getLogger(__name__)
13+
1014
if TYPE_CHECKING:
1115
from justllms.core.streaming import AsyncStreamResponse, SyncStreamResponse
1216

@@ -69,31 +73,53 @@ def _initialize_providers(self) -> None:
6973
"""Initialize providers based on configuration settings.
7074
7175
Creates provider instances for all enabled providers in the configuration
72-
that have valid API keys. Silently skips providers that fail to initialize
73-
to allow partial functionality when some providers are misconfigured.
76+
that have valid API keys (or do not require one). Logs warnings for
77+
misconfigured providers and raises if every eligible provider fails.
7478
7579
Raises:
80+
ConfigurationError: If all eligible providers fail to initialize.
7681
ImportError: If required provider class cannot be imported.
7782
"""
7883
from justllms.providers import get_provider_class
7984

85+
init_failures: List[Tuple[str, Exception]] = []
86+
attempted_providers: List[str] = []
87+
8088
for provider_name, provider_config in self.config.providers.items():
8189
if not provider_config.get("enabled", True):
8290
continue
8391

8492
provider_class = get_provider_class(provider_name)
8593
if not provider_class:
94+
logger.warning("Unknown provider '%s' in config, skipping", provider_name)
8695
continue
8796

8897
requires_key = getattr(provider_class, "requires_api_key", True)
8998
if requires_key and not provider_config.get("api_key"):
9099
continue
91100

101+
attempted_providers.append(provider_name)
102+
92103
try:
93104
config = ProviderConfig(name=provider_name, **provider_config)
94105
self.providers[provider_name] = provider_class(config)
95-
except Exception:
96-
pass
106+
except Exception as exc:
107+
logger.warning(
108+
"Failed to initialize provider '%s': %s",
109+
provider_name,
110+
exc,
111+
exc_info=logger.isEnabledFor(logging.DEBUG),
112+
)
113+
init_failures.append((provider_name, exc))
114+
115+
if attempted_providers and not self.providers:
116+
failure_details = "; ".join(
117+
f"{name}: {error}" for name, error in init_failures
118+
)
119+
raise ConfigurationError(
120+
"Failed to initialize any configured providers. "
121+
f"{failure_details or 'No failure details available.'}"
122+
)
97123

98124
def add_provider(self, name: str, provider: BaseProvider) -> None:
99125
"""Add a provider instance to the client.
@@ -227,14 +253,20 @@ def _create_completion(
227253
CompletionResponse or StreamResponse depending on stream parameter.
228254
229255
Raises:
230-
ValueError: If model is not specified and no fallback is configured.
256+
ValueError: If model is not specified and no fallback is configured,
257+
or if tools are requested together with stream=True.
231258
ProviderError: If the specified provider is not available or if the
232259
completion request fails, or if streaming is requested
233260
but the provider doesn't support it.
234261
"""
235262
# Check if tools are provided
236263
tools = kwargs.pop("tools", None)
237-
if tools and not stream:
264+
if tools and stream:
265+
raise ValueError(
266+
"Tool calling is not supported with stream=True. "
267+
"Use stream=False to enable tool execution, or omit tools for streaming."
268+
)
269+
if tools:
238270
# Route to tool-enabled completion
239271
# Determine provider/model for tools
240272
if not provider:

justllms/tools/executor.py

Lines changed: 149 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
11
import json
2+
import logging
3+
import multiprocessing as mp
4+
import pickle
25
import threading
36
import time
4-
from typing import Any, Dict, List, Optional
7+
from typing import Any, Callable, Dict, List, Optional, Tuple
58

69
from justllms.core.base import BaseResponse
710
from justllms.tools.models import Tool, ToolCall, ToolExecutionEntry, ToolResult, ToolResultStatus
811
from justllms.tools.utils import validate_tool_arguments
912

13+
logger = logging.getLogger(__name__)
14+
15+
16+
def _run_tool_worker(
17+
result_queue: "mp.Queue[Tuple[str, Any]]",
18+
func: Callable[..., Any],
19+
kwargs: Dict[str, Any],
20+
) -> None:
21+
"""Run a tool callable in an isolated process and return the outcome."""
22+
try:
23+
result_queue.put(("success", func(**kwargs)))
24+
except Exception as exc:
25+
result_queue.put(("error", str(exc)))
26+
1027

1128
class ToolExecutor:
1229
"""Executes tools sequentially with error handling.
@@ -19,6 +36,11 @@ class ToolExecutor:
1936
tools: Dictionary mapping tool names to Tool instances.
2037
timeout: Maximum execution time per tool in seconds.
2138
execute_in_parallel: Always False (no parallel execution).
39+
40+
Note:
41+
Timeouts terminate picklable tools in a subprocess. Non-picklable
42+
callables fall back to a daemon thread where timeout only stops
43+
waiting and cannot guarantee cancellation.
2244
"""
2345

2446
def __init__(
@@ -74,53 +96,155 @@ def execute_tool_call(self, tool_call: ToolCall) -> ToolResult:
7496
)
7597

7698
# Execute with timeout
77-
result_container: Dict[str, Any] = {}
78-
79-
def execute_tool() -> None:
80-
"""Execute tool in separate thread for timeout control."""
81-
try:
82-
result_container["result"] = tool.callable(**validated_args)
83-
result_container["success"] = True
84-
except Exception as e:
85-
result_container["error"] = str(e)
86-
result_container["success"] = False
87-
88-
# Run with timeout
89-
thread = threading.Thread(target=execute_tool, daemon=True)
90-
thread.start()
91-
thread.join(timeout=self.timeout)
92-
99+
result, error, timed_out = self._execute_callable_with_timeout(
100+
tool.callable, validated_args, tool_name=tool_call.name
101+
)
93102
execution_time_ms = (time.time() - start_time) * 1000
94103

95-
# Check timeout
96-
if thread.is_alive():
104+
if timed_out:
97105
return ToolResult(
98106
tool_call_id=tool_call.id,
99107
result=None,
100-
error=f"Tool execution timed out after {self.timeout}s",
108+
error=error,
101109
execution_time_ms=execution_time_ms,
102110
status=ToolResultStatus.TIMEOUT,
103111
)
104112

105-
# Check for errors
106-
if not result_container.get("success", False):
113+
if error is not None:
107114
return ToolResult(
108115
tool_call_id=tool_call.id,
109116
result=None,
110-
error=result_container.get("error", "Unknown error"),
117+
error=error,
111118
execution_time_ms=execution_time_ms,
112119
status=ToolResultStatus.ERROR,
113120
)
114121

115-
# Success
116122
return ToolResult(
117123
tool_call_id=tool_call.id,
118-
result=result_container.get("result"),
124+
result=result,
119125
error=None,
120126
execution_time_ms=execution_time_ms,
121127
status=ToolResultStatus.SUCCESS,
122128
)
123129

130+
def _execute_callable_with_timeout(
131+
self,
132+
callable_fn: Callable[..., Any],
133+
validated_args: Dict[str, Any],
134+
tool_name: str,
135+
) -> Tuple[Any, Optional[str], bool]:
136+
"""Execute a callable with timeout enforcement.
137+
138+
Returns:
139+
Tuple of (result, error_message, timed_out).
140+
"""
141+
try:
142+
pickle.dumps(callable_fn)
143+
except (pickle.PicklingError, TypeError):
144+
logger.debug(
145+
"Tool '%s' is not picklable; using thread-based timeout fallback",
146+
tool_name,
147+
)
148+
return self._execute_in_thread(callable_fn, validated_args, tool_name)
149+
150+
return self._execute_in_process(callable_fn, validated_args, tool_name)
151+
152+
def _execute_in_process(
153+
self,
154+
callable_fn: Callable[..., Any],
155+
validated_args: Dict[str, Any],
156+
tool_name: str,
157+
) -> Tuple[Any, Optional[str], bool]:
158+
"""Execute a picklable callable in a subprocess that can be terminated."""
159+
ctx = mp.get_context("spawn")
160+
result_queue: "mp.Queue[Tuple[str, Any]]" = ctx.Queue()
161+
process = ctx.Process(
162+
target=_run_tool_worker,
163+
args=(result_queue, callable_fn, validated_args),
164+
)
165+
process.start()
166+
process.join(timeout=self.timeout)
167+
168+
if process.is_alive():
169+
self._terminate_process(process)
170+
logger.warning(
171+
"Tool '%s' exceeded timeout of %ss and was terminated",
172+
tool_name,
173+
self.timeout,
174+
)
175+
return (
176+
None,
177+
f"Tool execution timed out after {self.timeout}s",
178+
True,
179+
)
180+
181+
if not result_queue.empty():
182+
status, payload = result_queue.get_nowait()
183+
if status == "success":
184+
return payload, None, False
185+
return None, payload, False
186+
187+
exit_code = process.exitcode
188+
return (
189+
None,
190+
f"Tool process exited without returning a result (exit code: {exit_code})",
191+
False,
192+
)
193+
194+
def _execute_in_thread(
195+
self,
196+
callable_fn: Callable[..., Any],
197+
validated_args: Dict[str, Any],
198+
tool_name: str,
199+
) -> Tuple[Any, Optional[str], bool]:
200+
"""Best-effort timeout for callables that cannot run in a subprocess."""
201+
result_container: Dict[str, Any] = {}
202+
203+
def execute_tool() -> None:
204+
try:
205+
result_container["result"] = callable_fn(**validated_args)
206+
result_container["success"] = True
207+
except Exception as exc:
208+
result_container["error"] = str(exc)
209+
result_container["success"] = False
210+
211+
thread = threading.Thread(target=execute_tool, daemon=True)
212+
thread.start()
213+
thread.join(timeout=self.timeout)
214+
215+
if thread.is_alive():
216+
logger.warning(
217+
"Tool '%s' exceeded timeout of %ss; execution may continue in background "
218+
"because the callable is not picklable for process termination",
219+
tool_name,
220+
self.timeout,
221+
)
222+
return (
223+
None,
224+
(
225+
f"Tool execution timed out after {self.timeout}s "
226+
"(best-effort; execution may continue in background)"
227+
),
228+
True,
229+
)
230+
231+
if not result_container.get("success", False):
232+
return None, result_container.get("error", "Unknown error"), False
233+
234+
return result_container.get("result"), None, False
235+
236+
@staticmethod
237+
def _terminate_process(process: mp.Process) -> None:
238+
"""Terminate a subprocess, escalating to kill if needed."""
239+
if not process.is_alive():
240+
return
241+
242+
process.terminate()
243+
process.join(timeout=1)
244+
if process.is_alive():
245+
process.kill()
246+
process.join()
247+
124248
def _extract_tool_calls(self, response: BaseResponse) -> List[ToolCall]:
125249
"""Extract tool calls from a response.
126250

0 commit comments

Comments
 (0)