diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..094d7cac --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +@ehhuang @ashwinb @raghotham @reluctantfuturist diff --git a/src/llama_stack_client/__init__.py b/src/llama_stack_client/__init__.py index 3c952a95..4fdd36f2 100644 --- a/src/llama_stack_client/__init__.py +++ b/src/llama_stack_client/__init__.py @@ -39,6 +39,12 @@ from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging +from .lib.agents.agent import Agent +from .lib.agents.event_logger import EventLogger as AgentEventLogger +from .lib.inference.event_logger import EventLogger as InferenceEventLogger +from .types.agents.turn_create_params import Document +from .types.shared_params.document import Document as RAGDocument + __all__ = [ "types", "__version__", diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index 9496091b..409d8f5c 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -1,6 +1,7 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from __future__ import annotations +import json import os from typing import Any, Union, Mapping @@ -126,6 +127,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + provider_data: Mapping[str, Any] | None = None, ) -> None: """Construct a new synchronous LlamaStackClient client instance. @@ -140,13 +142,18 @@ def __init__( if base_url is None: base_url = f"http://any-hosted-llama-stack.com" + custom_headers = default_headers or {} + custom_headers["X-LlamaStack-Client-Version"] = __version__ + if provider_data is not None: + custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) + super().__init__( version=__version__, base_url=base_url, max_retries=max_retries, timeout=timeout, http_client=http_client, - custom_headers=default_headers, + custom_headers=custom_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, ) @@ -344,6 +351,7 @@ def __init__( # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + provider_data: Mapping[str, Any] | None = None, ) -> None: """Construct a new async AsyncLlamaStackClient client instance. @@ -358,13 +366,18 @@ def __init__( if base_url is None: base_url = f"http://any-hosted-llama-stack.com" + custom_headers = default_headers or {} + custom_headers["X-LlamaStack-Client-Version"] = __version__ + if provider_data is not None: + custom_headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data) + super().__init__( version=__version__, base_url=base_url, max_retries=max_retries, timeout=timeout, http_client=http_client, - custom_headers=default_headers, + custom_headers=custom_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, ) diff --git a/src/llama_stack_client/_utils/_logs.py b/src/llama_stack_client/_utils/_logs.py index 49f3ee8c..77e8dc24 100644 --- a/src/llama_stack_client/_utils/_logs.py +++ b/src/llama_stack_client/_utils/_logs.py @@ -1,5 +1,6 @@ import os import logging +from rich.logging import RichHandler logger: logging.Logger = logging.getLogger("llama_stack_client") httpx_logger: logging.Logger = logging.getLogger("httpx") @@ -10,6 +11,7 @@ def _basic_config() -> None: logging.basicConfig( format="[%(asctime)s - %(name)s:%(lineno)d - %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", + handlers=[RichHandler(rich_tracebacks=True)], ) diff --git a/src/llama_stack_client/_version.py b/src/llama_stack_client/_version.py index 90b5d94d..41313065 100644 --- a/src/llama_stack_client/_version.py +++ b/src/llama_stack_client/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "llama_stack_client" -__version__ = "0.1.0-alpha.2" # x-release-please-version +__version__ = "0.2.12" diff --git a/src/llama_stack_client/lib/.keep b/src/llama_stack_client/lib/.keep index 5e2c99fd..7554f8b2 100644 --- a/src/llama_stack_client/lib/.keep +++ b/src/llama_stack_client/lib/.keep @@ -1,4 +1,4 @@ File generated from our OpenAPI spec by Stainless. This directory can be used to store custom files to expand the SDK. -It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. diff --git a/src/llama_stack_client/lib/__init__.py b/src/llama_stack_client/lib/__init__.py new file mode 100644 index 00000000..6bc5d151 --- /dev/null +++ b/src/llama_stack_client/lib/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .tools.mcp_oauth import get_oauth_token_for_mcp_server + +__all__ = ["get_oauth_token_for_mcp_server"] diff --git a/src/llama_stack_client/lib/agents/__init__.py b/src/llama_stack_client/lib/agents/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/agents/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py new file mode 100644 index 00000000..ebdc4abd --- /dev/null +++ b/src/llama_stack_client/lib/agents/agent.py @@ -0,0 +1,601 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import logging +from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types import ToolResponseMessage, ToolResponseParam, UserMessage +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.agent_turn_response_stream_chunk import ( + AgentTurnResponseStreamChunk, +) +from llama_stack_client.types.agents.turn import CompletionMessage, Turn +from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup +from llama_stack_client.types.shared.tool_call import ToolCall +from llama_stack_client.types.shared_params.agent_config import ToolConfig +from llama_stack_client.types.shared_params.response_format import ResponseFormat +from llama_stack_client.types.shared_params.sampling_params import SamplingParams + +from ..._types import Headers +from .client_tool import ClientTool, client_tool +from .tool_parser import ToolParser + +DEFAULT_MAX_ITER = 10 + +logger = logging.getLogger(__name__) + + +class AgentUtils: + @staticmethod + def get_client_tools( + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]], + ) -> List[ClientTool]: + if not tools: + return [] + + # Wrap any function in client_tool decorator + tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools] + return [tool for tool in tools if isinstance(tool, ClientTool)] + + @staticmethod + def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: + if chunk.event.payload.event_type not in { + "turn_complete", + "turn_awaiting_input", + }: + return [] + + message = chunk.event.payload.turn.output_message + if message.stop_reason == "out_of_tokens": + return [] + + if tool_parser: + return tool_parser.get_tool_calls(message) + + return message.tool_calls + + @staticmethod + def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: + if chunk.event.payload.event_type not in [ + "turn_complete", + "turn_awaiting_input", + ]: + return None + + return chunk.event.payload.turn.turn_id + + @staticmethod + def get_agent_config( + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + ) -> AgentConfig: + # Create a minimal valid AgentConfig with required fields + if model is None or instructions is None: + raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") + + agent_config = { + "model": model, + "instructions": instructions, + "toolgroups": [], + "client_tools": [], + } + + # Add optional parameters if provided + if enable_session_persistence is not None: + agent_config["enable_session_persistence"] = enable_session_persistence + if max_infer_iters is not None: + agent_config["max_infer_iters"] = max_infer_iters + if input_shields is not None: + agent_config["input_shields"] = input_shields + if output_shields is not None: + agent_config["output_shields"] = output_shields + if response_format is not None: + agent_config["response_format"] = response_format + if sampling_params is not None: + agent_config["sampling_params"] = sampling_params + if tool_config is not None: + agent_config["tool_config"] = tool_config + if tools is not None: + toolgroups: List[Toolgroup] = [] + for tool in tools: + if isinstance(tool, str) or isinstance(tool, dict): + toolgroups.append(tool) + + agent_config["toolgroups"] = toolgroups + agent_config["client_tools"] = [tool.get_tool_definition() for tool in AgentUtils.get_client_tools(tools)] + + agent_config = AgentConfig(**agent_config) + return agent_config + + +class Agent: + def __init__( + self, + client: LlamaStackClient, + # begin deprecated + agent_config: Optional[AgentConfig] = None, + client_tools: Tuple[ClientTool, ...] = (), + # end deprecated + tool_parser: Optional[ToolParser] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + extra_headers: Headers | None = None, + ): + """Construct an Agent with the given parameters. + + :param client: The LlamaStackClient instance. + :param agent_config: The AgentConfig instance. + ::deprecated: use other parameters instead + :param client_tools: A tuple of ClientTool instances. + ::deprecated: use tools instead + :param tool_parser: Custom logic that parses tool calls from a message. + :param model: The model to use for the agent. + :param instructions: The instructions for the agent. + :param tools: A list of tools for the agent. Values can be one of the following: + - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} + - a python function with a docstring. See @client_tool for more details. + - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" + - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent + - an instance of ClientTool: A client tool object. + :param tool_config: The tool configuration for the agent. + :param sampling_params: The sampling parameters for the agent. + :param max_infer_iters: The maximum number of inference iterations. + :param input_shields: The input shields for the agent. + :param output_shields: The output shields for the agent. + :param response_format: The response format for the agent. + :param enable_session_persistence: Whether to enable session persistence. + :param extra_headers: Extra headers to add to all requests sent by the agent. + """ + self.client = client + + if agent_config is not None: + logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") + if client_tools != (): + logger.warning("`client_tools` is deprecated. Use `tools` instead.") + + # Construct agent_config from parameters if not provided + if agent_config is None: + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) + client_tools = AgentUtils.get_client_tools(tools) + + self.agent_config = agent_config + self.client_tools = {t.get_name(): t for t in client_tools} + self.sessions = [] + self.tool_parser = tool_parser + self.builtin_tools = {} + self.extra_headers = extra_headers + self.initialize() + + def initialize(self) -> None: + agentic_system_create_response = self.client.agents.create( + agent_config=self.agent_config, + extra_headers=self.extra_headers, + ) + self.agent_id = agentic_system_create_response.agent_id + for tg in self.agent_config["toolgroups"]: + toolgroup_id = tg if isinstance(tg, str) else tg.get("name") + for tool in self.client.tools.list(toolgroup_id=toolgroup_id, extra_headers=self.extra_headers): + self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {} + + def create_session(self, session_name: str) -> str: + agentic_system_create_session_response = self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + extra_headers=self.extra_headers, + ) + self.session_id = agentic_system_create_session_response.session_id + self.sessions.append(self.session_id) + return self.session_id + + def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: + responses = [] + for tool_call in tool_calls: + responses.append(self._run_single_tool(tool_call)) + return responses + + def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + # NOTE: tool.run() expects a list of messages, we only pass in last message here + # but we could pass in the entire message history + result_message = tool.run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) + return result_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs={ + **tool_call.arguments, + **self.builtin_tools[tool_call.tool_name], + }, + extra_headers=self.extra_headers, + ) + return ToolResponseParam( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + ) + + # cannot find tools + return ToolResponseParam( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + ) + + def create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + # TODO: deprecate this + extra_headers: Headers | None = None, + ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming( + messages, session_id, toolgroups, documents, extra_headers=extra_headers or self.extra_headers + ) + else: + chunks = [ + x + for x in self._create_turn_streaming( + messages, + session_id, + toolgroups, + documents, + extra_headers=extra_headers or self.extra_headers, + ) + ] + if not chunks: + raise Exception("Turn did not complete") + + last_chunk = chunks[-1] + if hasattr(last_chunk, "error"): + if "message" in last_chunk.error: + error_msg = last_chunk.error["message"] + else: + error_msg = str(last_chunk.error) + raise RuntimeError(f"Turn did not complete. Error: {error_msg}") + try: + return last_chunk.event.payload.turn + except AttributeError: + raise RuntimeError(f"Turn did not complete. Output: {last_chunk}") from None + + def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + # TODO: deprecate this + extra_headers: Headers | None = None, + ) -> Iterator[AgentTurnResponseStreamChunk]: + n_iter = 0 + + # 1. create an agent turn + turn_response = self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + extra_headers=extra_headers or self.extra_headers, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + for chunk in turn_response: + if hasattr(chunk, "error"): + yield chunk + return + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) + if not tool_calls: + yield chunk + else: + is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call + # We only check for this if tool_parser is not set, because otherwise + # tool call will be parsed on client side, and server will always return "end_of_turn" + if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + yield chunk + break + + turn_id = AgentUtils.get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_responses = self._run_tool_calls(tool_calls) + + # pass it to next iteration + turn_response = self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=tool_responses, + stream=True, + extra_headers=extra_headers or self.extra_headers, + ) + n_iter += 1 + + if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): + raise Exception("Max inference iterations reached") + + +class AsyncAgent: + def __init__( + self, + client: LlamaStackClient, + # begin deprecated + agent_config: Optional[AgentConfig] = None, + client_tools: Tuple[ClientTool, ...] = (), + # end deprecated + tool_parser: Optional[ToolParser] = None, + model: Optional[str] = None, + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + extra_headers: Headers | None = None, + ): + """Construct an Agent with the given parameters. + + :param client: The LlamaStackClient instance. + :param agent_config: The AgentConfig instance. + ::deprecated: use other parameters instead + :param client_tools: A tuple of ClientTool instances. + ::deprecated: use tools instead + :param tool_parser: Custom logic that parses tool calls from a message. + :param model: The model to use for the agent. + :param instructions: The instructions for the agent. + :param tools: A list of tools for the agent. Values can be one of the following: + - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} + - a python function with a docstring. See @client_tool for more details. + - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" + - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent + - an instance of ClientTool: A client tool object. + :param tool_config: The tool configuration for the agent. + :param sampling_params: The sampling parameters for the agent. + :param max_infer_iters: The maximum number of inference iterations. + :param input_shields: The input shields for the agent. + :param output_shields: The output shields for the agent. + :param response_format: The response format for the agent. + :param enable_session_persistence: Whether to enable session persistence. + :param extra_headers: Extra headers to add to all requests sent by the agent. + """ + self.client = client + + if agent_config is not None: + logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") + if client_tools != (): + logger.warning("`client_tools` is deprecated. Use `tools` instead.") + + # Construct agent_config from parameters if not provided + if agent_config is None: + agent_config = AgentUtils.get_agent_config( + model=model, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) + client_tools = AgentUtils.get_client_tools(tools) + + self.agent_config = agent_config + self.client_tools = {t.get_name(): t for t in client_tools} + self.sessions = [] + self.tool_parser = tool_parser + self.builtin_tools = {} + self.extra_headers = extra_headers + self._agent_id = None + + if isinstance(client, LlamaStackClient): + raise ValueError("AsyncAgent must be initialized with an AsyncLlamaStackClient") + + @property + def agent_id(self) -> str: + if not self._agent_id: + raise RuntimeError("Agent ID not initialized. Call initialize() first.") + return self._agent_id + + async def initialize(self) -> None: + if self._agent_id: + return + + agentic_system_create_response = await self.client.agents.create( + agent_config=self.agent_config, + ) + self._agent_id = agentic_system_create_response.agent_id + for tg in self.agent_config["toolgroups"]: + for tool in await self.client.tools.list(toolgroup_id=tg, extra_headers=self.extra_headers): + self.builtin_tools[tool.identifier] = tg.get("args", {}) if isinstance(tg, dict) else {} + + async def create_session(self, session_name: str) -> str: + await self.initialize() + agentic_system_create_session_response = await self.client.agents.session.create( + agent_id=self.agent_id, + session_name=session_name, + extra_headers=self.extra_headers, + ) + self.session_id = agentic_system_create_session_response.session_id + self.sessions.append(self.session_id) + return self.session_id + + async def create_turn( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + stream: bool = True, + ) -> AsyncIterator[AgentTurnResponseStreamChunk] | Turn: + if stream: + return self._create_turn_streaming(messages, session_id, toolgroups, documents) + else: + chunks = [x async for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] + if not chunks: + raise Exception("Turn did not complete") + return chunks[-1].event.payload.turn + + async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: + responses = [] + for tool_call in tool_calls: + responses.append(await self._run_single_tool(tool_call)) + return responses + + async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: + # custom client tools + if tool_call.tool_name in self.client_tools: + tool = self.client_tools[tool_call.tool_name] + result_message = await tool.async_run( + [ + CompletionMessage( + role="assistant", + content=tool_call.tool_name, + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + ] + ) + return result_message + + # builtin tools executed by tool_runtime + if tool_call.tool_name in self.builtin_tools: + tool_result = await self.client.tool_runtime.invoke_tool( + tool_name=tool_call.tool_name, + kwargs={ + **tool_call.arguments, + **self.builtin_tools[tool_call.tool_name], + }, + extra_headers=self.extra_headers, + ) + return ToolResponseParam( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + ) + + # cannot find tools + return ToolResponseParam( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=f"Unknown tool `{tool_call.tool_name}` was called.", + ) + + async def _create_turn_streaming( + self, + messages: List[Union[UserMessage, ToolResponseMessage]], + session_id: Optional[str] = None, + toolgroups: Optional[List[Toolgroup]] = None, + documents: Optional[List[Document]] = None, + ) -> AsyncIterator[AgentTurnResponseStreamChunk]: + n_iter = 0 + + # 1. create an agent turn + turn_response = await self.client.agents.turn.create( + agent_id=self.agent_id, + # use specified session_id or last session created + session_id=session_id or self.session_id[-1], + messages=messages, + stream=True, + documents=documents, + toolgroups=toolgroups, + extra_headers=self.extra_headers, + ) + + # 2. process turn and resume if there's a tool call + is_turn_complete = False + while not is_turn_complete: + is_turn_complete = True + async for chunk in turn_response: + if hasattr(chunk, "error"): + yield chunk + return + + tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) + if not tool_calls: + yield chunk + else: + is_turn_complete = False + # End of turn is reached, do not resume even if there's a tool call + if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: + yield chunk + break + + turn_id = AgentUtils.get_turn_id(chunk) + if n_iter == 0: + yield chunk + + # run the tools + tool_responses = await self._run_tool_calls(tool_calls) + + # pass it to next iteration + turn_response = await self.client.agents.turn.resume( + agent_id=self.agent_id, + session_id=session_id or self.session_id[-1], + turn_id=turn_id, + tool_responses=tool_responses, + stream=True, + extra_headers=self.extra_headers, + ) + n_iter += 1 + + if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): + raise Exception("Max inference iterations reached") diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py new file mode 100644 index 00000000..c199b211 --- /dev/null +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect +import json +from abc import abstractmethod +from typing import ( + Any, + Callable, + Dict, + get_args, + get_origin, + get_type_hints, + List, + TypeVar, + Union, +) + +from llama_stack_client.types import CompletionMessage, Message, ToolResponse +from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam + + +class ClientTool: + """ + Developers can define their custom tools that models can use + by extending this class. + + Developers need to provide + - name + - description + - params_definition + - implement tool's behavior in `run_impl` method + + NOTE: The return of the `run` method needs to be json serializable + """ + + @abstractmethod + def get_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_description(self) -> str: + raise NotImplementedError + + @abstractmethod + def get_params_definition(self) -> Dict[str, Parameter]: + raise NotImplementedError + + def get_instruction_string(self) -> str: + return f"Use the function '{self.get_name()}' to: {self.get_description()}" + + def parameters_for_system_prompt(self) -> str: + return json.dumps( + { + "name": self.get_name(), + "description": self.get_description(), + "parameters": {name: definition for name, definition in self.get_params_definition().items()}, + } + ) + + def get_tool_definition(self) -> ToolDefParam: + return ToolDefParam( + name=self.get_name(), + description=self.get_description(), + parameters=list(self.get_params_definition().values()), + metadata={}, + tool_prompt_format="python_list", + ) + + def run( + self, + message_history: List[Message], + ) -> ToolResponse: + # NOTE: we could override this method to use the entire message history for advanced tools + last_message = message_history[-1] + assert isinstance(last_message, CompletionMessage), "Expected CompletionMessage" + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] + + metadata = {} + try: + if tool_call.arguments_json is not None: + params = json.loads(tool_call.arguments_json) + elif isinstance(tool_call.arguments, str): + params = json.loads(tool_call.arguments) + else: + params = tool_call.arguments + + response = self.run_impl(**params) + if isinstance(response, dict) and "content" in response: + content = json.dumps(response["content"], ensure_ascii=False) + metadata = response.get("metadata", {}) + else: + content = json.dumps(response, ensure_ascii=False) + except Exception as e: + content = f"Error when running tool: {e}" + return ToolResponse( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=content, + metadata=metadata, + ) + + async def async_run( + self, + message_history: List[Message], + ) -> ToolResponse: + last_message = message_history[-1] + + assert len(last_message.tool_calls) == 1, "Expected single tool call" + tool_call = last_message.tool_calls[0] + metadata = {} + try: + response = await self.async_run_impl(**tool_call.arguments) + if isinstance(response, dict) and "content" in response: + content = json.dumps(response["content"], ensure_ascii=False) + metadata = response.get("metadata", {}) + else: + content = json.dumps(response, ensure_ascii=False) + except Exception as e: + content = f"Error when running tool: {e}" + + return ToolResponse( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=content, + metadata=metadata, + ) + + @abstractmethod + def run_impl(self, **kwargs) -> Any: + """ + Can return any json serializable object. + To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll + be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace. + """ + raise NotImplementedError + + @abstractmethod + def async_run_impl(self, **kwargs): + raise NotImplementedError + + +T = TypeVar("T", bound=Callable) + + +def client_tool(func: T) -> ClientTool: + """ + Decorator to convert a function into a ClientTool. + Usage: + @client_tool + def add(x: int, y: int) -> int: + '''Add 2 integer numbers + + :param x: integer 1 + :param y: integer 2 + :returns: sum of x + y + ''' + return x + y + + Note that you must use RST-style docstrings with :param tags for each parameter. These will be used for prompting model to use tools correctly. + :returns: tags in the docstring is optional as it would not be used for the tool's description. + + Your function can return any json serializable object. + To return metadata along with the response, return a dict with a "content" key, and a "metadata" key, where the "content" is the response that'll + be serialized and passed to the model, and the "metadata" will be logged as metadata in the tool execution step within the Agent execution trace. + """ + + class _WrappedTool(ClientTool): + __name__ = func.__name__ + __doc__ = func.__doc__ + __module__ = func.__module__ + + def get_name(self) -> str: + return func.__name__ + + def get_description(self) -> str: + doc = inspect.getdoc(func) + if doc: + # Get everything before the first :param + return doc.split(":param")[0].strip() + else: + raise ValueError( + f"No description found for client tool {__name__}. Please provide a RST-style docstring with description and :param tags for each parameter." + ) + + def get_params_definition(self) -> Dict[str, Parameter]: + hints = get_type_hints(func) + # Remove return annotation if present + hints.pop("return", None) + + # Get parameter descriptions from docstring + params = {} + sig = inspect.signature(func) + doc = inspect.getdoc(func) or "" + + for name, type_hint in hints.items(): + # Look for :param name: in docstring + param_doc = "" + for line in doc.split("\n"): + if line.strip().startswith(f":param {name}:"): + param_doc = line.split(":", 2)[2].strip() + break + + if param_doc == "": + raise ValueError(f"No parameter description found for parameter {name}") + + param = sig.parameters[name] + is_optional_type = get_origin(type_hint) is Union and type(None) in get_args(type_hint) + is_required = param.default == inspect.Parameter.empty and not is_optional_type + params[name] = Parameter( + name=name, + description=param_doc or f"Parameter {name}", + parameter_type=type_hint.__name__, + default=(param.default if param.default != inspect.Parameter.empty else None), + required=is_required, + ) + + return params + + def run_impl(self, **kwargs) -> Any: + if inspect.iscoroutinefunction(func): + raise NotImplementedError("Tool is async but run_impl is not async") + return func(**kwargs) + + async def async_run_impl(self, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) + + return _WrappedTool() diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py new file mode 100644 index 00000000..731c7b2f --- /dev/null +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Iterator, Optional, Tuple + +from termcolor import cprint + +from llama_stack_client.types import InterleavedContent + + +def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: + def _process(c: Any) -> str: + if isinstance(c, str): + return c + elif hasattr(c, "type"): + if c.type == "text": + return c.text + elif c.type == "image": + return "" + else: + raise ValueError(f"Unexpected type {c}") + else: + raise ValueError(f"Unsupported content type: {type(c)}") + + if isinstance(content, list): + return sep.join(_process(c) for c in content) + else: + return _process(content) + + +class TurnStreamPrintableEvent: + def __init__( + self, + role: Optional[str] = None, + content: str = "", + end: Optional[str] = "\n", + color: str = "white", + ) -> None: + self.role = role + self.content = content + self.color = color + self.end = "\n" if end is None else end + + def __str__(self) -> str: + if self.role is not None: + return f"{self.role}> {self.content}" + else: + return f"{self.content}" + + def print(self, flush: bool = True) -> None: + cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) + + +class TurnStreamEventPrinter: + def __init__(self) -> None: + self.previous_event_type: Optional[str] = None + self.previous_step_type: Optional[str] = None + + def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]: + for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type): + yield printable_event + + if not hasattr(chunk, "error"): + self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk) + + def _yield_printable_events( + self, chunk: Any, previous_event_type: Optional[str] = None, previous_step_type: Optional[str] = None + ) -> Iterator[TurnStreamPrintableEvent]: + if hasattr(chunk, "error"): + yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red") + return + + event = chunk.event + event_type = event.payload.event_type + + if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}: + # Currently not logging any turn realted info + yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") + return + + step_type = event.payload.step_type + # handle safety + if step_type == "shield_call" and event_type == "step_complete": + violation = event.payload.step_details.violation + if not violation: + yield TurnStreamPrintableEvent(role=step_type, content="No Violation", color="magenta") + else: + yield TurnStreamPrintableEvent( + role=step_type, + content=f"{violation.metadata} {violation.user_message}", + color="red", + ) + + # handle inference + if step_type == "inference": + if event_type == "step_start": + yield TurnStreamPrintableEvent(role=step_type, content="", end="", color="yellow") + elif event_type == "step_progress": + if event.payload.delta.type == "tool_call": + if isinstance(event.payload.delta.tool_call, str): + yield TurnStreamPrintableEvent( + role=None, + content=event.payload.delta.tool_call, + end="", + color="cyan", + ) + elif event.payload.delta.type == "text": + yield TurnStreamPrintableEvent( + role=None, + content=event.payload.delta.text, + end="", + color="yellow", + ) + else: + # step complete + yield TurnStreamPrintableEvent(role=None, content="") + + # handle tool_execution + if step_type == "tool_execution" and event_type == "step_complete": + # Only print tool calls and responses at the step_complete event + details = event.payload.step_details + for t in details.tool_calls: + yield TurnStreamPrintableEvent( + role=step_type, + content=f"Tool:{t.tool_name} Args:{t.arguments}", + color="green", + ) + + for r in details.tool_responses: + if r.tool_name == "query_from_memory": + inserted_context = interleaved_content_as_str(r.content) + content = f"fetched {len(inserted_context)} bytes from memory" + + yield TurnStreamPrintableEvent( + role=step_type, + content=content, + color="cyan", + ) + else: + yield TurnStreamPrintableEvent( + role=step_type, + content=f"Tool:{r.tool_name} Response:{r.content}", + color="green", + ) + + def _get_event_type_step_type(self, chunk: Any) -> Tuple[Optional[str], Optional[str]]: + if hasattr(chunk, "event"): + previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None + previous_step_type = ( + chunk.event.payload.step_type + if previous_event_type not in {"turn_start", "turn_complete", "turn_awaiting_input"} + else None + ) + return previous_event_type, previous_step_type + return None, None + + +class EventLogger: + def log(self, event_generator: Iterator[Any]) -> Iterator[TurnStreamPrintableEvent]: + printer = TurnStreamEventPrinter() + for chunk in event_generator: + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/agents/react/__init__.py b/src/llama_stack_client/lib/agents/react/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py new file mode 100644 index 00000000..2719a7dd --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import logging +from typing import Any, Callable, List, Optional, Tuple, Union + +from llama_stack_client import LlamaStackClient +from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.agents.turn_create_params import Toolgroup +from llama_stack_client.types.shared_params.agent_config import ToolConfig +from llama_stack_client.types.shared_params.response_format import ResponseFormat +from llama_stack_client.types.shared_params.sampling_params import SamplingParams + +from ..agent import Agent, AgentUtils +from ..client_tool import ClientTool +from ..tool_parser import ToolParser +from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE +from .tool_parser import ReActOutput, ReActToolParser + +logger = logging.getLogger(__name__) + + +def get_tool_defs( + client: LlamaStackClient, builtin_toolgroups: Tuple[Toolgroup] = (), client_tools: Tuple[ClientTool] = () +): + tool_defs = [] + for x in builtin_toolgroups: + if isinstance(x, str): + toolgroup_id = x + else: + toolgroup_id = x["name"] + tool_defs.extend( + [ + { + "name": tool.identifier, + "description": tool.description, + "parameters": tool.parameters, + } + for tool in client.tools.list(toolgroup_id=toolgroup_id) + ] + ) + + tool_defs.extend( + [ + { + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_params_definition(), + } + for tool in client_tools + ] + ) + return tool_defs + + +def get_default_react_instructions( + client: LlamaStackClient, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = () +): + tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools) + tool_names = ", ".join([x["name"] for x in tool_defs]) + tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) + instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( + "<>", tool_descriptions + ) + return instruction + + +def get_agent_config_DEPRECATED( + client: LlamaStackClient, + model: str, + builtin_toolgroups: Tuple[str] = (), + client_tools: Tuple[ClientTool] = (), + json_response_format: bool = False, + custom_agent_config: Optional[AgentConfig] = None, +) -> AgentConfig: + if custom_agent_config is None: + instruction = get_default_react_instructions(client, builtin_toolgroups, client_tools) + + # user default toolgroups + agent_config = AgentConfig( + model=model, + instructions=instruction, + toolgroups=builtin_toolgroups, + client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], + tool_config={ + "tool_choice": "auto", + "system_message_behavior": "replace", + }, + input_shields=[], + output_shields=[], + enable_session_persistence=False, + ) + else: + agent_config = custom_agent_config + + if json_response_format: + agent_config["response_format"] = { + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + } + + return agent_config + + +class ReActAgent(Agent): + """ReAct agent. + + Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools. + """ + + def __init__( + self, + client: LlamaStackClient, + model: str, + tool_parser: ToolParser = ReActToolParser(), + instructions: Optional[str] = None, + tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, + tool_config: Optional[ToolConfig] = None, + sampling_params: Optional[SamplingParams] = None, + max_infer_iters: Optional[int] = None, + input_shields: Optional[List[str]] = None, + output_shields: Optional[List[str]] = None, + response_format: Optional[ResponseFormat] = None, + enable_session_persistence: Optional[bool] = None, + json_response_format: bool = False, + builtin_toolgroups: Tuple[str] = (), # DEPRECATED + client_tools: Tuple[ClientTool] = (), # DEPRECATED + custom_agent_config: Optional[AgentConfig] = None, # DEPRECATED + ): + """Construct an Agent with the given parameters. + + :param client: The LlamaStackClient instance. + :param custom_agent_config: The AgentConfig instance. + ::deprecated: use other parameters instead + :param client_tools: A tuple of ClientTool instances. + ::deprecated: use tools instead + :param builtin_toolgroups: A tuple of Toolgroup instances. + ::deprecated: use tools instead + :param tool_parser: Custom logic that parses tool calls from a message. + :param model: The model to use for the agent. + :param instructions: The instructions for the agent. + :param tools: A list of tools for the agent. Values can be one of the following: + - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} + - a python function with a docstring. See @client_tool for more details. + - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" + - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent + - an instance of ClientTool: A client tool object. + :param tool_config: The tool configuration for the agent. + :param sampling_params: The sampling parameters for the agent. + :param max_infer_iters: The maximum number of inference iterations. + :param input_shields: The input shields for the agent. + :param output_shields: The output shields for the agent. + :param response_format: The response format for the agent. + :param enable_session_persistence: Whether to enable session persistence. + :param json_response_format: Whether to use the json response format with default ReAct output schema. + ::deprecated: use response_format instead + """ + use_deprecated_params = False + if custom_agent_config is not None: + logger.warning("`custom_agent_config` is deprecated. Use inlined parameters instead.") + use_deprecated_params = True + if client_tools != (): + logger.warning("`client_tools` is deprecated. Use `tools` instead.") + use_deprecated_params = True + if builtin_toolgroups != (): + logger.warning("`builtin_toolgroups` is deprecated. Use `tools` instead.") + use_deprecated_params = True + + if use_deprecated_params: + agent_config = get_agent_config_DEPRECATED( + client=client, + model=model, + builtin_toolgroups=builtin_toolgroups, + client_tools=client_tools, + json_response_format=json_response_format, + ) + super().__init__( + client=client, + agent_config=agent_config, + client_tools=client_tools, + tool_parser=tool_parser, + ) + + else: + if not tool_config: + tool_config = { + "tool_choice": "auto", + "system_message_behavior": "replace", + } + + if json_response_format: + if instructions is not None: + logger.warning( + "Using a custom instructions, but json_response_format is set. Please make sure instructions are" + "compatible with the default ReAct output format." + ) + response_format = { + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + } + + # build REACT instructions + client_tools = AgentUtils.get_client_tools(tools) + builtin_toolgroups = [x for x in tools if isinstance(x, str) or isinstance(x, dict)] + if not instructions: + instructions = get_default_react_instructions(client, builtin_toolgroups, client_tools) + + super().__init__( + client=client, + model=model, + tool_parser=tool_parser, + instructions=instructions, + tools=tools, + tool_config=tool_config, + sampling_params=sampling_params, + max_infer_iters=max_infer_iters, + input_shields=input_shields, + output_shields=output_shields, + response_format=response_format, + enable_session_persistence=enable_session_persistence, + ) diff --git a/src/llama_stack_client/lib/agents/react/prompts.py b/src/llama_stack_client/lib/agents/react/prompts.py new file mode 100644 index 00000000..cad7054a --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/prompts.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE = """ +You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. +To do so, you have been given access to the following tools: <> + +You must always respond in the following JSON format: +{ + "thought": $THOUGHT_PROCESS, + "action": { + "tool_name": $TOOL_NAME, + "tool_params": $TOOL_PARAMS + }, + "answer": $ANSWER +} + +Specifically, this json should have a `thought` key, a `action` key and an `answer` key. + +The `action` key should specify the $TOOL_NAME the name of the tool to use and the `tool_params` key should specify the parameters key as input to the tool. + +Make sure to have the $TOOL_PARAMS as a list of dictionaries in the right format for the tool you are using, and do not put variable names as input if you can find the right values. + +You should always think about one action to take, and have the `thought` key contain your thought process about this action. +If the tool responds, the tool will return an observation containing result of the action. +... (this Thought/Action/Observation can repeat N times, you should take several steps when needed. The action key must only use a SINGLE tool at a time.) + +You can use the result of the previous action as input for the next action. +The observation will always be the response from calling the tool: it can represent a file, like "image_1.jpg". You do not need to generate them, it will be provided to you. +Then you can use it as input for the next action. You can do it for instance as follows: + +Observation: "image_1.jpg" +{ + "thought": "I need to transform the image that I received in the previous observation to make it green.", + "action": { + "tool_name": "image_transformer", + "tool_params": [{"name": "image"}, {"value": "image_1.jpg"}] + }, + "answer": null +} + + +To provide the final answer to the task, use the `answer` key. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: +Observation: "your observation" + +{ + "thought": "you thought process", + "action": null, + "answer": "insert your final answer here" +} + +Here are a few examples using notional tools: +--- +Task: "Generate an image of the oldest person in this document." + +Your Response: +{ + "thought": "I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.", + "action": { + "tool_name": "document_qa", + "tool_params": [{"name": "document"}, {"value": "document.pdf"}, {"name": "question"}, {"value": "Who is the oldest person mentioned?"}] + }, + "answer": null +} + +Your Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." + +Your Response: +{ + "thought": "I will now generate an image showcasing the oldest person.", + "action": { + "tool_name": "image_generator", + "tool_params": [{"name": "prompt"}, {"value": "A portrait of John Doe, a 55-year-old man living in Canada."}] + }, + "answer": null +} +Your Observation: "image.png" + +{ + "thought": "I will now return the generated image.", + "action": null, + "answer": "image.png" +} + +--- +Task: "What is the result of the following operation: 5 + 3 + 1294.678?" + +Your Response: +{ + "thought": "I will use python code evaluator to compute the result of the operation and then return the final answer using the `final_answer` tool", + "action": { + "tool_name": "python_interpreter", + "tool_params": [{"name": "code"}, {"value": "5 + 3 + 1294.678"}] + }, + "answer": null +} +Your Observation: 1302.678 + +{ + "thought": "Now that I know the result, I will now return it.", + "action": null, + "answer": 1302.678 +} + +--- +Task: "Which city has the highest population , Guangzhou or Shanghai?" + +Your Response: +{ + "thought": "I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.", + "action": { + "tool_name": "search", + "tool_params": [{"name": "query"}, {"value": "Population Guangzhou"}] + }, + "answer": null +} +Your Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] + +Your Response: +{ + "thought": "Now let's get the population of Shanghai using the tool 'search'.", + "action": { + "tool_name": "search", + "tool_params": [{"name": "query"}, {"value": "Population Shanghai"}] + }, + "answer": null +} +Your Observation: "26 million (2019)" + +Your Response: +{ + "thought": "Now I know that Shanghai has a larger population. Let's return the result.", + "action": null, + "answer": "Shanghai" +} + +Above example were using notional tools that might not exist for you. You only have access to these tools: +<> + +Here are the rules you should always follow to solve your task: +1. ALWAYS answer in the JSON format with keys "thought", "action", "answer", else you will fail. +2. Always use the right arguments for the tools. Never use variable names in the 'tool_params' field, use the value instead. +3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. Observations will be provided to you, no need to generate them + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. +""" diff --git a/src/llama_stack_client/lib/agents/react/tool_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py new file mode 100644 index 00000000..76b787dd --- /dev/null +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import uuid +from typing import List, Optional, Union + +from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + +from pydantic import BaseModel, ValidationError + +from ..tool_parser import ToolParser + + +class Param(BaseModel): + name: str + value: Union[str, int, float, bool] + + +class Action(BaseModel): + tool_name: str + tool_params: List[Param] + + +class ReActOutput(BaseModel): + thought: str + action: Optional[Action] + answer: Optional[str] + + +class ReActToolParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + tool_calls = [] + response_text = str(output_message.content) + try: + react_output = ReActOutput.model_validate_json(response_text) + except ValidationError as e: + print(f"Error parsing action: {e}") + return tool_calls + + if react_output.answer: + return tool_calls + + if react_output.action: + tool_name = react_output.action.tool_name + tool_params = react_output.action.tool_params + params = {param.name: param.value for param in tool_params} + if tool_name and tool_params: + call_id = str(uuid.uuid4()) + tool_calls = [ + ToolCall( + call_id=call_id, + tool_name=tool_name, + arguments=params, + arguments_json=json.dumps(params), + ) + ] + + return tool_calls diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py new file mode 100644 index 00000000..dc0c5ba4 --- /dev/null +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import abstractmethod +from typing import List + +from llama_stack_client.types.agents.turn import CompletionMessage +from llama_stack_client.types.shared.tool_call import ToolCall + + +class ToolParser: + """ + Abstract base class for parsing agent responses into tool calls. Implement this class to customize how + agent outputs are processed and transformed into executable tool calls. + + To use this class: + 1. Create a subclass of ToolParser + 2. Implement the `get_tool_calls` method + 3. Pass your parser instance to the Agent's constructor + + Example: + class MyCustomParser(ToolParser): + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + # Add your custom parsing logic here + return extracted_tool_calls + + Methods: + get_tool_calls(output_message: CompletionMessage) -> List[ToolCall]: + Abstract method that must be implemented by subclasses to process + the agent's response and extract tool calls. + + Args: + output_message (CompletionMessage): The response message from agent turn + + Returns: + Optional[List[ToolCall]]: A list of parsed tool calls, or None if no tools should be called + """ + + @abstractmethod + def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: + raise NotImplementedError diff --git a/src/llama_stack_client/lib/cli/__init__.py b/src/llama_stack_client/lib/cli/__init__.py new file mode 100644 index 00000000..77737e7d --- /dev/null +++ b/src/llama_stack_client/lib/cli/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Ignore tqdm experimental warning +import warnings + +from tqdm import TqdmExperimentalWarning + +warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) diff --git a/src/llama_stack_client/lib/cli/common/__init__.py b/src/llama_stack_client/lib/cli/common/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/cli/common/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/cli/common/utils.py b/src/llama_stack_client/lib/cli/common/utils.py new file mode 100644 index 00000000..faf9ac26 --- /dev/null +++ b/src/llama_stack_client/lib/cli/common/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from functools import wraps + +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + + +def create_bar_chart(data, labels, title=""): + """Create a bar chart using Rich Table.""" + + console = Console() + table = Table(title=title) + table.add_column("Score") + table.add_column("Count") + + max_value = max(data) + total_count = sum(data) + + # Define a list of colors to cycle through + colors = ["green", "blue", "red", "yellow", "magenta", "cyan"] + + for i, (label, value) in enumerate(zip(labels, data)): + bar_length = int((value / max_value) * 20) # Adjust bar length as needed + bar = "█" * bar_length + " " * (20 - bar_length) + color = colors[i % len(colors)] + table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}") + + console.print(table) + + +def handle_client_errors(operation_name): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + console = Console() + console.print( + Panel.fit( + f"[bold red]Failed to {operation_name}[/bold red]\n\n" + f"[yellow]Error Type:[/yellow] {e.__class__.__name__}\n" + f"[yellow]Details:[/yellow] {str(e)}" + ) + ) + + return wrapper + + return decorator diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py new file mode 100644 index 00000000..59554580 --- /dev/null +++ b/src/llama_stack_client/lib/cli/configure.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os + +import click +import yaml +from prompt_toolkit import prompt +from prompt_toolkit.validation import Validator +from urllib.parse import urlparse + +from llama_stack_client.lib.cli.constants import LLAMA_STACK_CLIENT_CONFIG_DIR, get_config_file_path + + +def get_config(): + config_file = get_config_file_path() + if config_file.exists(): + with open(config_file, "r") as f: + return yaml.safe_load(f) + return None + + +@click.command() +@click.help_option("-h", "--help") +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="") +@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="") +def configure(endpoint: str | None, api_key: str | None): + """Configure Llama Stack Client CLI.""" + os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True) + config_path = get_config_file_path() + + if endpoint != "": + final_endpoint = endpoint + else: + final_endpoint = prompt( + "> Enter the endpoint of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: len(x) > 0 and (parsed := urlparse(x)).scheme and parsed.netloc, + error_message="Endpoint cannot be empty and must be a valid URL, please enter a valid endpoint", + ), + ) + + if api_key != "": + final_api_key = api_key + else: + final_api_key = prompt( + "> Enter the API key (leave empty if no key is needed): ", + ) + + # Prepare config dict before writing it + config_dict = { + "endpoint": final_endpoint, + } + if final_api_key != "": + config_dict["api_key"] = final_api_key + + with open(config_path, "w") as f: + f.write( + yaml.dump( + config_dict, + sort_keys=True, + ) + ) + + print(f"Done! You can now use the Llama Stack Client CLI with endpoint {final_endpoint}") diff --git a/src/llama_stack_client/lib/cli/constants.py b/src/llama_stack_client/lib/cli/constants.py new file mode 100644 index 00000000..22595747 --- /dev/null +++ b/src/llama_stack_client/lib/cli/constants.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from pathlib import Path + +LLAMA_STACK_CLIENT_CONFIG_DIR = Path(os.path.expanduser("~/.llama/client")) + + +def get_config_file_path(): + return LLAMA_STACK_CLIENT_CONFIG_DIR / "config.yaml" diff --git a/src/llama_stack_client/lib/cli/datasets/__init__.py b/src/llama_stack_client/lib/cli/datasets/__init__.py new file mode 100644 index 00000000..ec7b144f --- /dev/null +++ b/src/llama_stack_client/lib/cli/datasets/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .datasets import datasets + +__all__ = ["datasets"] diff --git a/src/llama_stack_client/lib/cli/datasets/datasets.py b/src/llama_stack_client/lib/cli/datasets/datasets.py new file mode 100644 index 00000000..c01b875a --- /dev/null +++ b/src/llama_stack_client/lib/cli/datasets/datasets.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import click + +from .list import list_datasets +from .register import register +from .unregister import unregister + + +@click.group() +@click.help_option("-h", "--help") +def datasets(): + """Manage datasets.""" + + +# Register subcommands +datasets.add_command(list_datasets) +datasets.add_command(register) +datasets.add_command(unregister) diff --git a/src/llama_stack_client/lib/cli/datasets/list.py b/src/llama_stack_client/lib/cli/datasets/list.py new file mode 100644 index 00000000..61d625c9 --- /dev/null +++ b/src/llama_stack_client/lib/cli/datasets/list.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list datasets") +def list_datasets(ctx): + """Show available datasets on distribution endpoint""" + client = ctx.obj["client"] + console = Console() + headers = ["identifier", "provider_id", "metadata", "type", "purpose"] + + datasets_list_response = client.datasets.list() + if datasets_list_response: + table = Table() + for header in headers: + table.add_column(header) + + for item in datasets_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/datasets/register.py b/src/llama_stack_client/lib/cli/datasets/register.py new file mode 100644 index 00000000..d990e30c --- /dev/null +++ b/src/llama_stack_client/lib/cli/datasets/register.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import base64 +import json +import mimetypes +import os +from typing import Optional, Literal + +import click +import yaml + +from ..common.utils import handle_client_errors + + +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + return data_url + + +@click.command("register") +@click.help_option("-h", "--help") +@click.option("--dataset-id", required=True, help="Id of the dataset") +@click.option( + "--purpose", + type=click.Choice(["post-training/messages", "eval/question-answer", "eval/messages-answer"]), + help="Purpose of the dataset", + required=True, +) +@click.option("--metadata", type=str, help="Metadata of the dataset") +@click.option("--url", type=str, help="URL of the dataset", required=False) +@click.option( + "--dataset-path", required=False, help="Local file path to the dataset. If specified, upload dataset via URL" +) +@click.pass_context +@handle_client_errors("register dataset") +def register( + ctx, + dataset_id: str, + purpose: Literal["post-training/messages", "eval/question-answer", "eval/messages-answer"], + metadata: Optional[str], + url: Optional[str], + dataset_path: Optional[str], +): + """Create a new dataset""" + client = ctx.obj["client"] + + if metadata: + try: + metadata = json.loads(metadata) + except json.JSONDecodeError as err: + raise click.BadParameter("Metadata must be valid JSON") from err + + if dataset_path: + url = data_url_from_file(dataset_path) + else: + if not url: + raise click.BadParameter("URL is required when dataset path is not specified") + + response = client.datasets.register( + dataset_id=dataset_id, + source={"uri": url}, + metadata=metadata, + purpose=purpose, + ) + if response: + click.echo(yaml.dump(response.dict())) diff --git a/src/llama_stack_client/lib/cli/datasets/unregister.py b/src/llama_stack_client/lib/cli/datasets/unregister.py new file mode 100644 index 00000000..8ca7cceb --- /dev/null +++ b/src/llama_stack_client/lib/cli/datasets/unregister.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import click + +from ..common.utils import handle_client_errors + + +@click.command("unregister") +@click.help_option("-h", "--help") +@click.argument("dataset-id", required=True) +@click.pass_context +@handle_client_errors("unregister dataset") +def unregister(ctx, dataset_id: str): + """Remove a dataset""" + client = ctx.obj["client"] + client.datasets.unregister(dataset_id=dataset_id) + click.echo(f"Dataset '{dataset_id}' unregistered successfully") diff --git a/src/llama_stack_client/lib/cli/eval/__init__.py b/src/llama_stack_client/lib/cli/eval/__init__.py new file mode 100644 index 00000000..503994e9 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval import eval + +__all__ = ["eval"] diff --git a/src/llama_stack_client/lib/cli/eval/eval.py b/src/llama_stack_client/lib/cli/eval/eval.py new file mode 100644 index 00000000..dd162809 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/eval.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import click + +from .run_benchmark import run_benchmark +from .run_scoring import run_scoring + + +@click.group() +@click.help_option("-h", "--help") +def eval(): + """Run evaluation tasks.""" + + +# Register subcommands +eval.add_command(run_benchmark) +eval.add_command(run_scoring) diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py new file mode 100644 index 00000000..e088137e --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +from typing import Optional + +import click +from rich import print as rprint +from tqdm.rich import tqdm + +from ..common.utils import create_bar_chart +from .utils import ( + aggregate_accuracy, + aggregate_average, + aggregate_weighted_average, + aggregate_categorical_count, + aggregate_median, +) + + +@click.command("run-benchmark") +@click.help_option("-h", "--help") +@click.argument("benchmark-ids", nargs=-1, required=True) +@click.option( + "--model-id", + required=True, + help="model id to run the benchmark eval on", + default=None, + type=str, +) +@click.option( + "--output-dir", + required=True, + help="Path to the dump eval results output directory", +) +@click.option( + "--num-examples", + required=False, + help="Number of examples to evaluate on, useful for debugging", + default=None, + type=int, +) +@click.option( + "--temperature", + required=False, + help="temperature in the sampling params to run generation", + default=0.0, + type=float, +) +@click.option( + "--max-tokens", + required=False, + help="max-tokens in the sampling params to run generation", + default=4096, + type=int, +) +@click.option( + "--top-p", + required=False, + help="top-p in the sampling params to run generation", + default=0.9, + type=float, +) +@click.option( + "--repeat-penalty", + required=False, + help="repeat-penalty in the sampling params to run generation", + default=1.0, + type=float, +) +@click.option( + "--visualize", + is_flag=True, + default=False, + help="Visualize evaluation results after completion", +) +@click.pass_context +def run_benchmark( + ctx, + benchmark_ids: tuple[str, ...], + model_id: str, + output_dir: str, + num_examples: Optional[int], + temperature: float, + max_tokens: int, + top_p: float, + repeat_penalty: float, + visualize: bool, +): + """Run a evaluation benchmark task""" + + client = ctx.obj["client"] + + for benchmark_id in benchmark_ids: + benchmark = client.benchmarks.retrieve(benchmark_id=benchmark_id) + scoring_functions = benchmark.scoring_functions + dataset_id = benchmark.dataset_id + + results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples) + + output_res = {} + + for i, r in enumerate(tqdm(results.data)): + eval_res = client.eval.evaluate_rows( + benchmark_id=benchmark_id, + input_rows=[r], + scoring_functions=scoring_functions, + benchmark_config={ + "type": "benchmark", + "eval_candidate": { + "type": "model", + "model": model_id, + "sampling_params": { + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + }, + }, + }, + ) + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for k in eval_res.generations[0].keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(eval_res.generations[0][k]) + + for scoring_fn in scoring_functions: + if scoring_fn not in output_res: + output_res[scoring_fn] = [] + output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) + + aggregation_functions = client.scoring_functions.retrieve( + scoring_fn_id=scoring_fn + ).params.aggregation_functions + + # only output the aggregation result for the last row + if i == len(results.data) - 1: + for aggregation_function in aggregation_functions: + scoring_results = output_res[scoring_fn] + if aggregation_function == "categorical_count": + output_res[scoring_fn].append(aggregate_categorical_count(scoring_results)) + elif aggregation_function == "average": + output_res[scoring_fn].append(aggregate_average(scoring_results)) + elif aggregation_function == "weighted_average": + output_res[scoring_fn].append(aggregate_weighted_average(scoring_results)) + elif aggregation_function == "median": + output_res[scoring_fn].append(aggregate_median(scoring_results)) + elif aggregation_function == "accuracy": + output_res[scoring_fn].append(aggregate_accuracy(scoring_results)) + else: + raise NotImplementedError( + f"Aggregation function {aggregation_function} is not supported yet" + ) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + # Save results to JSON file + output_file = os.path.join(output_dir, f"{benchmark_id}_results.json") + with open(output_file, "w") as f: + json.dump(output_res, f, indent=2) + + rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n") + + if visualize: + for scoring_fn in scoring_functions: + aggregation_functions = client.scoring_functions.retrieve( + scoring_fn_id=scoring_fn + ).params.aggregation_functions + + for aggregation_function in aggregation_functions: + res = output_res[scoring_fn] + assert len(res) > 0 and "score" in res[0] + if aggregation_function == "categorical_count": + scores = [str(r["score"]) for r in res] + unique_scores = sorted(list(set(scores))) + counts = [scores.count(s) for s in unique_scores] + create_bar_chart( + counts, + unique_scores, + title=f"{scoring_fn}-{aggregation_function}", + ) + else: + raise NotImplementedError( + f"Aggregation function {aggregation_function} ius not supported for visualization yet" + ) diff --git a/src/llama_stack_client/lib/cli/eval/run_scoring.py b/src/llama_stack_client/lib/cli/eval/run_scoring.py new file mode 100644 index 00000000..78560a0a --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/run_scoring.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import os +from typing import Optional + +import click +import pandas +from rich import print as rprint +from tqdm.rich import tqdm + + +@click.command("run-scoring") +@click.help_option("-h", "--help") +@click.argument("scoring-function-ids", nargs=-1, required=True) +@click.option( + "--dataset-id", + required=False, + help="Pre-registered dataset_id to score (from llama-stack-client datasets list)", +) +@click.option( + "--dataset-path", + required=False, + help="Path to the dataset file to score", + type=click.Path(exists=True), +) +@click.option( + "--scoring-params-config", + required=False, + help="Path to the scoring params config file in JSON format", + type=click.Path(exists=True), +) +@click.option( + "--num-examples", + required=False, + help="Number of examples to evaluate on, useful for debugging", + default=None, + type=int, +) +@click.option( + "--output-dir", + required=True, + help="Path to the dump eval results output directory", +) +@click.option( + "--visualize", + is_flag=True, + default=False, + help="Visualize evaluation results after completion", +) +@click.pass_context +def run_scoring( + ctx, + scoring_function_ids: tuple[str, ...], + dataset_id: Optional[str], + dataset_path: Optional[str], + scoring_params_config: Optional[str], + num_examples: Optional[int], + output_dir: str, + visualize: bool, +): + """Run scoring from application datasets""" + # one of dataset_id or dataset_path is required + if dataset_id is None and dataset_path is None: + raise click.BadParameter("Specify either dataset_id (pre-registered dataset) or dataset_path (local file)") + + client = ctx.obj["client"] + + scoring_params = {fn_id: None for fn_id in scoring_function_ids} + if scoring_params_config: + with open(scoring_params_config, "r") as f: + scoring_params = json.load(f) + + output_res = {} + + if dataset_id is not None: + dataset = client.datasets.retrieve(dataset_id=dataset_id) + if not dataset: + click.BadParameter( + f"Dataset {dataset_id} not found. Please register using llama-stack-client datasets register" + ) + + # TODO: this will eventually be replaced with jobs polling from server vis score_bath + # For now, get all datasets rows via datasets API + results = client.datasets.iterrows(dataset_id=dataset_id, limit=-1 if num_examples is None else num_examples) + rows = results.rows + + if dataset_path is not None: + df = pandas.read_csv(dataset_path) + rows = df.to_dict(orient="records") + if num_examples is not None: + rows = rows[:num_examples] + + for r in tqdm(rows): + score_res = client.scoring.score( + input_rows=[r], + scoring_functions=scoring_params, + ) + for k in r.keys(): + if k not in output_res: + output_res[k] = [] + output_res[k].append(r[k]) + + for fn_id in scoring_function_ids: + if fn_id not in output_res: + output_res[fn_id] = [] + output_res[fn_id].append(score_res.results[fn_id].score_rows[0]) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, f"{dataset_path or dataset_id}_score_results.csv") + df = pandas.DataFrame(output_res) + df.to_csv(output_file, index=False) + print(df) + + rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n") diff --git a/src/llama_stack_client/lib/cli/eval/utils.py b/src/llama_stack_client/lib/cli/eval/utils.py new file mode 100644 index 00000000..96d8d54c --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict, List, Union + + +def aggregate_categorical_count( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + scores = [str(r["score"]) for r in scoring_results] + unique_scores = sorted(list(set(scores))) + return {"categorical_count": {s: scores.count(s) for s in unique_scores}} + + +def aggregate_average( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + return { + "average": sum(result["score"] for result in scoring_results if result["score"] is not None) + / len([_ for _ in scoring_results if _["score"] is not None]), + } + + +def aggregate_weighted_average( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + return { + "weighted_average": sum( + result["score"] * result["weight"] + for result in scoring_results + if result["score"] is not None and result["weight"] is not None + ) + / sum(result["weight"] for result in scoring_results if result["weight"] is not None), + } + + +def aggregate_median( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + scores = [r["score"] for r in scoring_results if r["score"] is not None] + median = statistics.median(scores) if scores else None + return {"median": median} + + +def aggregate_accuracy( + scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]], +) -> Dict[str, Any]: + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/src/llama_stack_client/lib/cli/eval_tasks/__init__.py b/src/llama_stack_client/lib/cli/eval_tasks/__init__.py new file mode 100644 index 00000000..d755c85a --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval_tasks/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval_tasks import eval_tasks + +__all__ = ["eval_tasks"] diff --git a/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py new file mode 100644 index 00000000..183498fb --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval_tasks/eval_tasks.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import json +from typing import Optional + +import click +import yaml + +from ..common.utils import handle_client_errors +from .list import list_eval_tasks + + +@click.group() +@click.help_option("-h", "--help") +def eval_tasks(): + """Manage evaluation tasks.""" + + +@eval_tasks.command() +@click.help_option("-h", "--help") +@click.option("--eval-task-id", required=True, help="ID of the eval task") +@click.option("--dataset-id", required=True, help="ID of the dataset to evaluate") +@click.option("--scoring-functions", required=True, multiple=True, help="Scoring functions to use for evaluation") +@click.option("--provider-id", help="Provider ID for the eval task", default=None) +@click.option("--provider-eval-task-id", help="Provider's eval task ID", default=None) +@click.option("--metadata", type=str, help="Metadata for the eval task in JSON format") +@click.pass_context +@handle_client_errors("register eval task") +def register( + ctx, + eval_task_id: str, + dataset_id: str, + scoring_functions: tuple[str, ...], + provider_id: Optional[str], + provider_eval_task_id: Optional[str], + metadata: Optional[str], +): + """Register a new eval task""" + client = ctx.obj["client"] + + if metadata: + try: + metadata = json.loads(metadata) + except json.JSONDecodeError as err: + raise click.BadParameter("Metadata must be valid JSON") from err + + response = client.eval_tasks.register( + eval_task_id=eval_task_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + provider_id=provider_id, + provider_eval_task_id=provider_eval_task_id, + metadata=metadata, + ) + if response: + click.echo(yaml.dump(response.dict())) + + +# Register subcommands +eval_tasks.add_command(list_eval_tasks) +eval_tasks.add_command(register) diff --git a/src/llama_stack_client/lib/cli/eval_tasks/list.py b/src/llama_stack_client/lib/cli/eval_tasks/list.py new file mode 100644 index 00000000..d7eb9c53 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval_tasks/list.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list eval tasks") +def list_eval_tasks(ctx): + """Show available eval tasks on distribution endpoint""" + + client = ctx.obj["client"] + console = Console() + headers = [] + eval_tasks_list_response = client.eval_tasks.list() + if eval_tasks_list_response and len(eval_tasks_list_response) > 0: + headers = sorted(eval_tasks_list_response[0].__dict__.keys()) + + if eval_tasks_list_response: + table = Table() + for header in headers: + table.add_column(header) + + for item in eval_tasks_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/inference/__init__.py b/src/llama_stack_client/lib/cli/inference/__init__.py new file mode 100644 index 00000000..d10d45c4 --- /dev/null +++ b/src/llama_stack_client/lib/cli/inference/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .inference import inference + +__all__ = ["inference"] diff --git a/src/llama_stack_client/lib/cli/inference/inference.py b/src/llama_stack_client/lib/cli/inference/inference.py new file mode 100644 index 00000000..0cc16396 --- /dev/null +++ b/src/llama_stack_client/lib/cli/inference/inference.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional, List, Dict +import traceback + +import click +from rich.console import Console + +from ...inference.event_logger import EventLogger +from ..common.utils import handle_client_errors + + +@click.group() +@click.help_option("-h", "--help") +def inference(): + """Inference (chat).""" + + +@click.command("chat-completion") +@click.help_option("-h", "--help") +@click.option("--message", help="Message") +@click.option("--stream", is_flag=True, help="Streaming", default=False) +@click.option("--session", is_flag=True, help="Start a Chat Session", default=False) +@click.option("--model-id", required=False, help="Model ID") +@click.pass_context +@handle_client_errors("inference chat-completion") +def chat_completion(ctx, message: str, stream: bool, session: bool, model_id: Optional[str]): + """Show available inference chat completion endpoints on distribution endpoint""" + if not message and not session: + click.secho( + "you must specify either --message or --session", + fg="red", + ) + raise click.exceptions.Exit(1) + client = ctx.obj["client"] + console = Console() + + if not model_id: + available_models = [model.identifier for model in client.models.list() if model.model_type == "llm"] + model_id = available_models[0] + + messages = [] + if message: + messages.append({"role": "user", "content": message}) + response = client.chat.completions.create( + model=model_id, + messages=messages, + stream=stream, + ) + if not stream: + console.print(response) + else: + for event in EventLogger().log(response): + event.print() + if session: + chat_session(client=client, model_id=model_id, messages=messages, console=console) + + +def chat_session(client, model_id: Optional[str], messages: List[Dict[str, str]], console: Console): + """Run an interactive chat session with the served model""" + while True: + try: + message = input(">>> ") + if message in ["\\q", "quit"]: + console.print("Exiting") + break + messages.append({"role": "user", "content": message}) + response = client.chat.completions.create( + model=model_id, + messages=messages, + stream=True, + ) + for event in EventLogger().log(response): + event.print() + except Exception as exc: + traceback.print_exc() + console.print(f"Error in chat session {exc}") + break + except KeyboardInterrupt as exc: + console.print("\nDetected user interrupt, exiting") + break + + +# Register subcommands +inference.add_command(chat_completion) diff --git a/src/llama_stack_client/lib/cli/inspect/__init__.py b/src/llama_stack_client/lib/cli/inspect/__init__.py new file mode 100644 index 00000000..db651969 --- /dev/null +++ b/src/llama_stack_client/lib/cli/inspect/__init__.py @@ -0,0 +1,3 @@ +from .inspect import inspect + +__all__ = ["inspect"] diff --git a/src/llama_stack_client/lib/cli/inspect/inspect.py b/src/llama_stack_client/lib/cli/inspect/inspect.py new file mode 100644 index 00000000..f9c85b1b --- /dev/null +++ b/src/llama_stack_client/lib/cli/inspect/inspect.py @@ -0,0 +1,13 @@ +import click + +from .version import inspect_version + + +@click.group() +@click.help_option("-h", "--help") +def inspect(): + """Inspect server configuration.""" + + +# Register subcommands +inspect.add_command(inspect_version) diff --git a/src/llama_stack_client/lib/cli/inspect/version.py b/src/llama_stack_client/lib/cli/inspect/version.py new file mode 100644 index 00000000..212b9f9d --- /dev/null +++ b/src/llama_stack_client/lib/cli/inspect/version.py @@ -0,0 +1,16 @@ +import click +from rich.console import Console + +from ..common.utils import handle_client_errors + + +@click.command("version") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("inspect version") +def inspect_version(ctx): + """Show available providers on distribution endpoint""" + client = ctx.obj["client"] + console = Console() + version_response = client.inspect.version() + console.print(version_response) diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py new file mode 100644 index 00000000..54c46aaa --- /dev/null +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from importlib.metadata import version + +import click +import yaml + +from llama_stack_client import LlamaStackClient + +from .configure import configure +from .constants import get_config_file_path +from .datasets import datasets +from .eval import eval +from .eval_tasks import eval_tasks +from .inference import inference +from .inspect import inspect +from .models import models +from .post_training import post_training +from .providers import providers +from .scoring_functions import scoring_functions +from .shields import shields +from .toolgroups import toolgroups +from .vector_dbs import vector_dbs + + +@click.group() +@click.help_option("-h", "--help") +@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client") +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="") +@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="") +@click.option("--config", type=str, help="Path to config file", default=None) +@click.pass_context +def llama_stack_client(ctx, endpoint: str, api_key: str, config: str | None): + """Welcome to the llama-stack-client CLI - a command-line interface for interacting with Llama Stack""" + ctx.ensure_object(dict) + + # If no config provided, check default location + if config and endpoint: + raise ValueError("Cannot use both config and endpoint") + + if config is None: + default_config = get_config_file_path() + if default_config.exists(): + config = str(default_config) + + if config: + try: + with open(config, "r") as f: + config_dict = yaml.safe_load(f) + endpoint = config_dict.get("endpoint", endpoint) + api_key = config_dict.get("api_key", "") + except Exception as e: + click.echo(f"Error loading config from {config}: {str(e)}", err=True) + click.echo("Falling back to HTTP client with endpoint", err=True) + + if endpoint == "": + endpoint = "http://localhost:8321" + + default_headers = {} + if api_key != "": + default_headers = { + "Authorization": f"Bearer {api_key}", + } + + client = LlamaStackClient( + base_url=endpoint, + provider_data={ + "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""), + "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), + "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), + }, + default_headers=default_headers, + ) + ctx.obj = {"client": client} + + +# Register all subcommands +llama_stack_client.add_command(models, "models") +llama_stack_client.add_command(vector_dbs, "vector_dbs") +llama_stack_client.add_command(shields, "shields") +llama_stack_client.add_command(eval_tasks, "eval_tasks") +llama_stack_client.add_command(providers, "providers") +llama_stack_client.add_command(datasets, "datasets") +llama_stack_client.add_command(configure, "configure") +llama_stack_client.add_command(scoring_functions, "scoring_functions") +llama_stack_client.add_command(eval, "eval") +llama_stack_client.add_command(inference, "inference") +llama_stack_client.add_command(post_training, "post_training") +llama_stack_client.add_command(inspect, "inspect") +llama_stack_client.add_command(toolgroups, "toolgroups") + + +def main(): + llama_stack_client() + + +if __name__ == "__main__": + main() diff --git a/src/llama_stack_client/lib/cli/models/__init__.py b/src/llama_stack_client/lib/cli/models/__init__.py new file mode 100644 index 00000000..64479669 --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .models import models + +__all__ = ["models"] diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py new file mode 100644 index 00000000..c724e5d5 --- /dev/null +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.group() +@click.help_option("-h", "--help") +def models(): + """Manage GenAI models.""" + + +@click.command(name="list", help="Show available llama models at distribution endpoint") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list models") +def list_models(ctx): + client = ctx.obj["client"] + console = Console() + + headers = [ + "model_type", + "identifier", + "provider_alias", + "metadata", + "provider_id", + ] + response = client.models.list() + if response: + table = Table( + show_lines=True, # Add lines between rows for better readability + padding=(0, 1), # Add horizontal padding + expand=True, # Allow table to use full width + ) + + # Configure columns with specific styling + table.add_column("model_type", style="blue") + table.add_column("identifier", style="bold cyan", no_wrap=True, overflow="fold") + table.add_column("provider_resource_id", style="yellow", no_wrap=True, overflow="fold") + table.add_column("metadata", style="magenta", max_width=30, overflow="fold") + table.add_column("provider_id", style="green", max_width=20) + + for item in response: + table.add_row( + item.model_type, + item.identifier, + item.provider_resource_id, + str(item.metadata or ""), + item.provider_id, + ) + + # Create a title for the table + console.print("\n[bold]Available Models[/bold]\n") + console.print(table) + console.print(f"\nTotal models: {len(response)}\n") + + +@click.command(name="get") +@click.help_option("-h", "--help") +@click.argument("model_id") +@click.pass_context +@handle_client_errors("get model details") +def get_model(ctx, model_id: str): + """Show details of a specific model at the distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + models_get_response = client.models.retrieve(model_id=model_id) + + if not models_get_response: + console.print( + f"Model {model_id} is not found at distribution endpoint. " + "Please ensure endpoint is serving specified model.", + style="bold red", + ) + return + + headers = sorted(models_get_response.__dict__.keys()) + table = Table() + for header in headers: + table.add_column(header) + + table.add_row(*[str(models_get_response.__dict__[header]) for header in headers]) + console.print(table) + + +@click.command(name="register", help="Register a new model at distribution endpoint") +@click.help_option("-h", "--help") +@click.argument("model_id") +@click.option("--provider-id", help="Provider ID for the model", default=None) +@click.option("--provider-model-id", help="Provider's model ID", default=None) +@click.option("--metadata", help="JSON metadata for the model", default=None) +@click.pass_context +@handle_client_errors("register model") +def register_model( + ctx, + model_id: str, + provider_id: Optional[str], + provider_model_id: Optional[str], + metadata: Optional[str], +): + """Register a new model at distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + response = client.models.register( + model_id=model_id, + provider_id=provider_id, + provider_model_id=provider_model_id, + metadata=metadata, + ) + if response: + console.print(f"[green]Successfully registered model {model_id}[/green]") + + +@click.command(name="unregister", help="Unregister a model from distribution endpoint") +@click.help_option("-h", "--help") +@click.argument("model_id") +@click.pass_context +@handle_client_errors("unregister model") +def unregister_model(ctx, model_id: str): + client = ctx.obj["client"] + console = Console() + + response = client.models.unregister(model_id=model_id) + if response: + console.print(f"[green]Successfully deleted model {model_id}[/green]") + + +# Register subcommands +models.add_command(list_models) +models.add_command(get_model) +models.add_command(register_model) +models.add_command(unregister_model) diff --git a/src/llama_stack_client/lib/cli/post_training/__init__.py b/src/llama_stack_client/lib/cli/post_training/__init__.py new file mode 100644 index 00000000..bbb17b3c --- /dev/null +++ b/src/llama_stack_client/lib/cli/post_training/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .post_training import post_training + +__all__ = ["post_training"] diff --git a/src/llama_stack_client/lib/cli/post_training/post_training.py b/src/llama_stack_client/lib/cli/post_training/post_training.py new file mode 100644 index 00000000..b9b353fb --- /dev/null +++ b/src/llama_stack_client/lib/cli/post_training/post_training.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +from rich.console import Console + +from llama_stack_client.types.post_training_supervised_fine_tune_params import AlgorithmConfigParam, TrainingConfig + +from ..common.utils import handle_client_errors + + +@click.group() +@click.help_option("-h", "--help") +def post_training(): + """Post-training.""" + + +@click.command("supervised_fine_tune") +@click.help_option("-h", "--help") +@click.option("--job-uuid", required=True, help="Job UUID") +@click.option("--model", required=True, help="Model ID") +@click.option("--algorithm-config", required=True, help="Algorithm Config") +@click.option("--training-config", required=True, help="Training Config") +@click.option("--checkpoint-dir", required=False, help="Checkpoint Config", default=None) +@click.pass_context +@handle_client_errors("post_training supervised_fine_tune") +def supervised_fine_tune( + ctx, + job_uuid: str, + model: str, + algorithm_config: AlgorithmConfigParam, + training_config: TrainingConfig, + checkpoint_dir: Optional[str], +): + """Kick off a supervised fine tune job""" + client = ctx.obj["client"] + console = Console() + + post_training_job = client.post_training.supervised_fine_tune( + job_uuid=job_uuid, + model=model, + algorithm_config=algorithm_config, + training_config=training_config, + checkpoint_dir=checkpoint_dir, + # logger_config and hyperparam_search_config haven't been used yet + logger_config={}, + hyperparam_search_config={}, + ) + console.print(post_training_job.job_uuid) + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("post_training get_training_jobs") +def get_training_jobs(ctx): + """Show the list of available post training jobs""" + client = ctx.obj["client"] + console = Console() + + post_training_jobs = client.post_training.job.list() + console.print([post_training_job.job_uuid for post_training_job in post_training_jobs]) + + +@click.command("status") +@click.help_option("-h", "--help") +@click.option("--job-uuid", required=True, help="Job UUID") +@click.pass_context +@handle_client_errors("post_training get_training_job_status") +def get_training_job_status(ctx, job_uuid: str): + """Show the status of a specific post training job""" + client = ctx.obj["client"] + console = Console() + + job_status_reponse = client.post_training.job.status(job_uuid=job_uuid) + console.print(job_status_reponse) + + +@click.command("artifacts") +@click.help_option("-h", "--help") +@click.option("--job-uuid", required=True, help="Job UUID") +@click.pass_context +@handle_client_errors("post_training get_training_job_artifacts") +def get_training_job_artifacts(ctx, job_uuid: str): + """Get the training artifacts of a specific post training job""" + client = ctx.obj["client"] + console = Console() + + job_artifacts = client.post_training.job.artifacts(job_uuid=job_uuid) + console.print(job_artifacts) + + +@click.command("cancel") +@click.help_option("-h", "--help") +@click.option("--job-uuid", required=True, help="Job UUID") +@click.pass_context +@handle_client_errors("post_training cancel_training_job") +def cancel_training_job(ctx, job_uuid: str): + """Cancel the training job""" + client = ctx.obj["client"] + + client.post_training.job.cancel(job_uuid=job_uuid) + + +# Register subcommands +post_training.add_command(supervised_fine_tune) +post_training.add_command(get_training_jobs) +post_training.add_command(get_training_job_status) +post_training.add_command(get_training_job_artifacts) +post_training.add_command(cancel_training_job) diff --git a/src/llama_stack_client/lib/cli/providers/__init__.py b/src/llama_stack_client/lib/cli/providers/__init__.py new file mode 100644 index 00000000..2e632915 --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/__init__.py @@ -0,0 +1,3 @@ +from .providers import providers + +__all__ = ["providers"] diff --git a/src/llama_stack_client/lib/cli/providers/inspect.py b/src/llama_stack_client/lib/cli/providers/inspect.py new file mode 100644 index 00000000..fc03d00d --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/inspect.py @@ -0,0 +1,27 @@ +import click +import yaml +from rich.console import Console + +from ..common.utils import handle_client_errors + + +@click.command(name="inspect") +@click.argument("provider_id") +@click.pass_context +@handle_client_errors("inspect providers") +def inspect_provider(ctx, provider_id): + """Show available providers on distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + providers_response = client.providers.retrieve(provider_id=provider_id) + + if not providers_response: + click.secho("Provider not found", fg="red") + raise click.exceptions.Exit(1) + + console.print(f"provider_id={providers_response.provider_id}") + console.print(f"provider_type={providers_response.provider_type}") + console.print("config:") + for line in yaml.dump(providers_response.config, indent=2).split("\n"): + console.print(line) diff --git a/src/llama_stack_client/lib/cli/providers/list.py b/src/llama_stack_client/lib/cli/providers/list.py new file mode 100644 index 00000000..692860e3 --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/list.py @@ -0,0 +1,26 @@ +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list providers") +def list_providers(ctx): + """Show available providers on distribution endpoint""" + client = ctx.obj["client"] + console = Console() + headers = ["API", "Provider ID", "Provider Type"] + + providers_response = client.providers.list() + table = Table() + for header in headers: + table.add_column(header) + + for response in providers_response: + table.add_row(response.api, response.provider_id, response.provider_type) + + console.print(table) diff --git a/src/llama_stack_client/lib/cli/providers/providers.py b/src/llama_stack_client/lib/cli/providers/providers.py new file mode 100644 index 00000000..bd07628d --- /dev/null +++ b/src/llama_stack_client/lib/cli/providers/providers.py @@ -0,0 +1,15 @@ +import click + +from .list import list_providers +from .inspect import inspect_provider + + +@click.group() +@click.help_option("-h", "--help") +def providers(): + """Manage API providers.""" + + +# Register subcommands +providers.add_command(list_providers) +providers.add_command(inspect_provider) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/__init__.py b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py new file mode 100644 index 00000000..9699df68 --- /dev/null +++ b/src/llama_stack_client/lib/cli/scoring_functions/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .scoring_functions import scoring_functions + +__all__ = ["scoring_functions"] diff --git a/src/llama_stack_client/lib/cli/scoring_functions/list.py b/src/llama_stack_client/lib/cli/scoring_functions/list.py new file mode 100644 index 00000000..b4bb3b70 --- /dev/null +++ b/src/llama_stack_client/lib/cli/scoring_functions/list.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list scoring functions") +def list_scoring_functions(ctx): + """Show available scoring functions on distribution endpoint""" + + client = ctx.obj["client"] + console = Console() + headers = [ + "identifier", + "provider_id", + "description", + "type", + ] + + scoring_functions_list_response = client.scoring_functions.list() + if scoring_functions_list_response: + table = Table() + for header in headers: + table.add_column(header) + + for item in scoring_functions_list_response: + table.add_row(*[str(getattr(item, header)) for header in headers]) + console.print(table) diff --git a/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py new file mode 100644 index 00000000..ba7b58eb --- /dev/null +++ b/src/llama_stack_client/lib/cli/scoring_functions/scoring_functions.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from typing import Optional + +import click +import yaml + +from .list import list_scoring_functions + + +@click.group() +@click.help_option("-h", "--help") +def scoring_functions(): + """Manage scoring functions.""" + + +@scoring_functions.command() +@click.help_option("-h", "--help") +@click.option("--scoring-fn-id", required=True, help="Id of the scoring function") +@click.option("--description", required=True, help="Description of the scoring function") +@click.option("--return-type", type=str, required=True, help="Return type of the scoring function") +@click.option("--provider-id", type=str, help="Provider ID for the scoring function", default=None) +@click.option("--provider-scoring-fn-id", type=str, help="Provider's scoring function ID", default=None) +@click.option("--params", type=str, help="Parameters for the scoring function in JSON format", default=None) +@click.pass_context +def register( + ctx, + scoring_fn_id: str, + description: str, + return_type: str, + provider_id: Optional[str], + provider_scoring_fn_id: Optional[str], + params: Optional[str], +): + """Register a new scoring function""" + client = ctx.obj["client"] + + if params: + try: + params = json.loads(params) + except json.JSONDecodeError as err: + raise click.BadParameter("Parameters must be valid JSON") from err + + response = client.scoring_functions.register( + scoring_fn_id=scoring_fn_id, + description=description, + return_type=json.loads(return_type), + provider_id=provider_id, + provider_scoring_fn_id=provider_scoring_fn_id, + params=params, + ) + if response: + click.echo(yaml.dump(response.dict())) + + +# Register subcommands +scoring_functions.add_command(list_scoring_functions) +scoring_functions.add_command(register) diff --git a/src/llama_stack_client/lib/cli/shields/__init__.py b/src/llama_stack_client/lib/cli/shields/__init__.py new file mode 100644 index 00000000..45f397c4 --- /dev/null +++ b/src/llama_stack_client/lib/cli/shields/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .shields import shields + +__all__ = ["shields"] diff --git a/src/llama_stack_client/lib/cli/shields/shields.py b/src/llama_stack_client/lib/cli/shields/shields.py new file mode 100644 index 00000000..5a3177f9 --- /dev/null +++ b/src/llama_stack_client/lib/cli/shields/shields.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +import yaml +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.group() +@click.help_option("-h", "--help") +def shields(): + """Manage safety shield services.""" + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list shields") +def list(ctx): + """Show available safety shields on distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + shields_list_response = client.shields.list() + headers = [ + "identifier", + "provider_alias", + "params", + "provider_id", + ] + + if shields_list_response: + table = Table( + show_lines=True, # Add lines between rows for better readability + padding=(0, 1), # Add horizontal padding + expand=True, # Allow table to use full width + ) + + table.add_column("identifier", style="bold cyan", no_wrap=True, overflow="fold") + table.add_column("provider_alias", style="yellow", no_wrap=True, overflow="fold") + table.add_column("params", style="magenta", max_width=30, overflow="fold") + table.add_column("provider_id", style="green", max_width=20) + + for item in shields_list_response: + table.add_row( + item.identifier, + item.provider_resource_id, + str(item.params or ""), + item.provider_id, + ) + + console.print(table) + + +@shields.command() +@click.help_option("-h", "--help") +@click.option("--shield-id", required=True, help="Id of the shield") +@click.option("--provider-id", help="Provider ID for the shield", default=None) +@click.option("--provider-shield-id", help="Provider's shield ID", default=None) +@click.option( + "--params", + type=str, + help="JSON configuration parameters for the shield", + default=None, +) +@click.pass_context +@handle_client_errors("register shield") +def register( + ctx, + shield_id: str, + provider_id: Optional[str], + provider_shield_id: Optional[str], + params: Optional[str], +): + """Register a new safety shield""" + client = ctx.obj["client"] + + response = client.shields.register( + shield_id=shield_id, + params=params, + provider_id=provider_id, + provider_shield_id=provider_shield_id, + ) + if response: + click.echo(yaml.dump(response.dict())) + + +# Register subcommands +shields.add_command(list) +shields.add_command(register) diff --git a/src/llama_stack_client/lib/cli/toolgroups/__init__.py b/src/llama_stack_client/lib/cli/toolgroups/__init__.py new file mode 100644 index 00000000..912d911b --- /dev/null +++ b/src/llama_stack_client/lib/cli/toolgroups/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .toolgroups import toolgroups + +__all__ = ["toolgroups"] diff --git a/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py b/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py new file mode 100644 index 00000000..1e3d921d --- /dev/null +++ b/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors +from ....types import toolgroup_register_params +from ...._types import NOT_GIVEN, NotGiven + + +@click.group() +@click.help_option("-h", "--help") +def toolgroups(): + """Manage available tool groups.""" + + +@click.command(name="list", help="Show available llama toolgroups at distribution endpoint") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list toolgroups") +def list_toolgroups(ctx): + client = ctx.obj["client"] + console = Console() + + headers = ["identifier", "provider_id", "args", "mcp_endpoint"] + response = client.toolgroups.list() + if response: + table = Table() + for header in headers: + table.add_column(header) + + for item in response: + row = [str(getattr(item, header)) for header in headers] + table.add_row(*row) + console.print(table) + + +@click.command(name="get") +@click.help_option("-h", "--help") +@click.argument("toolgroup_id") +@click.pass_context +@handle_client_errors("get toolgroup details") +def get_toolgroup(ctx, toolgroup_id: str): + """Show available llama toolgroups at distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + toolgroups_get_response = client.tools.list() + # filter response to only include provided toolgroup_id + toolgroups_get_response = [ + toolgroup for toolgroup in toolgroups_get_response if toolgroup.toolgroup_id == toolgroup_id + ] + if len(toolgroups_get_response) == 0: + console.print( + f"Toolgroup {toolgroup_id} is not found at distribution endpoint. " + "Please ensure endpoint is serving specified toolgroup.", + style="bold red", + ) + return + + headers = sorted(toolgroups_get_response[0].__dict__.keys()) + table = Table() + for header in headers: + table.add_column(header) + + for toolgroup in toolgroups_get_response: + row = [str(getattr(toolgroup, header)) for header in headers] + table.add_row(*row) + console.print(table) + + +@click.command(name="register", help="Register a new toolgroup at distribution endpoint") +@click.help_option("-h", "--help") +@click.argument("toolgroup_id") +@click.option("--provider-id", help="Provider ID for the toolgroup", default=None) +@click.option("--mcp-endpoint", help="JSON mcp_config for the toolgroup", default=None) +@click.option("--args", help="JSON args for the toolgroup", default=None) +@click.pass_context +@handle_client_errors("register toolgroup") +def register_toolgroup( + ctx, + toolgroup_id: str, + provider_id: Optional[str], + mcp_endpoint: Optional[str], + args: Optional[str], +): + """Register a new toolgroup at distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + _mcp_endpoint: toolgroup_register_params.McpEndpoint | NotGiven = NOT_GIVEN + if mcp_endpoint: + _mcp_endpoint = toolgroup_register_params.McpEndpoint(uri=mcp_endpoint) + + response = client.toolgroups.register( + toolgroup_id=toolgroup_id, + provider_id=provider_id, + args=args, + mcp_endpoint=_mcp_endpoint, + ) + if response: + console.print(f"[green]Successfully registered toolgroup {toolgroup_id}[/green]") + + +@click.command(name="unregister", help="Unregister a toolgroup from distribution endpoint") +@click.help_option("-h", "--help") +@click.argument("toolgroup_id") +@click.pass_context +@handle_client_errors("unregister toolgroup") +def unregister_toolgroup(ctx, toolgroup_id: str): + client = ctx.obj["client"] + console = Console() + + response = client.toolgroups.unregister(toolgroup_id=toolgroup_id) + if response: + console.print(f"[green]Successfully deleted toolgroup {toolgroup_id}[/green]") + + +# Register subcommands +toolgroups.add_command(list_toolgroups) +toolgroups.add_command(get_toolgroup) +toolgroups.add_command(register_toolgroup) +toolgroups.add_command(unregister_toolgroup) diff --git a/src/llama_stack_client/lib/cli/vector_dbs/__init__.py b/src/llama_stack_client/lib/cli/vector_dbs/__init__.py new file mode 100644 index 00000000..62e1cd65 --- /dev/null +++ b/src/llama_stack_client/lib/cli/vector_dbs/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .vector_dbs import vector_dbs + +__all__ = ["vector_dbs"] diff --git a/src/llama_stack_client/lib/cli/vector_dbs/vector_dbs.py b/src/llama_stack_client/lib/cli/vector_dbs/vector_dbs.py new file mode 100644 index 00000000..cb196942 --- /dev/null +++ b/src/llama_stack_client/lib/cli/vector_dbs/vector_dbs.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +import yaml +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.group() +@click.help_option("-h", "--help") +def vector_dbs(): + """Manage vector databases.""" + + +@click.command("list") +@click.help_option("-h", "--help") +@click.pass_context +@handle_client_errors("list vector dbs") +def list(ctx): + """Show available vector dbs on distribution endpoint""" + + client = ctx.obj["client"] + console = Console() + vector_dbs_list_response = client.vector_dbs.list() + + if vector_dbs_list_response: + table = Table() + # Add our specific columns + table.add_column("identifier") + table.add_column("provider_id") + table.add_column("provider_resource_id") + table.add_column("vector_db_type") + table.add_column("params") + + for item in vector_dbs_list_response: + # Create a dict of all attributes + item_dict = item.__dict__ + + # Extract our main columns + identifier = str(item_dict.pop("identifier", "")) + provider_id = str(item_dict.pop("provider_id", "")) + provider_resource_id = str(item_dict.pop("provider_resource_id", "")) + vector_db_type = str(item_dict.pop("vector_db_type", "")) + # Convert remaining attributes to YAML string for params column + params = yaml.dump(item_dict, default_flow_style=False) + + table.add_row(identifier, provider_id, provider_resource_id, vector_db_type, params) + + console.print(table) + + +@vector_dbs.command() +@click.help_option("-h", "--help") +@click.argument("vector-db-id") +@click.option("--provider-id", help="Provider ID for the vector db", default=None) +@click.option("--provider-vector-db-id", help="Provider's vector db ID", default=None) +@click.option( + "--embedding-model", + type=str, + help="Embedding model (for vector type)", + default="all-MiniLM-L6-v2", +) +@click.option( + "--embedding-dimension", + type=int, + help="Embedding dimension (for vector type)", + default=384, +) +@click.pass_context +@handle_client_errors("register vector db") +def register( + ctx, + vector_db_id: str, + provider_id: Optional[str], + provider_vector_db_id: Optional[str], + embedding_model: Optional[str], + embedding_dimension: Optional[int], +): + """Create a new vector db""" + client = ctx.obj["client"] + + response = client.vector_dbs.register( + vector_db_id=vector_db_id, + provider_id=provider_id, + provider_vector_db_id=provider_vector_db_id, + embedding_model=embedding_model, + embedding_dimension=embedding_dimension, + ) + if response: + click.echo(yaml.dump(response.dict())) + + +@vector_dbs.command() +@click.help_option("-h", "--help") +@click.argument("vector-db-id") +@click.pass_context +@handle_client_errors("delete vector db") +def unregister(ctx, vector_db_id: str): + """Delete a vector db""" + client = ctx.obj["client"] + client.vector_dbs.unregister(vector_db_id=vector_db_id) + click.echo(f"Vector db '{vector_db_id}' deleted successfully") + + +# Register subcommands +vector_dbs.add_command(list) +vector_dbs.add_command(register) +vector_dbs.add_command(unregister) diff --git a/src/llama_stack_client/lib/inference/__init__.py b/src/llama_stack_client/lib/inference/__init__.py new file mode 100644 index 00000000..756f351d --- /dev/null +++ b/src/llama_stack_client/lib/inference/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/src/llama_stack_client/lib/inference/event_logger.py b/src/llama_stack_client/lib/inference/event_logger.py new file mode 100644 index 00000000..14b46372 --- /dev/null +++ b/src/llama_stack_client/lib/inference/event_logger.py @@ -0,0 +1,80 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Generator +from termcolor import cprint +from llama_stack_client.types import ChatCompletionResponseStreamChunk, ChatCompletionChunk + + +class InferenceStreamPrintableEvent: + def __init__( + self, + content: str = "", + end: str = "\n", + color="white", + ): + self.content = content + self.color = color + self.end = "\n" if end is None else end + + def print(self, flush=True): + cprint(f"{self.content}", color=self.color, end=self.end, flush=flush) + + +class InferenceStreamLogEventPrinter: + def __init__(self): + self.is_thinking = False + + def yield_printable_events( + self, chunk: ChatCompletionResponseStreamChunk | ChatCompletionChunk + ) -> Generator[InferenceStreamPrintableEvent, None, None]: + # Check if the chunk has event attribute (ChatCompletionResponseStreamChunk) + if hasattr(chunk, "event"): + yield from self._handle_inference_stream_chunk(chunk) + # Check if the chunk has choices attribute (ChatCompletionChunk) + elif hasattr(chunk, "choices") and len(chunk.choices) > 0: + yield from self._handle_chat_completion_chunk(chunk) + + def _handle_inference_stream_chunk( + self, chunk: ChatCompletionResponseStreamChunk + ) -> Generator[InferenceStreamPrintableEvent, None, None]: + event = chunk.event + if event.event_type == "start": + yield InferenceStreamPrintableEvent("Assistant> ", color="cyan", end="") + elif event.event_type == "progress": + if event.delta.type == "reasoning": + if not self.is_thinking: + yield InferenceStreamPrintableEvent(" ", color="magenta", end="") + self.is_thinking = True + yield InferenceStreamPrintableEvent(event.delta.reasoning, color="magenta", end="") + else: + if self.is_thinking: + yield InferenceStreamPrintableEvent("", color="magenta", end="") + self.is_thinking = False + yield InferenceStreamPrintableEvent(event.delta.text, color="yellow", end="") + elif event.event_type == "complete": + yield InferenceStreamPrintableEvent("") + + def _handle_chat_completion_chunk( + self, chunk: ChatCompletionChunk + ) -> Generator[InferenceStreamPrintableEvent, None, None]: + choice = chunk.choices[0] + delta = choice.delta + if delta: + if delta.role: + yield InferenceStreamPrintableEvent(f"{delta.role}> ", color="cyan", end="") + if delta.content: + yield InferenceStreamPrintableEvent(delta.content, color="yellow", end="") + if choice.finish_reason: + if choice.finish_reason == "length": + yield InferenceStreamPrintableEvent("", color="red", end="") + yield InferenceStreamPrintableEvent() + + +class EventLogger: + def log(self, event_generator): + printer = InferenceStreamLogEventPrinter() + for chunk in event_generator: + yield from printer.yield_printable_events(chunk) diff --git a/src/llama_stack_client/lib/inference/utils.py b/src/llama_stack_client/lib/inference/utils.py new file mode 100644 index 00000000..24ed7cd1 --- /dev/null +++ b/src/llama_stack_client/lib/inference/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pathlib +import base64 + + +class MessageAttachment: + # https://developer.mozilla.org/en-US/docs/Glossary/Base64 + @classmethod + def base64(cls, file_path: str) -> str: + path = pathlib.Path(file_path) + return base64.b64encode(path.read_bytes()).decode("utf-8") + + # https://developer.mozilla.org/en-US/docs/Web/URI/Schemes/data + @classmethod + def data_url(cls, media_type: str, file_path: str) -> str: + return f"data:{media_type};base64,{cls.base64(file_path)}" diff --git a/src/llama_stack_client/lib/inline/inline.py b/src/llama_stack_client/lib/inline/inline.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llama_stack_client/lib/stream_printer.py b/src/llama_stack_client/lib/stream_printer.py new file mode 100644 index 00000000..a08d9663 --- /dev/null +++ b/src/llama_stack_client/lib/stream_printer.py @@ -0,0 +1,24 @@ +from .agents.event_logger import TurnStreamEventPrinter +from .inference.event_logger import InferenceStreamLogEventPrinter + + +class EventStreamPrinter: + @classmethod + def gen(cls, event_generator): + inference_printer = None + turn_printer = None + for chunk in event_generator: + if not hasattr(chunk, "event"): + raise ValueError(f"Unexpected chunk without event: {chunk}") + + event = chunk.event + if hasattr(event, "event_type"): + if not inference_printer: + inference_printer = InferenceStreamLogEventPrinter() + yield from inference_printer.yield_printable_events(chunk) + elif hasattr(event, "payload") and hasattr(event.payload, "event_type"): + if not turn_printer: + turn_printer = TurnStreamEventPrinter() + yield from turn_printer.yield_printable_events(chunk) + else: + raise ValueError(f"Unsupported event: {event}") diff --git a/src/llama_stack_client/lib/tools/mcp_oauth.py b/src/llama_stack_client/lib/tools/mcp_oauth.py new file mode 100644 index 00000000..a3c03416 --- /dev/null +++ b/src/llama_stack_client/lib/tools/mcp_oauth.py @@ -0,0 +1,297 @@ +import asyncio +import base64 +import hashlib +import logging +import os +import socket +import threading +import time +import urllib.parse +import uuid +from http.server import BaseHTTPRequestHandler, HTTPServer + +import fire +import requests + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class McpOAuthHelper: + """A simpler helper for OAuth2 authentication with MCP servers with OAuth discovery.""" + + def __init__(self, server_url): + self.server_url = server_url + self.server_base_url = get_base_url(server_url) + self.access_token = None + + # For PKCE (Proof Key for Code Exchange) + self.code_verifier = None + self.code_challenge = None + + # OAuth client registration + self.client_id = None + self.client_secret = None + self.registered_redirect_uris = [] + + # Callback server + self.callback_port = find_available_port(8000, 8100) + self.redirect_uri = f"http://localhost:{self.callback_port}/callback" + self.auth_code = None + self.auth_error = None + self.http_server = None + self.server_thread = None + + # Software statement for DCR + self.software_statement = { + "software_id": "simple-mcp-client", + "software_version": "1.0.0", + "software_name": "Simple MCP Client Example", + "software_description": "A simple MCP client for demonstration purposes", + "software_uri": "https://github.com/example/simple-mcp-client", + "redirect_uris": [self.redirect_uri], + "client_name": "Simple MCP Client", + "client_uri": "https://example.com/mcp-client", + "token_endpoint_auth_method": "none", # Public client + } + + def discover_auth_endpoints(self): + """ + Discover the OAuth server metadata according to RFC8414. + MCP servers MUST support this discovery mechanism. + """ + well_known_url = f"{self.server_base_url}/.well-known/oauth-authorization-server" + response = requests.get(well_known_url) + if response.status_code == 200: + metadata = response.json() + logger.info("✅ Successfully discovered OAuth metadata") + return metadata + + raise Exception(f"OAuth metadata discovery failed with status: {response.status_code}") + + def register_client(self, registration_endpoint): + headers = {"Content-Type": "application/json"} + + registration_request = { + "client_name": self.software_statement["client_name"], + "redirect_uris": [self.redirect_uri], + "token_endpoint_auth_method": "none", # Public client + "grant_types": ["authorization_code"], + "response_types": ["code"], + "scope": "openid", + "software_id": self.software_statement["software_id"], + "software_version": self.software_statement["software_version"], + } + + response = requests.post(registration_endpoint, headers=headers, json=registration_request) + + if response.status_code in (201, 200): + registration_data = response.json() + self.client_id = registration_data.get("client_id") + self.client_secret = registration_data.get("client_secret") + self.registered_redirect_uris = registration_data.get("redirect_uris", [self.redirect_uri]) + + logger.info(f"Client ID: {self.client_id}") + return registration_data + + raise Exception(f"Client registration failed: {response.status_code}") + + def generate_pkce_values(self): + """Generate PKCE code verifier and challenge.""" + # Generate a random code verifier + code_verifier = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8").rstrip("=") + + # Generate the code challenge using SHA-256 + code_challenge_digest = hashlib.sha256(code_verifier.encode("utf-8")).digest() + code_challenge = base64.urlsafe_b64encode(code_challenge_digest).decode("utf-8").rstrip("=") + + self.code_verifier = code_verifier + self.code_challenge = code_challenge + + return code_verifier, code_challenge + + def stop_server(self): + time.sleep(1) + if self.http_server: + self.http_server.shutdown() + + def start_callback_server(self): + def auth_callback(auth_code: str | None, error: str | None): + logger.info(f"Authorization callback received: auth_code={auth_code}, error={error}") + self.auth_code = auth_code + self.auth_error = error + threading.Thread(target=self.stop_server).start() + + self.http_server = CallbackServer(("localhost", self.callback_port), auth_callback) + + self.server_thread = threading.Thread(target=self.http_server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + + logger.info(f"🌐 Callback server started on port {self.callback_port}") + + def exchange_code_for_token(self, auth_code, token_endpoint): + logger.info("Exchanging authorization code for access token...") + + data = { + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + "code_verifier": self.code_verifier, + } + if self.client_secret: + data["client_secret"] = self.client_secret + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + response = requests.post(token_endpoint, data=data, headers=headers) + if response.status_code == 200: + token_data = response.json() + self.access_token = token_data.get("access_token") + logger.info(f"✅ Successfully obtained access token: {self.access_token}") + return self.access_token + + raise Exception(f"Failed to exchange code for token: {response.status_code}") + + def initiate_auth_flow(self): + auth_metadata = self.discover_auth_endpoints() + registration_endpoint = auth_metadata.get("registration_endpoint") + if registration_endpoint and not self.client_id: + self.register_client(registration_endpoint) + + self.generate_pkce_values() + + self.start_callback_server() + + auth_url = auth_metadata.get("authorization_endpoint") + if not auth_url: + raise Exception("No authorization endpoint in metadata") + + token_endpoint = auth_metadata.get("token_endpoint") + if not token_endpoint: + raise Exception("No token endpoint in metadata") + + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + "state": str(uuid.uuid4()), # Random state + "code_challenge": self.code_challenge, + "code_challenge_method": "S256", + "scope": "openid", # Add appropriate scopes for Asana + } + + full_auth_url = f"{auth_url}?{urllib.parse.urlencode(params)}" + logger.info(f"Opening browser to authorize URL: {full_auth_url}") + logger.info("Flow will continue after you log in") + + import webbrowser + + webbrowser.open(full_auth_url) + self.server_thread.join(60) # Wait up to 1 minute + + if self.auth_code: + return self.exchange_code_for_token(self.auth_code, token_endpoint) + elif self.auth_error: + logger.error(f"Authorization failed: {self.auth_error}") + return None + else: + logger.error("Timed out waiting for authorization") + return None + + +def get_base_url(url): + parsed_url = urllib.parse.urlparse(url) + return f"{parsed_url.scheme}://{parsed_url.netloc}" + + +def find_available_port(start_port, end_port): + """Find an available port within a range.""" + for port in range(start_port, end_port + 1): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except socket.error: + continue + raise RuntimeError(f"No available ports in range {start_port}-{end_port}") + + +class CallbackServer(HTTPServer): + class OAuthCallbackHandler(BaseHTTPRequestHandler): + def do_GET(self): + parsed_path = urllib.parse.urlparse(self.path) + query_params = urllib.parse.parse_qs(parsed_path.query) + + if parsed_path.path == "/callback": + auth_code = query_params.get("code", [None])[0] + error = query_params.get("error", [None])[0] + + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + if error: + self.wfile.write(b"Authorization Failed") + self.wfile.write(f"

Authorization Failed

Error: {error}

".encode()) + self.server.auth_code_callback(None, error) + elif auth_code: + self.wfile.write(b"Authorization Successful") + self.wfile.write( + b"

Authorization Successful

You can close this window now.

" + ) + # Call the callback with the auth code + self.server.auth_code_callback(auth_code, None) + else: + self.wfile.write(b"Authorization Failed") + self.wfile.write( + b"

Authorization Failed

No authorization code received.

" + ) + self.server.auth_code_callback(None, "No authorization code received") + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Override to suppress HTTP server logs.""" + return + + def __init__(self, server_address, auth_code_callback): + self.auth_code_callback = auth_code_callback + super().__init__(server_address, self.OAuthCallbackHandler) + + +def get_oauth_token_for_mcp_server(url: str) -> str | None: + helper = McpOAuthHelper(url) + return helper.initiate_auth_flow() + + +async def run_main(url: str): + from mcp import ClientSession + from mcp.client.sse import sse_client + + token = get_oauth_token_for_mcp_server(url) + if not token: + return + + headers = { + "Authorization": f"Bearer {token}", + } + + async with sse_client(url, headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + result = await session.list_tools() + + logger.info(f"Tools: {len(result.tools)}, showing first 5:") + for t in result.tools[:5]: + logger.info(f"{t.name}: {t.description}") + + +def main(url: str): + asyncio.run(run_main(url)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py index 160b1d8a..84a8dd96 100644 --- a/src/llama_stack_client/resources/inference.py +++ b/src/llama_stack_client/resources/inference.py @@ -408,6 +408,8 @@ def chat_completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletionResponse | Stream[ChatCompletionResponseStreamChunk]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return self._post( "/v1/inference/chat-completion", body=maybe_transform( @@ -590,6 +592,8 @@ def completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> CompletionResponse | Stream[CompletionResponse]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return self._post( "/v1/inference/completion", body=maybe_transform( @@ -1042,6 +1046,8 @@ async def chat_completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ChatCompletionResponse | AsyncStream[ChatCompletionResponseStreamChunk]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return await self._post( "/v1/inference/chat-completion", body=await async_maybe_transform( @@ -1224,6 +1230,8 @@ async def completion( extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> CompletionResponse | AsyncStream[CompletionResponse]: + if stream: + extra_headers = {"Accept": "text/event-stream", **(extra_headers or {})} return await self._post( "/v1/inference/completion", body=await async_maybe_transform( diff --git a/src/llama_stack_client/types/response_object.py b/src/llama_stack_client/types/response_object.py index bc1ab5bb..e4b313d3 100644 --- a/src/llama_stack_client/types/response_object.py +++ b/src/llama_stack_client/types/response_object.py @@ -251,6 +251,16 @@ class Error(BaseModel): class ResponseObject(BaseModel): + @property + def output_text(self) -> str: + texts: List[str] = [] + for output in self.output: + if output.type == "message": + for content in output.content: + if content.type == "output_text": + texts.append(content.text) + return "".join(texts) + id: str created_at: int