-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature add Add LlamaCppChatCompletionClient and llama-cpp #5326
Open
aribornstein
wants to merge
4
commits into
microsoft:main
Choose a base branch
from
aribornstein:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+528
−10
Open
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
8296259
Add LlamaCppChatCompletionClient and llama-cpp dependency for enhance…
xhabit 8646d54
Merge branch 'main' into main
aribornstein 729133f
Added properly typed completion client and updated init and toml
xhabit c708cb1
feat: enhance LlamaCppChatCompletionClient with improved initializati…
xhabit File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
try: | ||
from ._llama_cpp_completion_client import LlamaCppChatCompletionClient | ||
except ImportError as e: | ||
raise ImportError( | ||
"Dependencies for Llama Cpp not found. " "Please install llama-cpp-python: " "pip install autogen-ext[llama-cpp]" | ||
) from e | ||
|
||
__all__ = ["LlamaCppChatCompletionClient"] |
293 changes: 293 additions & 0 deletions
293
python/packages/autogen-ext/src/autogen_ext/models/llama_cpp/_llama_cpp_completion_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
import json | ||
import logging # added import | ||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, cast | ||
|
||
from autogen_core import CancellationToken | ||
from autogen_core.models import ( | ||
AssistantMessage, | ||
ChatCompletionClient, | ||
CreateResult, | ||
FunctionExecutionResultMessage, | ||
ModelInfo, | ||
RequestUsage, | ||
SystemMessage, | ||
UserMessage, | ||
) | ||
from autogen_core.tools import Tool, ToolSchema | ||
from llama_cpp import ( | ||
ChatCompletionRequestAssistantMessage, | ||
ChatCompletionRequestFunctionMessage, | ||
ChatCompletionRequestSystemMessage, | ||
ChatCompletionRequestToolMessage, | ||
ChatCompletionRequestUserMessage, | ||
CreateChatCompletionResponse, | ||
Llama, | ||
) | ||
|
||
|
||
class LlamaCppChatCompletionClient(ChatCompletionClient): | ||
def __init__( | ||
self, | ||
filename: str, | ||
verbose: bool = True, | ||
**kwargs: Any, | ||
aribornstein marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
""" | ||
Initialize the LlamaCpp client. | ||
""" | ||
self.logger = logging.getLogger(__name__) # initialize logger | ||
self.logger.setLevel(logging.DEBUG if verbose else logging.INFO) # set level based on verbosity | ||
self.llm = Llama(model_path=filename, **kwargs) | ||
self._total_usage = {"prompt_tokens": 0, "completion_tokens": 0} | ||
|
||
async def create( | ||
self, | ||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], | ||
tools: Optional[Sequence[Tool | ToolSchema]] = None, | ||
**kwargs: Any, | ||
) -> CreateResult: | ||
tools = tools or [] | ||
|
||
# Convert LLMMessage objects to dictionaries with 'role' and 'content' | ||
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = [] | ||
converted_messages: list[ | ||
ChatCompletionRequestSystemMessage | ||
| ChatCompletionRequestUserMessage | ||
| ChatCompletionRequestAssistantMessage | ||
| ChatCompletionRequestUserMessage | ||
| ChatCompletionRequestToolMessage | ||
| ChatCompletionRequestFunctionMessage | ||
] = [] | ||
for msg in messages: | ||
if isinstance(msg, SystemMessage): | ||
converted_messages.append({"role": "system", "content": msg.content}) | ||
elif isinstance(msg, UserMessage) and isinstance(msg.content, str): | ||
converted_messages.append({"role": "user", "content": msg.content}) | ||
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str): | ||
converted_messages.append({"role": "assistant", "content": msg.content}) | ||
else: | ||
raise ValueError(f"Unsupported message type: {type(msg)}") | ||
|
||
# Add tool descriptions to the system message | ||
tool_descriptions = "\n".join( | ||
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)] | ||
) | ||
|
||
few_shot_example = """ | ||
Example tool usage: | ||
User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} | ||
Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} | ||
""" | ||
|
||
system_message = ( | ||
"You are an assistant with access to tools. " | ||
"If a user query matches a tool, explicitly invoke it with JSON arguments. " | ||
"Here are the tools available:\n" | ||
f"{tool_descriptions}\n" | ||
f"{few_shot_example}" | ||
) | ||
converted_messages.insert(0, {"role": "system", "content": system_message}) | ||
|
||
# Debugging outputs | ||
# print(f"DEBUG: System message: {system_message}") | ||
# print(f"DEBUG: Converted messages: {converted_messages}") | ||
|
||
# Generate the model response | ||
response = cast( | ||
CreateChatCompletionResponse, self.llm.create_chat_completion(messages=converted_messages, stream=False) | ||
) | ||
self._total_usage["prompt_tokens"] += response.get("usage", {}).get("prompt_tokens", 0) | ||
self._total_usage["completion_tokens"] += response.get("usage", {}).get("completion_tokens", 0) | ||
|
||
# Parse the response | ||
response_text = response["choices"][0]["message"]["content"] | ||
# print(f"DEBUG: Model response: {response_text}") | ||
|
||
# Detect tool usage in the response | ||
if not response_text: | ||
self.logger.debug("DEBUG: No response text found. Returning empty response.") | ||
return CreateResult( | ||
content="", usage=RequestUsage(prompt_tokens=0, completion_tokens=0), finish_reason="stop", cached=False | ||
) | ||
|
||
tool_call = await self._detect_and_execute_tool( | ||
response_text, [tool for tool in tools if isinstance(tool, Tool)] | ||
) | ||
if not tool_call: | ||
self.logger.debug("DEBUG: No tool was invoked. Returning raw model response.") | ||
else: | ||
self.logger.debug(f"DEBUG: Tool executed successfully: {tool_call}") | ||
|
||
# Create a CreateResult object | ||
finish_reason = response["choices"][0].get("finish_reason") | ||
if finish_reason not in ("stop", "length", "function_calls", "content_filter", "unknown"): | ||
finish_reason = "unknown" | ||
usage = cast(RequestUsage, response.get("usage", {})) | ||
create_result = CreateResult( | ||
content=tool_call if tool_call else response_text, | ||
usage=usage, | ||
finish_reason=finish_reason, # type: ignore | ||
cached=False, | ||
) | ||
return create_result | ||
|
||
async def _detect_and_execute_tool(self, response_text: str, tools: List[Tool]) -> Optional[str]: | ||
""" | ||
Detect if the model is requesting a tool and execute the tool. | ||
|
||
:param response_text: The raw response text from the model. | ||
:param tools: A list of available tools. | ||
:return: The result of the tool execution or None if no tool is called. | ||
""" | ||
for tool in tools: | ||
if tool.name.lower() in response_text.lower(): # Case-insensitive matching | ||
self.logger.debug(f"DEBUG: Detected tool '{tool.name}' in response.") | ||
# Extract arguments (if any) from the response | ||
func_args = self._extract_tool_arguments(response_text) | ||
if func_args: | ||
self.logger.debug(f"DEBUG: Extracted arguments for tool '{tool.name}': {func_args}") | ||
else: | ||
self.logger.debug(f"DEBUG: No arguments found for tool '{tool.name}'.") | ||
return f"Error: No valid arguments provided for tool '{tool.name}'." | ||
|
||
# Ensure arguments match the tool's args_type | ||
try: | ||
args_model = tool.args_type() | ||
if "request" in args_model.model_fields: # Handle nested arguments | ||
func_args = {"request": func_args} | ||
args_instance = args_model(**func_args) | ||
except Exception as e: | ||
return f"Error parsing arguments for tool '{tool.name}': {e}" | ||
|
||
# Execute the tool | ||
try: | ||
if callable(getattr(tool, "run", None)): | ||
result = await cast(Any, tool).run(args=args_instance, cancellation_token=CancellationToken()) | ||
if isinstance(result, dict): | ||
return json.dumps(result) | ||
elif callable(getattr(result, "model_dump", None)): # If it's a Pydantic model | ||
return json.dumps(result.model_dump()) | ||
else: | ||
return str(result) | ||
except Exception as e: | ||
return f"Error executing tool '{tool.name}': {e}" | ||
|
||
return None | ||
|
||
def _extract_tool_arguments(self, response_text: str) -> Dict[str, Any]: | ||
""" | ||
Extract tool arguments from the response text. | ||
|
||
:param response_text: The raw response text. | ||
:return: A dictionary of extracted arguments. | ||
""" | ||
try: | ||
args_start = response_text.find("{") | ||
args_end = response_text.find("}") | ||
if args_start != -1 and args_end != -1: | ||
args_str = response_text[args_start : args_end + 1] | ||
args = json.loads(args_str) | ||
if isinstance(args, dict): | ||
return cast(Dict[str, Any], args) | ||
else: | ||
return {} | ||
except json.JSONDecodeError as e: | ||
self.logger.debug(f"DEBUG: Failed to parse arguments: {e}") | ||
return {} | ||
|
||
async def create_stream( | ||
self, | ||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], | ||
tools: Optional[Sequence[Tool | ToolSchema]] = None, | ||
**kwargs: Any, | ||
) -> AsyncGenerator[str, None]: | ||
tools = tools or [] | ||
|
||
# Convert LLMMessage objects to dictionaries with 'role' and 'content' | ||
converted_messages: list[ | ||
ChatCompletionRequestSystemMessage | ||
| ChatCompletionRequestUserMessage | ||
| ChatCompletionRequestAssistantMessage | ||
| ChatCompletionRequestUserMessage | ||
| ChatCompletionRequestToolMessage | ||
| ChatCompletionRequestFunctionMessage | ||
] = [] | ||
for msg in messages: | ||
if isinstance(msg, SystemMessage): | ||
converted_messages.append({"role": "system", "content": msg.content}) | ||
elif isinstance(msg, UserMessage) and isinstance(msg.content, str): | ||
converted_messages.append({"role": "user", "content": msg.content}) | ||
elif isinstance(msg, AssistantMessage) and isinstance(msg.content, str): | ||
converted_messages.append({"role": "assistant", "content": msg.content}) | ||
else: | ||
raise ValueError(f"Unsupported message type: {type(msg)}") | ||
|
||
# Add tool descriptions to the system message | ||
tool_descriptions = "\n".join( | ||
[f"Tool: {i+1}. {tool.name} - {tool.description}" for i, tool in enumerate(tools) if isinstance(tool, Tool)] | ||
) | ||
|
||
few_shot_example = """ | ||
Example tool usage: | ||
User: Validate this request: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} | ||
Assistant: Calling tool 'validate_request' with arguments: {"patient_name": "John Doe", "patient_id": "12345", "procedure": "MRI Knee"} | ||
""" | ||
|
||
system_message = ( | ||
"You are an assistant with access to tools. " | ||
"If a user query matches a tool, explicitly invoke it with JSON arguments. " | ||
"Here are the tools available:\n" | ||
f"{tool_descriptions}\n" | ||
f"{few_shot_example}" | ||
) | ||
converted_messages.insert(0, {"role": "system", "content": system_message}) | ||
# Convert messages into a plain string prompt | ||
prompt = "\n".join(f"{msg['role']}: {msg.get('content', '')}" for msg in converted_messages) | ||
# Call the model with streaming enabled | ||
response_generator = self.llm(prompt=prompt, stream=True) | ||
|
||
for token in response_generator: | ||
if isinstance(token, dict): | ||
yield token["choices"][0]["text"] | ||
else: | ||
yield token | ||
|
||
# Implement abstract methods | ||
def actual_usage(self) -> RequestUsage: | ||
return RequestUsage( | ||
prompt_tokens=self._total_usage.get("prompt_tokens", 0), | ||
completion_tokens=self._total_usage.get("completion_tokens", 0), | ||
) | ||
|
||
@property | ||
def capabilities(self) -> ModelInfo: | ||
return self.model_info | ||
def count_tokens( | ||
self, | ||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], | ||
**kwargs: Any, | ||
) -> int: | ||
total = 0 | ||
for msg in messages: | ||
# Use the Llama model's tokenizer to encode the content | ||
tokens = self.llm.tokenize(str(msg.content).encode("utf-8")) | ||
total += len(tokens) | ||
return total | ||
|
||
@property | ||
def model_info(self) -> ModelInfo: | ||
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True) | ||
|
||
def remaining_tokens( | ||
self, | ||
messages: Sequence[SystemMessage | UserMessage | AssistantMessage | FunctionExecutionResultMessage], | ||
**kwargs: Any, | ||
) -> int: | ||
used_tokens = self.count_tokens(messages) | ||
return max(self.llm.n_ctx() - used_tokens, 0) | ||
|
||
def total_usage(self) -> RequestUsage: | ||
return RequestUsage( | ||
prompt_tokens=self._total_usage.get("prompt_tokens", 0), | ||
completion_tokens=self._total_usage.get("completion_tokens", 0), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add unit tests in the
python/packages/autogen-ext/tests
directoryThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will work on this tomorrow