From 82962592b92b6ef6ce46c9a5c06c4006b0f5411c Mon Sep 17 00:00:00 2001 From: aribornstein Date: Sun, 2 Feb 2025 17:35:29 +0200 Subject: [PATCH 1/3] Add LlamaCppChatCompletionClient and llama-cpp dependency for enhanced chat capabilities --- python/packages/autogen-ext/pyproject.toml | 5 + .../autogen_ext/models/llama_cpp/__init__.py | 8 + .../llama_cpp/_llama_cpp_completion_client.py | 240 ++++++++++++++++++ python/uv.lock | 26 +- 4 files changed, 269 insertions(+), 10 deletions(-) create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 40acd71aff15..d5ac3454b8d3 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -31,6 +31,11 @@ file-surfer = [ "autogen-agentchat==0.4.3", "markitdown>=0.0.1a2", ] + +llama-cpp = [ + "llama-cpp-python" +] + graphrag = ["graphrag>=1.0.1"] web-surfer = [ "autogen-agentchat==0.4.3", diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py new file mode 100644 index 000000000000..c0182d5c0c1d --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py @@ -0,0 +1,8 @@ +try: + from ._llama_cpp_completion_client import LlamaCppChatCompletionClient +except ImportError as e: + raise ImportError( + "Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install llama-cpp-python" + ) from e + +__all__ = ["LlamaCppChatCompletionClient"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py new file mode 100644 index 000000000000..7f2428bd7f32 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py @@ -0,0 +1,240 @@ +import json +import logging # added import +from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union + +from autogen_core import CancellationToken +from autogen_core.models import AssistantMessage, ChatCompletionClient, CreateResult, SystemMessage, UserMessage +from autogen_core.tools import Tool +from llama_cpp import Llama +from pydantic import BaseModel + + +class ComponentModel(BaseModel): + provider: str + component_type: Optional[Literal["model", "agent", "tool", "termination", "token_provider"]] = None + version: Optional[int] = None + component_version: Optional[int] = None + description: Optional[str] = None + config: Dict[str, Any] + + +class LlamaCppChatCompletionClient(ChatCompletionClient): + def __init__( + self, + repo_id: str, + filename: str, + n_gpu_layers: int = -1, + seed: int = 1337, + n_ctx: int = 1000, + verbose: bool = True, + ): + """ + Initialize the LlamaCpp client. + """ + self.logger = logging.getLogger(__name__) # initialize logger + self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # set level based on verbosity + self.llm = Llama.from_pretrained( + repo_id=repo_id, + filename=filename, + n_gpu_layers=n_gpu_layers, + seed=seed, + n_ctx=n_ctx, + verbose=verbose, + ) + self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} + + async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> CreateResult: + """ + Generate a response using the model, incorporating tool metadata. + + :param messages: A list of message objects to process. + :param tools: A list of tool objects to register dynamically. + :param kwargs: Additional arguments for the model. + :return: A CreateResult object containing the model's response. + """ + tools = tools or [] + + # Convert LLMMessage objects to dictionaries with 'role' and 'content' + converted_messages = [] + for msg in messages: + if isinstance(msg, SystemMessage): + converted_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, UserMessage): + converted_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AssistantMessage): + converted_messages.append({"role": "assistant", "content": msg.content}) + else: + raise ValueError(f"Unsupported message type: {type(msg)}") + + # Add tool descriptions to the system message + tool_descriptions = "\n".join( + [f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools)] + ) + + few_shot_example = """ + Example tool usage: + User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + """ + + system_message = ( + "You are an assistant with access to tools. " + "If a user query matches a tool, explicitly invoke it with JSON arguments. " + "Here are the tools available:\n" + f"{tool_descriptions}\n" + f"{few_shot_example}" + ) + converted_messages.insert(0, {"role": "system", "content": system_message}) + + # Debugging outputs + # print(f"DEBUG: System message: {system_message}") + # print(f"DEBUG: Converted messages: {converted_messages}") + + # Generate the model response + response = self.llm.create_chat_completion(messages=converted_messages, stream=False) + self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0) + self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0) + + # Parse the response + response_text = response["choices"][0]["message"]["content"] + # print(f"DEBUG: Model response: {response_text}") + + # Detect tool usage in the response + tool_call = await self._detect_and_execute_tool(response_text, tools) + if not tool_call: + self.logger.debug("DEBUG: No tool was invoked. Returning raw model response.") + else: + self.logger.debug(f"DEBUG: Tool executed successfully: {tool_call}") + + # Create a CreateResult object + create_result = CreateResult( + content=tool_call if tool_call else response_text, + usage=response.get("usage", {}), + finish_reason=response["choices"][0].get("finish_reason", "unknown"), + cached=False, + ) + return create_result + + async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) -> Optional[str]: + """ + Detect if the model is requesting a tool and execute the tool. + + :param response_text: The raw response text from the model. + :param tools: A list of available tools. + :return: The result of the tool execution or None if no tool is called. + """ + for tool in tools: + if tool.name.lower() in response_text.lower(): # Case-insensitive matching + self.logger.debug(f"DEBUG: Detected tool '{tool.name}' in response.") + # Extract arguments (if any) from the response + func_args = self._extract_tool_arguments(response_text) + if func_args: + self.logger.debug(f"DEBUG: Extracted arguments for tool '{tool.name}': {func_args}") + else: + self.logger.debug(f"DEBUG: No arguments found for tool '{tool.name}'.") + return f"Error: No valid arguments provided for tool '{tool.name}'." + + # Ensure arguments match the tool's args_type + try: + args_model = tool.args_type() + if "request" in args_model.__fields__: # Handle nested arguments + func_args = {"request": func_args} + args_instance = args_model(**func_args) + except Exception as e: + return f"Error parsing arguments for tool '{tool.name}': {e}" + + # Execute the tool + try: + result = await tool.run(args=args_instance, cancellation_token=CancellationToken()) + if isinstance(result, dict): + return json.dumps(result) + elif hasattr(result, "model_dump"): # If it's a Pydantic model + return json.dumps(result.model_dump()) + else: + return str(result) + except Exception as e: + return f"Error executing tool '{tool.name}': {e}" + + return None + + def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]: + """ + Extract tool arguments from the response text. + + :param response_text: The raw response text. + :return: A dictionary of extracted arguments. + """ + try: + args_start = response_text.find("{") + args_end = response_text.find("}") + if args_start != -1 and args_end != -1: + args_str = response_text[args_start : args_end + 1] + return json.loads(args_str) + except json.JSONDecodeError as e: + self.logger.debug(f"DEBUG: Failed to parse arguments: {e}") + return {} + + async def create_stream(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> AsyncGenerator[str, None]: + """ + Generate a streaming response using the model. + + :param messages: A list of messages to process. + :param tools: A list of tool objects to register dynamically. + :param kwargs: Additional arguments for the model. + :return: An asynchronous generator yielding the response stream. + """ + tools = tools or [] + + # Convert LLMMessage objects to dictionaries with 'role' and 'content' + converted_messages = [] + for msg in messages: + if isinstance(msg, SystemMessage): + converted_messages.append({"role": "system", "content": msg.content}) + elif isinstance(msg, UserMessage): + converted_messages.append({"role": "user", "content": msg.content}) + elif isinstance(msg, AssistantMessage): + converted_messages.append({"role": "assistant", "content": msg.content}) + else: + raise ValueError(f"Unsupported message type: {type(msg)}") + + # Add tool descriptions to the system message + tool_descriptions = "\n".join([f"Tool: {tool.name} - {tool.description}" for tool in tools]) + if tool_descriptions: + converted_messages.insert( + 0, {"role": "system", "content": f"The following tools are available:\n{tool_descriptions}"} + ) + + # Convert messages into a plain string prompt + prompt = "\n".join(f"{msg['role']}: {msg['content']}" for msg in converted_messages) + # Call the model with streaming enabled + response_generator = self.llm(prompt=prompt, stream=True) + + for token in response_generator: + yield token["choices"][0]["text"] + + # Implement abstract methods + def actual_usage(self) -> Dict[str, int]: + return self._total_usage + + @property + def capabilities(self) -> Dict[str, bool]: + return {"chat": True, "stream": True} + + def count_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: + return sum(len(msg["content"].split()) for msg in messages) + + @property + def model_info(self) -> Dict[str, Any]: + return { + "name": "llama-cpp", + "capabilities": {"chat": True, "stream": True}, + "context_window": self.llm.n_ctx, + "function_calling": True, + } + + def remaining_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: + used_tokens = self.count_tokens(messages) + return max(self.llm.n_ctx - used_tokens, 0) + + def total_usage(self) -> Dict[str, int]: + return self._total_usage diff --git a/python/uv.lock b/python/uv.lock index d01ba4f8d7fa..54fccef3d1c9 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -588,6 +588,9 @@ jupyter-executor = [ langchain = [ { name = "langchain-core" }, ] +llama-cpp = [ + { name = "llama-cpp-python" }, +] magentic-one = [ { name = "autogen-agentchat" }, { name = "markitdown" }, @@ -676,6 +679,7 @@ requires-dist = [ { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.62.0" }, { name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, + { name = "llama-cpp-python", marker = "extra == 'llama-cpp'" }, { name = "markitdown", marker = "extra == 'file-surfer'", specifier = ">=0.0.1a2" }, { name = "markitdown", marker = "extra == 'magentic-one'", specifier = ">=0.0.1a2" }, { name = "markitdown", marker = "extra == 'web-surfer'", specifier = ">=0.0.1a2" }, @@ -3200,6 +3204,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/0f/af106de1780cf526c96de1ba279edcb55a0376a4484a7dea206f9f038cc4/llama_cloud-0.1.8-py3-none-any.whl", hash = "sha256:1a0c4cf212a04f2375f1d0791ca4e5f196e0fb0567c4ec96cd9dbcad773de60a", size = 247083 }, ] +[[package]] +name = "llama-cpp-python" +version = "0.3.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "diskcache" }, + { name = "jinja2" }, + { name = "numpy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/38/7a47b1fb1d83eaddd86ca8ddaf20f141cbc019faf7b425283d8e5ef710e5/llama_cpp_python-0.3.7.tar.gz", hash = "sha256:0566a0dcc0f38005c4093309a87f67c2452449522e3e17e15cd735a62957894c", size = 66715891 } + [[package]] name = "llama-index" version = "0.12.11" @@ -4233,7 +4249,6 @@ name = "nvidia-cublas-cu12" version = "12.4.5.8" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, ] @@ -4242,7 +4257,6 @@ name = "nvidia-cuda-cupti-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, ] @@ -4251,7 +4265,6 @@ name = "nvidia-cuda-nvrtc-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, ] @@ -4260,7 +4273,6 @@ name = "nvidia-cuda-runtime-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, ] @@ -4283,7 +4295,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, ] @@ -4292,7 +4303,6 @@ name = "nvidia-curand-cu12" version = "10.3.5.147" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, ] @@ -4306,7 +4316,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, ] @@ -4318,7 +4327,6 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, ] @@ -4335,7 +4343,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, ] @@ -4344,7 +4351,6 @@ name = "nvidia-nvtx-cu12" version = "12.4.127" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, ] From 729133fd7de3c94ed358ec86d49539702df7e9fc Mon Sep 17 00:00:00 2001 From: aribornstein Date: Sun, 2 Feb 2025 22:07:24 +0200 Subject: [PATCH 2/3] Added properly typed completion client and updated init and toml --- python/packages/autogen-ext/pyproject.toml | 2 +- .../autogen_ext/models/llama_cpp/__init__.py | 2 +- .../llama_cpp/_llama_cpp_completion_client.py | 233 +++++++++++------- 3 files changed, 145 insertions(+), 92 deletions(-) diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index e10ca0de03a2..830b40ca95b7 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -33,7 +33,7 @@ file-surfer = [ ] llama-cpp = [ - "llama-cpp-python" + "llama-cpp-python>=0.1.9" ] graphrag = ["graphrag>=1.0.1"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py index c0182d5c0c1d..be3bd2cbf601 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py @@ -2,7 +2,7 @@ from ._llama_cpp_completion_client import LlamaCppChatCompletionClient except ImportError as e: raise ImportError( - "Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install llama-cpp-python" + "Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install autogen-ext[llama-cpp]" ) from e __all__ = ["LlamaCppChatCompletionClient"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py index 7f2428bd7f32..8c5e61d8a975 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py @@ -1,74 +1,76 @@ import json import logging # added import -from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Sequence, Union +from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, cast from autogen_core import CancellationToken -from autogen_core.models import AssistantMessage, ChatCompletionClient, CreateResult, SystemMessage, UserMessage -from autogen_core.tools import Tool -from llama_cpp import Llama -from pydantic import BaseModel - - -class ComponentModel(BaseModel): - provider: str - component_type: Optional[Literal["model", "agent", "tool", "termination", "token_provider"]] = None - version: Optional[int] = None - component_version: Optional[int] = None - description: Optional[str] = None - config: Dict[str, Any] +from autogen_core.models import ( + AssistantMessage, + ChatCompletionClient, + CreateResult, + FunctionExecutionResultMessage, + ModelInfo, + RequestUsage, + SystemMessage, + UserMessage, +) +from autogen_core.tools import Tool, ToolSchema +from llama_cpp import ( + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestFunctionMessage, + ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, + CreateChatCompletionResponse, + Llama, +) class LlamaCppChatCompletionClient(ChatCompletionClient): def __init__( self, - repo_id: str, filename: str, - n_gpu_layers: int = -1, - seed: int = 1337, - n_ctx: int = 1000, verbose: bool = True, + **kwargs: Any, ): """ Initialize the LlamaCpp client. """ self.logger = logging.getLogger(__name__) # initialize logger self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # set level based on verbosity - self.llm = Llama.from_pretrained( - repo_id=repo_id, - filename=filename, - n_gpu_layers=n_gpu_layers, - seed=seed, - n_ctx=n_ctx, - verbose=verbose, - ) + self.llm = Llama(model_path=filename, **kwargs) self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} - async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> CreateResult: - """ - Generate a response using the model, incorporating tool metadata. - - :param messages: A list of message objects to process. - :param tools: A list of tool objects to register dynamically. - :param kwargs: Additional arguments for the model. - :return: A CreateResult object containing the model's response. - """ + async def create( + self, + messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], + tools: Optional[Sequence[Tool | ToolSchema]] = None, + **kwargs: Any, + ) -> CreateResult: tools = tools or [] # Convert LLMMessage objects to dictionaries with 'role' and 'content' - converted_messages = [] + # converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = [] + converted_messages: list[ + ChatCompletionRequestSystemMessage + | ChatCompletionRequestUserMessage + | ChatCompletionRequestAssistantMessage + | ChatCompletionRequestUserMessage + | ChatCompletionRequestToolMessage + | ChatCompletionRequestFunctionMessage + ] = [] for msg in messages: if isinstance(msg, SystemMessage): converted_messages.append({"role": "system", "content": msg.content}) - elif isinstance(msg, UserMessage): + elif isinstance(msg, UserMessage) and isinstance(msg.content, str): converted_messages.append({"role": "user", "content": msg.content}) - elif isinstance(msg, AssistantMessage): + elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str): converted_messages.append({"role": "assistant", "content": msg.content}) else: raise ValueError(f"Unsupported message type: {type(msg)}") # Add tool descriptions to the system message tool_descriptions = "\n".join( - [f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools)] + [f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)] ) few_shot_example = """ @@ -91,7 +93,9 @@ async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) - # print(f"DEBUG: Converted messages: {converted_messages}") # Generate the model response - response = self.llm.create_chat_completion(messages=converted_messages, stream=False) + response = cast( + CreateChatCompletionResponse, self.llm.create_chat_completion(messages=converted_messages, stream=False) + ) self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0) self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0) @@ -100,17 +104,29 @@ async def create(self, messages: List[Any], tools: List[Any] = None, **kwargs) - # print(f"DEBUG: Model response: {response_text}") # Detect tool usage in the response - tool_call = await self._detect_and_execute_tool(response_text, tools) + if not response_text: + self.logger.debug("DEBUG: No response text found. Returning empty response.") + return CreateResult( + content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False + ) + + tool_call = await self._detect_and_execute_tool( + response_text, [tool for tool in tools if isinstance(tool, Tool)] + ) if not tool_call: self.logger.debug("DEBUG: No tool was invoked. Returning raw model response.") else: self.logger.debug(f"DEBUG: Tool executed successfully: {tool_call}") # Create a CreateResult object + finish_reason = response["choices"][0].get("finish_reason") + if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"): + finish_reason = "unknown" + usage = cast(RequestUsage, response.get("usage", {})) create_result = CreateResult( content=tool_call if tool_call else response_text, - usage=response.get("usage", {}), - finish_reason=response["choices"][0].get("finish_reason", "unknown"), + usage=usage, + finish_reason=finish_reason, # type: ignore cached=False, ) return create_result @@ -137,7 +153,7 @@ async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) # Ensure arguments match the tool's args_type try: args_model = tool.args_type() - if "request" in args_model.__fields__: # Handle nested arguments + if "request" in args_model.model_fields: # Handle nested arguments func_args = {"request": func_args} args_instance = args_model(**func_args) except Exception as e: @@ -145,13 +161,14 @@ async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) # Execute the tool try: - result = await tool.run(args=args_instance, cancellation_token=CancellationToken()) - if isinstance(result, dict): - return json.dumps(result) - elif hasattr(result, "model_dump"): # If it's a Pydantic model - return json.dumps(result.model_dump()) - else: - return str(result) + if callable(getattr(tool, "run", None)): + result = await cast(Any, tool).run(args=args_instance, cancellation_token=CancellationToken()) + if isinstance(result, dict): + return json.dumps(result) + elif callable(getattr(result, "model_dump", None)): # If it's a Pydantic model + return json.dumps(result.model_dump()) + else: + return str(result) except Exception as e: return f"Error executing tool '{tool.name}': {e}" @@ -169,72 +186,108 @@ def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]: args_end = response_text.find("}") if args_start != -1 and args_end != -1: args_str = response_text[args_start : args_end + 1] - return json.loads(args_str) + args = json.loads(args_str) + if isinstance(args, dict): + return cast(Dict[str, Any], args) + else: + return {} except json.JSONDecodeError as e: self.logger.debug(f"DEBUG: Failed to parse arguments: {e}") return {} - async def create_stream(self, messages: List[Any], tools: List[Any] = None, **kwargs) -> AsyncGenerator[str, None]: - """ - Generate a streaming response using the model. - - :param messages: A list of messages to process. - :param tools: A list of tool objects to register dynamically. - :param kwargs: Additional arguments for the model. - :return: An asynchronous generator yielding the response stream. - """ + async def create_stream( + self, + messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], + tools: Optional[Sequence[Tool | ToolSchema]] = None, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: tools = tools or [] # Convert LLMMessage objects to dictionaries with 'role' and 'content' - converted_messages = [] + converted_messages: list[ + ChatCompletionRequestSystemMessage + | ChatCompletionRequestUserMessage + | ChatCompletionRequestAssistantMessage + | ChatCompletionRequestUserMessage + | ChatCompletionRequestToolMessage + | ChatCompletionRequestFunctionMessage + ] = [] for msg in messages: if isinstance(msg, SystemMessage): converted_messages.append({"role": "system", "content": msg.content}) - elif isinstance(msg, UserMessage): + elif isinstance(msg, UserMessage) and isinstance(msg.content, str): converted_messages.append({"role": "user", "content": msg.content}) - elif isinstance(msg, AssistantMessage): + elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str): converted_messages.append({"role": "assistant", "content": msg.content}) else: raise ValueError(f"Unsupported message type: {type(msg)}") # Add tool descriptions to the system message - tool_descriptions = "\n".join([f"Tool: {tool.name} - {tool.description}" for tool in tools]) - if tool_descriptions: - converted_messages.insert( - 0, {"role": "system", "content": f"The following tools are available:\n{tool_descriptions}"} - ) + tool_descriptions = "\n".join( + [f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)] + ) + few_shot_example = """ + Example tool usage: + User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + """ + + system_message = ( + "You are an assistant with access to tools. " + "If a user query matches a tool, explicitly invoke it with JSON arguments. " + "Here are the tools available:\n" + f"{tool_descriptions}\n" + f"{few_shot_example}" + ) + converted_messages.insert(0, {"role": "system", "content": system_message}) # Convert messages into a plain string prompt - prompt = "\n".join(f"{msg['role']}: {msg['content']}" for msg in converted_messages) + prompt = "\n".join(f"{msg['role']}: {msg.get('content', '')}" for msg in converted_messages) # Call the model with streaming enabled response_generator = self.llm(prompt=prompt, stream=True) for token in response_generator: - yield token["choices"][0]["text"] + if isinstance(token, dict): + yield token["choices"][0]["text"] + else: + yield token # Implement abstract methods - def actual_usage(self) -> Dict[str, int]: - return self._total_usage + def actual_usage(self) -> RequestUsage: + return RequestUsage( + prompt_tokens=self._total_usage.get("prompt_tokens", 0), + completion_tokens=self._total_usage.get("completion_tokens", 0), + ) @property - def capabilities(self) -> Dict[str, bool]: - return {"chat": True, "stream": True} - - def count_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: - return sum(len(msg["content"].split()) for msg in messages) + def capabilities(self) -> ModelInfo: + return self.model_info + def count_tokens( + self, + messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], + **kwargs: Any, + ) -> int: + total = 0 + for msg in messages: + # Use the Llama model's tokenizer to encode the content + tokens = self.llm.tokenize(str(msg.content).encode("utf-8")) + total += len(tokens) + return total @property - def model_info(self) -> Dict[str, Any]: - return { - "name": "llama-cpp", - "capabilities": {"chat": True, "stream": True}, - "context_window": self.llm.n_ctx, - "function_calling": True, - } - - def remaining_tokens(self, messages: Sequence[Dict[str, Any]], **kwargs) -> int: + def model_info(self) -> ModelInfo: + return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True) + + def remaining_tokens( + self, + messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], + **kwargs: Any, + ) -> int: used_tokens = self.count_tokens(messages) - return max(self.llm.n_ctx - used_tokens, 0) + return max(self.llm.n_ctx() - used_tokens, 0) - def total_usage(self) -> Dict[str, int]: - return self._total_usage + def total_usage(self) -> RequestUsage: + return RequestUsage( + prompt_tokens=self._total_usage.get("prompt_tokens", 0), + completion_tokens=self._total_usage.get("completion_tokens", 0), + ) From c708cb133541acdeee011dd02569a04544a9b412 Mon Sep 17 00:00:00 2001 From: aribornstein Date: Sat, 22 Feb 2025 20:23:23 +0200 Subject: [PATCH 3/3] feat: enhance LlamaCppChatCompletionClient with improved initialization and error handling; add unit tests for functionality --- .../autogen_ext/models/llama_cpp/__init__.py | 4 +- .../llama_cpp/_llama_cpp_completion_client.py | 22 +- .../models/test_llama_cpp_model_client.py | 194 ++++++++++++++++++ 3 files changed, 213 insertions(+), 7 deletions(-) create mode 100644 python/packages/autogen-ext/tests/models/test_llama_cpp_model_client.py diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py index be3bd2cbf601..0324e4005a09 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py @@ -2,7 +2,9 @@ from ._llama_cpp_completion_client import LlamaCppChatCompletionClient except ImportError as e: raise ImportError( - "Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install autogen-ext[llama-cpp]" + "Dependencies for Llama Cpp not found. " + "Please install llama-cpp-python: " + "pip install autogen-ext[llama-cpp]" ) from e __all__ = ["LlamaCppChatCompletionClient"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py index 8c5e61d8a975..45d91c29a1f4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py @@ -37,7 +37,12 @@ def __init__( """ self.logger = logging.getLogger(__name__) # initialize logger self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # set level based on verbosity - self.llm = Llama(model_path=filename, **kwargs) + self.llm = ( + Llama.from_pretrained(filename=filename, repo_id=kwargs.pop("repo_id"), **kwargs) # type: ignore + # The partially unknown type is in the `llama_cpp` package + if "repo_id" in kwargs + else Llama(model_path=filename, **kwargs) + ) self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} async def create( @@ -75,8 +80,8 @@ async def create( few_shot_example = """ Example tool usage: - User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} - Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + User: Add two numbers: {"num1": 5, "num2": 10} + Assistant: Calling tool 'add' with arguments: {"num1": 5, "num2": 10} """ system_message = ( @@ -224,13 +229,17 @@ async def create_stream( # Add tool descriptions to the system message tool_descriptions = "\n".join( - [f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)] + [ + f"Tool: {i+1}. {tool.name}({tool.schema['parameters'] if 'parameters' in tool.schema else ''}) - {tool.description}" + for i, tool in enumerate(tools) + if isinstance(tool, Tool) + ] ) few_shot_example = """ Example tool usage: - User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} - Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} + User: Add two numbers: {"num1": 5, "num2": 10} + Assistant: Calling tool 'add' with arguments: {"num1": 5, "num2": 10} """ system_message = ( @@ -262,6 +271,7 @@ def actual_usage(self) -> RequestUsage: @property def capabilities(self) -> ModelInfo: return self.model_info + def count_tokens( self, messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], diff --git a/python/packages/autogen-ext/tests/models/test_llama_cpp_model_client.py b/python/packages/autogen-ext/tests/models/test_llama_cpp_model_client.py new file mode 100644 index 000000000000..a2ad51b847f4 --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_llama_cpp_model_client.py @@ -0,0 +1,194 @@ +import contextlib +import sys +from typing import TYPE_CHECKING, Any, ContextManager, Generator, Sequence, Union + +import pytest +import torch +from autogen_agentchat.agents import AssistantAgent +from autogen_agentchat.messages import TextMessage +from autogen_core import CancellationToken +from autogen_core.models import RequestUsage, SystemMessage, UserMessage +from autogen_core.tools import FunctionTool + +if TYPE_CHECKING: + from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient + + +# Fake Llama class to simulate responses +class FakeLlama: + def __init__( + self, + model_path: str, + **_: Any, + ) -> None: + self.model_path = model_path + self.n_ctx = lambda: 1024 + + # Added tokenize method for testing purposes. + def tokenize(self, b: bytes) -> list[int]: + return list(b) + + def create_chat_completion(self, messages: Any, stream: bool = False) -> dict[str, Any]: + # Return fake non-streaming response. + return { + "usage": {"prompt_tokens": 1, "completion_tokens": 2}, + "choices": [{"message": {"content": "Fake response"}}], + } + + def __call__(self, prompt: str, stream: bool = True) -> Generator[dict[str, Any], None, None]: + # Yield fake streaming tokens. + yield {"choices": [{"text": "Hello "}]} + yield {"choices": [{"text": "World"}]} + + +@pytest.fixture +@contextlib.contextmanager +def get_completion_client( + monkeypatch: pytest.MonkeyPatch, +) -> "Generator[type[LlamaCppChatCompletionClient], None, None]": + with monkeypatch.context() as m: + m.setattr("llama_cpp.Llama", FakeLlama) + from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient + + yield LlamaCppChatCompletionClient + sys.modules.pop("autogen_ext.models.llama_cpp._llama_cpp_completion_client", None) + sys.modules.pop("llama_cpp", None) + + +@pytest.mark.asyncio +async def test_llama_cpp_create(get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]") -> None: + with get_completion_client as Client: + client = Client(filename="dummy") + messages: Sequence[Union[SystemMessage, UserMessage]] = [ + SystemMessage(content="Test system"), + UserMessage(content="Test user", source="user"), + ] + result = await client.create(messages=messages) + assert result.content == "Fake response" + usage: RequestUsage = result.usage + assert usage.prompt_tokens == 1 + assert usage.completion_tokens == 2 + assert result.finish_reason in ("stop", "unknown") + + +@pytest.mark.asyncio +async def test_llama_cpp_create_stream( + get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]", +) -> None: + with get_completion_client as Client: + client = Client(filename="dummy") + messages: Sequence[Union[SystemMessage, UserMessage]] = [ + SystemMessage(content="Test system"), + UserMessage(content="Test user", source="user"), + ] + collected = "" + async for token in client.create_stream(messages=messages): + collected += token + assert collected == "Hello World" + + +@pytest.mark.asyncio +async def test_create_invalid_message( + get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]", +) -> None: + with get_completion_client as Client: + client = Client(filename="dummy") + # Pass an unsupported message type (integer) to trigger ValueError. + with pytest.raises(ValueError, match="Unsupported message type"): + await client.create(messages=[123]) # type: ignore + + +@pytest.mark.asyncio +async def test_count_and_remaining_tokens( + get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]", monkeypatch: pytest.MonkeyPatch +) -> None: + with get_completion_client as Client: + client = Client(filename="dummy") + msg = SystemMessage(content="Test") + # count_tokens should count the bytes + token_count = client.count_tokens([msg]) + # Since "Test" encoded is 4 bytes, expect 4 tokens. + assert token_count >= 4 + remaining = client.remaining_tokens([msg]) + # remaining should be (1024 - token_count); ensure non-negative. + assert remaining == max(1024 - token_count, 0) + + +@pytest.mark.asyncio +async def test_llama_cpp_integration_non_streaming() -> None: + if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()): + pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set") + + from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient + + client = LlamaCppChatCompletionClient( + repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000 + ) + messages: Sequence[Union[SystemMessage, UserMessage]] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Hello, how are you?", source="user"), + ] + result = await client.create(messages=messages) + assert isinstance(result.content, str) and len(result.content.strip()) > 0 + + +@pytest.mark.asyncio +async def test_llama_cpp_integration_streaming() -> None: + if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()): + pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set") + + from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient + + client = LlamaCppChatCompletionClient( + repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000 + ) + messages: Sequence[Union[SystemMessage, UserMessage]] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Please stream your response.", source="user"), + ] + collected = "" + async for token in client.create_stream(messages=messages): + collected += token + assert isinstance(collected, str) and len(collected.strip()) > 0 + + +# Define tools (functions) for the AssistantAgent +def add(num1: int, num2: int) -> int: + """Add two numbers together""" + return num1 + num2 + + +@pytest.mark.asyncio +async def test_llama_cpp_integration_tool_use() -> None: + if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()): + pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set") + + from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient + + model_client = LlamaCppChatCompletionClient( + repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000 + ) + + # Initialize the AssistantAgent + assistant = AssistantAgent( + name="assistant", + system_message=("You can add two numbers together using the `add` function. "), + model_client=model_client, + tools=[ + FunctionTool( + add, + description="Add two numbers together. The first argument is num1 and second is num2. The return value is num1 + num2", + ) + ], + reflect_on_tool_use=True, # Reflect on tool results + ) + + # Test the tool + response = await assistant.on_messages( + [ + TextMessage(content="add 3 and 4", source="user"), + ], + CancellationToken(), + ) + + assert "7" in response.chat_message.content