Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Update Agent invoke abstraction parameters #11192

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def main():
# Invoke the agent
# The chat history is maintained in the session
async for response in bedrock_agent.invoke(
input_text=user_input,
messages=user_input,
thread=thread,
):
print(f"Bedrock agent: {response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def main():
# Invoke the agent
# The chat history is maintained in the session
response = await bedrock_agent.get_response(
input_text=user_input,
messages=user_input,
thread=thread,
)
print(f"Bedrock agent: {response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def main():
# Invoke the agent
# The chat history is maintained in the thread
print("Bedrock agent: ", end="")
async for response in bedrock_agent.invoke_stream(input_text=user_input, thread=thread):
async for response in bedrock_agent.invoke_stream(messages=user_input, thread=thread):
print(response, end="")
thread = response.thread
print()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def main():
try:
# Invoke the agent
async for response in bedrock_agent.invoke(
input_text=ASK,
messages=ASK,
thread=thread,
):
print(f"Response:\n{response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def main():
# Invoke the agent
print("Response: ")
async for response in bedrock_agent.invoke_stream(
input_text=ASK,
messages=ASK,
thread=thread,
):
print(response, end="")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def main():
try:
# Invoke the agent
async for response in bedrock_agent.invoke(
input_text="What is the weather in Seattle?",
messages="What is the weather in Seattle?",
thread=thread,
):
print(f"Response:\n{response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def main():
try:
# Invoke the agent
async for response in bedrock_agent.invoke(
input_text="What is the weather in Seattle?",
messages="What is the weather in Seattle?",
thread=thread,
):
print(f"Response:\n{response}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async def main():
# Invoke the agent
print("Response: ")
async for response in bedrock_agent.invoke_stream(
input_text="What is the weather in Seattle?",
messages="What is the weather in Seattle?",
thread=thread,
):
print(response, end="")
Expand Down
48 changes: 45 additions & 3 deletions python/semantic_kernel/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,13 @@ def _configure_plugins(cls, data: Any) -> Any:
return data

@abstractmethod
def get_response(self, *args, **kwargs) -> Awaitable[AgentResponseItem[ChatMessageContent]]:
def get_response(
self,
*,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
**kwargs,
) -> Awaitable[AgentResponseItem[ChatMessageContent]]:
"""Get a response from the agent.

This method returns the final result of the agent's execution
Expand All @@ -219,28 +225,64 @@ def get_response(self, *args, **kwargs) -> Awaitable[AgentResponseItem[ChatMessa
objects. Streaming only the final result is not feasible because the timing of
the final result's availability is unknown, and blocking the caller until then
is undesirable in streaming scenarios.

Args:
messages: The message(s) to send to the agent.
thread: The conversation thread associated with the message(s).
kwargs: Additional keyword arguments.

Returns:
An agent response item.
"""
pass

@abstractmethod
def invoke(self, *args, **kwargs) -> AsyncIterable[AgentResponseItem[ChatMessageContent]]:
def invoke(
self,
*,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
**kwargs,
) -> AsyncIterable[AgentResponseItem[ChatMessageContent]]:
"""Invoke the agent.

This invocation method will return the intermediate steps and the final results
of the agent's execution as a stream of ChatMessageContent objects to the caller.

Note: A ChatMessageContent object contains an entire message.

Args:
messages: The message(s) to send to the agent.
thread: The conversation thread associated with the message(s).
kwargs: Additional keyword arguments.

Yields:
An agent response item.
"""
pass

@abstractmethod
def invoke_stream(self, *args, **kwargs) -> AsyncIterable[AgentResponseItem[StreamingChatMessageContent]]:
def invoke_stream(
self,
*,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
**kwargs,
) -> AsyncIterable[AgentResponseItem[StreamingChatMessageContent]]:
"""Invoke the agent as a stream.

This invocation method will return the intermediate steps and final results of the
agent's execution as a stream of StreamingChatMessageContent objects to the caller.

Note: A StreamingChatMessageContent object contains a chunk of a message.

Args:
messages: The message(s) to send to the agent.
thread: The conversation thread associated with the message(s).
kwargs: Additional keyword arguments.

Yields:
An agent response item.
"""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,18 @@ def __init__(self, conversable_agent: ConversableAgent, **kwargs: Any) -> None:
@trace_agent_get_response
@override
async def get_response(
self, messages: str | ChatMessageContent | list[str | ChatMessageContent], thread: AgentThread | None = None
self,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
**kwargs: Any,
) -> AgentResponseItem[ChatMessageContent]:
"""Get a response from the agent.

Args:
messages: The input chat message content either as a string, ChatMessageContent or
a list of strings or ChatMessageContent.
thread: The thread to use for the conversation. If None, a new thread will be created.
kwargs: Additional keyword arguments

Returns:
An AgentResponseItem of type ChatMessageContent object with the response and the thread.
Expand All @@ -153,6 +157,7 @@ async def get_response(

reply = await self.conversable_agent.a_generate_reply(
messages=[message.to_dict() for message in chat_history.messages],
**kwargs,
)

logger.info("Called AutoGenConversableAgent.a_generate_reply.")
Expand Down Expand Up @@ -245,7 +250,9 @@ async def invoke(
@override
def invoke_stream(
self,
message: str,
*,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
kernel: "Kernel | None" = None,
arguments: KernelArguments | None = None,
**kwargs: Any,
Expand Down
33 changes: 21 additions & 12 deletions python/semantic_kernel/agents/bedrock/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def create_and_prepare_agent(
async def get_response(
self,
*,
input_text: str,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
agent_alias: str | None = None,
arguments: KernelArguments | None = None,
Expand All @@ -272,7 +272,7 @@ async def get_response(
"""Get a response from the agent.

Args:
input_text (str): The input text.
messages (str | ChatMessageContent | list[str | ChatMessageContent]): The messages.
thread (AgentThread, optional): The thread. This is used to maintain the session state in the service.
agent_alias (str, optional): The agent alias.
arguments (KernelArguments, optional): The kernel arguments to override the current arguments.
Expand All @@ -282,8 +282,11 @@ async def get_response(
Returns:
A chat message content with the response.
"""
if not isinstance(messages, str) and not isinstance(messages, ChatMessageContent):
raise AgentInvokeException("Messages must be a string or a ChatMessageContent for BedrockAgent.")

thread = await self._ensure_thread_exists_with_messages(
messages=[input_text],
messages=messages,
thread=thread,
construct_thread=lambda: BedrockAgentThread(bedrock_runtime_client=self.bedrock_runtime_client),
expected_type=BedrockAgentThread,
Expand All @@ -302,7 +305,7 @@ async def get_response(
kwargs.setdefault("sessionState", {})

for _ in range(self.function_choice_behavior.maximum_auto_invoke_attempts):
response = await self._invoke_agent(thread.id, input_text, agent_alias, **kwargs)
response = await self._invoke_agent(thread.id, messages, agent_alias, **kwargs)

events: list[dict[str, Any]] = []
for event in response.get("completion", []):
Expand Down Expand Up @@ -355,7 +358,7 @@ async def get_response(
@override
async def invoke(
self,
input_text: str,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
*,
agent_alias: str | None = None,
Expand All @@ -366,7 +369,7 @@ async def invoke(
"""Invoke an agent.

Args:
input_text (str): The input text.
messages (str | ChatMessageContent | list[str | ChatMessageContent]): The messages.
thread (AgentThread, optional): The thread. This is used to maintain the session state in the service.
agent_alias (str, optional): The agent alias.
arguments (KernelArguments, optional): The kernel arguments to override the current arguments.
Expand All @@ -376,8 +379,11 @@ async def invoke(
Returns:
An async iterable of chat message content.
"""
if not isinstance(messages, str) and not isinstance(messages, ChatMessageContent):
raise AgentInvokeException("Messages must be a string or a ChatMessageContent for BedrockAgent.")

thread = await self._ensure_thread_exists_with_messages(
messages=[input_text],
messages=messages,
thread=thread,
construct_thread=lambda: BedrockAgentThread(bedrock_runtime_client=self.bedrock_runtime_client),
expected_type=BedrockAgentThread,
Expand All @@ -396,7 +402,7 @@ async def invoke(
kwargs.setdefault("sessionState", {})

for _ in range(self.function_choice_behavior.maximum_auto_invoke_attempts):
response = await self._invoke_agent(thread.id, input_text, agent_alias, **kwargs)
response = await self._invoke_agent(thread.id, messages, agent_alias, **kwargs)

events: list[dict[str, Any]] = []
for event in response.get("completion", []):
Expand Down Expand Up @@ -451,7 +457,7 @@ async def invoke(
@override
async def invoke_stream(
self,
input_text: str,
messages: str | ChatMessageContent | list[str | ChatMessageContent],
thread: AgentThread | None = None,
*,
agent_alias: str | None = None,
Expand All @@ -462,7 +468,7 @@ async def invoke_stream(
"""Invoke an agent with streaming.

Args:
input_text (str): The input text.
messages (str | ChatMessageContent | list[str | ChatMessageContent]): The messages.
thread (AgentThread, optional): The thread. This is used to maintain the session state in the service.
agent_alias (str, optional): The agent alias.
arguments (KernelArguments, optional): The kernel arguments to override the current arguments.
Expand All @@ -472,8 +478,11 @@ async def invoke_stream(
Returns:
An async iterable of streaming chat message content
"""
if not isinstance(messages, str) and not isinstance(messages, ChatMessageContent):
raise AgentInvokeException("Messages must be a string or a ChatMessageContent for BedrockAgent.")

thread = await self._ensure_thread_exists_with_messages(
messages=[input_text],
messages=messages,
thread=thread,
construct_thread=lambda: BedrockAgentThread(bedrock_runtime_client=self.bedrock_runtime_client),
expected_type=BedrockAgentThread,
Expand All @@ -492,7 +501,7 @@ async def invoke_stream(
kwargs.setdefault("sessionState", {})

for request_index in range(self.function_choice_behavior.maximum_auto_invoke_attempts):
response = await self._invoke_agent(thread.id, input_text, agent_alias, **kwargs)
response = await self._invoke_agent(thread.id, messages, agent_alias, **kwargs)

all_function_call_messages: list[StreamingChatMessageContent] = []
for event in response.get("completion", []):
Expand Down
9 changes: 7 additions & 2 deletions python/semantic_kernel/agents/bedrock/bedrock_agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from semantic_kernel.agents.bedrock.models.bedrock_agent_model import BedrockAgentModel
from semantic_kernel.agents.bedrock.models.bedrock_agent_status import BedrockAgentStatus
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.utils.async_utils import run_in_executor
from semantic_kernel.utils.feature_stage_decorator import experimental

Expand Down Expand Up @@ -349,14 +351,17 @@ async def list_associated_agent_knowledge_bases(self, **kwargs) -> dict[str, Any
async def _invoke_agent(
self,
thread_id: str,
input_text: str,
message: str | ChatMessageContent,
agent_alias: str | None = None,
**kwargs,
) -> dict[str, Any]:
"""Invoke an agent."""
if not self.agent_model.agent_id:
raise ValueError("Agent does not exist. Please create the agent before invoking it.")

if isinstance(message, ChatMessageContent) and message.role != AuthorRole.USER:
raise ValueError("Only user messages are supported for invoking a Bedrock agent.")

agent_alias = agent_alias or self.WORKING_DRAFT_AGENT_ALIAS

try:
Expand All @@ -367,7 +372,7 @@ async def _invoke_agent(
agentAliasId=agent_alias,
agentId=self.agent_model.agent_id,
sessionId=thread_id,
inputText=input_text,
inputText=message if isinstance(message, str) else message.content,
**kwargs,
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def invoke(self, agent: "Agent", **kwargs: Any) -> AsyncIterable[tuple[boo
await self._ensure_last_message_is_user()

async for response in agent.invoke(
input_text=self.messages[-1].content,
messages=self.messages[-1].content,
thread=self.thread,
sessionState=await self._parse_chat_history_to_session_state(),
):
Expand Down Expand Up @@ -105,7 +105,7 @@ async def invoke_stream(

full_message: list[StreamingChatMessageContent] = []
async for response_chunk in agent.invoke_stream(
input_text=self.messages[-1].content,
messages=self.messages[-1].content,
thread=self.thread,
sessionState=await self._parse_chat_history_to_session_state(),
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ async def test_bedrock_agent_get_response(
mock_invoke_agent.return_value = bedrock_agent_non_streaming_simple_response
mock_start.return_value = "test_session_id"

response = await agent.get_response(input_text="test_input_text", thread=thread)
response = await agent.get_response(messages="test_input_text", thread=thread)
assert response.message.content == simple_response

mock_invoke_agent.assert_called_once()
Expand Down Expand Up @@ -495,7 +495,7 @@ async def test_bedrock_agent_get_response_exception(
mock_start.return_value = "test_session_id"

with pytest.raises(AgentInvokeException):
await agent.get_response(input_text="test_input_text")
await agent.get_response(messages="test_input_text")


# Test case to verify the invocation of BedrockAgent
Expand Down
Loading
Loading