diff --git a/libs/aws/langchain_aws/__init__.py b/libs/aws/langchain_aws/__init__.py index 29f4bbc6..98307901 100644 --- a/libs/aws/langchain_aws/__init__.py +++ b/libs/aws/langchain_aws/__init__.py @@ -1,8 +1,9 @@ +from langchain_aws.chat_model_adapter import BedrockClaudeAdapter, BedrockLlamaAdapter, ModelAdapter from langchain_aws.chains import ( create_neptune_opencypher_qa_chain, create_neptune_sparql_qa_chain, ) -from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse +from langchain_aws.chat_models import ChatBedrock, ChatBedrockConverse, DemoChatBedrock from langchain_aws.embeddings import BedrockEmbeddings from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph from langchain_aws.llms import BedrockLLM, SagemakerEndpoint @@ -20,6 +21,10 @@ "BedrockLLM", "ChatBedrock", "ChatBedrockConverse", + "DemoChatBedrock", + "ModelAdapter", + "BedrockClaudeAdapter", + "BedrockLlamaAdapter", "SagemakerEndpoint", "AmazonKendraRetriever", "AmazonKnowledgeBasesRetriever", diff --git a/libs/aws/langchain_aws/chat_model_adapter/__init__.py b/libs/aws/langchain_aws/chat_model_adapter/__init__.py new file mode 100644 index 00000000..9f94aa46 --- /dev/null +++ b/libs/aws/langchain_aws/chat_model_adapter/__init__.py @@ -0,0 +1,5 @@ +from langchain_aws.chat_model_adapter.demo_chat_adapter import ModelAdapter +from langchain_aws.chat_model_adapter.anthropic_adapter import BedrockClaudeAdapter +from langchain_aws.chat_model_adapter.llama_adapter import BedrockLlamaAdapter + +__all__ = ["ModelAdapter", "BedrockClaudeAdapter", "BedrockLlamaAdapter"] diff --git a/libs/aws/langchain_aws/chat_model_adapter/anthropic_adapter.py b/libs/aws/langchain_aws/chat_model_adapter/anthropic_adapter.py new file mode 100644 index 00000000..2f37ebe5 --- /dev/null +++ b/libs/aws/langchain_aws/chat_model_adapter/anthropic_adapter.py @@ -0,0 +1,438 @@ +from typing import ( + Any, + Iterator, + List, + Optional, + Sequence, + Union, + Dict, + Callable, + Literal, + Type, + TypeVar, + Tuple, + TypedDict, + cast, + Mapping +) + +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import ( + BaseMessage, + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, + ChatMessage, + ToolCall +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.pydantic import TypeBaseModel +from langchain_aws.utils import enforce_stop_tokens +from langchain_aws.function_calling import _tools_in_params, _lc_tool_calls_to_anthropic_tool_use_blocks +from langchain_core.messages.tool import tool_call, tool_call_chunk +from langchain_aws.chat_model_adapter.demo_chat_adapter import ModelAdapter +from pydantic import BaseModel +from langchain_core.outputs import Generation, GenerationChunk, LLMResult +import re +import json +import logging +import warnings + + +class AnthropicTool(TypedDict): + name: str + description: str + input_schema: Dict[str, Any] + +HUMAN_PROMPT = "\n\nHuman:" +ASSISTANT_PROMPT = "\n\nAssistant:" +ALTERNATION_ERROR = ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." +) + +# Example concrete implementation for a specific model +class BedrockClaudeAdapter(ModelAdapter): + + _message_type_lookups = { + "human": "user", + "ai": "assistant", + "AIMessageChunk": "assistant", + "HumanMessageChunk": "user", + } + + def convert_messages_to_payload( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + model: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + # Specific implementation for converting LC messages to Claude payload + system, formatted_messages = self._format_anthropic_messages(messages=messages) + + return {"system": system, "messages":formatted_messages} + + + def convert_response_to_chat_result(self, response: Any) -> ChatResult: + pass + + def convert_stream_response_to_chunks( + self, response: Any + ) -> Iterator[ChatGenerationChunk]: + """Convert model-specific stream response to LangChain chunks""" + pass + + def format_tools( + self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]] + ) -> Any: + """Format tools for the specific model""" + pass + + def _format_image(self, image_url: str) -> Dict: + """ + Formats an image of format data:image/jpeg;base64,{b64_string} + to a dict for anthropic api + + { + "type": "base64", + "media_type": "image/jpeg", + "data": "/9j/4AAQSkZJRg...", + } + + And throws an error if it's not a b64 image + """ + regex = r"^data:(?Pimage/.+);base64,(?P.+)$" + match = re.match(regex, image_url) + if match is None: + raise ValueError( + "Anthropic only supports base64-encoded images currently." + " Example: data:image/png;base64,'/9j/4AAQSk'..." + ) + return { + "type": "base64", + "media_type": match.group("media_type"), + "data": match.group("data"), + } + + def _merge_messages( + self, + messages: Sequence[BaseMessage], + ) -> List[Union[SystemMessage, AIMessage, HumanMessage]]: + """Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501 + merged: list = [] + for curr in messages: + curr = curr.model_copy(deep=True) + if isinstance(curr, ToolMessage): + if isinstance(curr.content, list) and all( + isinstance(block, dict) and block.get("type") == "tool_result" + for block in curr.content + ): + curr = HumanMessage(curr.content) # type: ignore[misc] + else: + curr = HumanMessage( # type: ignore[misc] + [ + { + "type": "tool_result", + "content": curr.content, + "tool_use_id": curr.tool_call_id, + } + ] + ) + last = merged[-1] if merged else None + if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage): + if isinstance(last.content, str): + new_content: List = [{"type": "text", "text": last.content}] + else: + new_content = last.content + if isinstance(curr.content, str): + new_content.append({"type": "text", "text": curr.content}) + else: + new_content.extend(curr.content) + last.content = new_content + else: + merged.append(curr) + return merged + + def _format_anthropic_messages( + self, + messages: List[BaseMessage], + ) -> Tuple[Optional[str], List[Dict]]: + """Format messages for anthropic.""" + system: Optional[str] = None + formatted_messages: List[Dict] = [] + + merged_messages = self._merge_messages(messages) + for i, message in enumerate(merged_messages): + if message.type == "system": + if i != 0: + raise ValueError("System message must be at beginning of message list.") + if not isinstance(message.content, str): + raise ValueError( + "System message must be a string, " + f"instead was: {type(message.content)}" + ) + system = message.content + continue + + role = self._message_type_lookups[message.type] + content: Union[str, List] + + if not isinstance(message.content, str): + # parse as dict + assert isinstance( + message.content, list + ), "Anthropic message content must be str or list of dicts" + + # populate content + content = [] + for item in message.content: + if isinstance(item, str): + content.append({"type": "text", "text": item}) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a type key") + elif item["type"] == "image_url": + # convert format + source = self._format_image(item["image_url"]["url"]) + content.append({"type": "image", "source": source}) + elif item["type"] == "tool_use": + # If a tool_call with the same id as a tool_use content block + # exists, the tool_call is preferred. + if isinstance(message, AIMessage) and item["id"] in [ + tc["id"] for tc in message.tool_calls + ]: + overlapping = [ + tc + for tc in message.tool_calls + if tc["id"] == item["id"] + ] + content.extend( + _lc_tool_calls_to_anthropic_tool_use_blocks(overlapping) + ) + else: + item.pop("text", None) + content.append(item) + elif item["type"] == "text": + text = item.get("text", "") + # Only add non-empty strings for now as empty ones are not + # accepted. + # https://github.com/anthropics/anthropic-sdk-python/issues/461 + if text.strip(): + content.append({"type": "text", "text": text}) + else: + content.append(item) + else: + raise ValueError( + f"Content items must be str or dict, instead was: {type(item)}" + ) + elif isinstance(message, AIMessage) and message.tool_calls: + content = ( + [] + if not message.content + else [{"type": "text", "text": message.content}] + ) + # Note: Anthropic can't have invalid tool calls as presently defined, + # since the model already returns dicts args not JSON strings, and invalid + # tool calls are those with invalid JSON for args. + content += _lc_tool_calls_to_anthropic_tool_use_blocks(message.tool_calls) + else: + content = message.content + + formatted_messages.append({"role": role, "content": content}) + return system, formatted_messages + + def _add_newlines_before_ha(self, input_text: str) -> str: + new_text = input_text + for word in ["Human:", "Assistant:"]: + new_text = new_text.replace(word, "\n\n" + word) + for i in range(2): + new_text = new_text.replace("\n\n\n" + word, "\n\n" + word) + return new_text + + def _human_assistant_format(self, input_text: str) -> str: + if input_text.count("Human:") == 0 or ( + input_text.find("Human:") > input_text.find("Assistant:") + and "Assistant:" in input_text + ): + input_text = HUMAN_PROMPT + " " + input_text # SILENT CORRECTION + if input_text.count("Assistant:") == 0: + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + if input_text[: len("Human:")] == "Human:": + input_text = "\n\n" + input_text + input_text = self._add_newlines_before_ha(input_text) + count = 0 + # track alternation + for i in range(len(input_text)): + if input_text[i : i + len(HUMAN_PROMPT)] == HUMAN_PROMPT: + if count % 2 == 0: + count += 1 + else: + warnings.warn(ALTERNATION_ERROR + f" Received {input_text}") + if input_text[i : i + len(ASSISTANT_PROMPT)] == ASSISTANT_PROMPT: + if count % 2 == 1: + count += 1 + else: + warnings.warn(ALTERNATION_ERROR + f" Received {input_text}") + + if count % 2 == 1: # Only saw Human, no Assistant + input_text = input_text + ASSISTANT_PROMPT # SILENT CORRECTION + + return input_text + + def _prepare_input( + self, + model_kwargs: Dict[str, Any], + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, + tools: Optional[List[AnthropicTool]] = None, + *, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> Dict[str, Any]: + + input_body = {**model_kwargs} + if messages: + if tools: + input_body["tools"] = tools + input_body["anthropic_version"] = "bedrock-2023-05-31" + input_body["messages"] = messages + if system: + input_body["system"] = system + if max_tokens: + input_body["max_tokens"] = max_tokens + elif "max_tokens" not in input_body: + input_body["max_tokens"] = 1024 + + if prompt: + input_body["prompt"] = self._human_assistant_format(prompt) + if max_tokens: + input_body["max_tokens_to_sample"] = max_tokens + elif "max_tokens_to_sample" not in input_body: + input_body["max_tokens_to_sample"] = 1024 + + if temperature is not None: + input_body["temperature"] = temperature + return input_body + + def _extract_tool_calls(self, content: List[dict]) -> List[ToolCall]: + tool_calls = [] + for block in content: + if block["type"] != "tool_use": + continue + tool_calls.append( + tool_call(name=block["name"], args=block["input"], id=block["id"]) + ) + return tool_calls + + def _prepare_output(self, response: Any) -> dict: + text = "" + tool_calls = [] + response_body = json.loads(response.get("body").read().decode()) + + if "completion" in response_body: + text = response_body.get("completion") + elif "content" in response_body: + content = response_body.get("content") + if len(content) == 1 and content[0]["type"] == "text": + text = content[0]["text"] + elif any(block["type"] == "tool_use" for block in content): + tool_calls = self._extract_tool_calls(content) + + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) + return { + "text": text, + "tool_calls": tool_calls, + "body": response_body, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "stop_reason": response_body.get("stop_reason"), + } + + + def prepare_input_and_invoke( + self, + client: Any, + model_id: str, + request_options: Dict[str, Any], + input_params: Dict[str, Any], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs: Any, + ) -> Tuple[ + str, + List[ToolCall], + Dict[str, Any], + ]: + _model_kwargs = model_kwargs or {} + params = {**_model_kwargs, **kwargs} + + tools = None + if "claude-3" in model_id and _tools_in_params(params): + tools = params["tools"] + + input_body = self._prepare_input( + model_kwargs=params, + prompt=input_params["prompt"], + system=input_params["system"], + messages=input_params["messages"], + tools=tools, + max_tokens=max_tokens, + temperature=temperature, + ) + body = json.dumps(input_body) + request_options["body"] = body + + try: + print("anthropic adapter used for invoking response") + response = client.invoke_model(**request_options) + + ( + text, + tool_calls, + body, + usage_info, + stop_reason, + ) = self._prepare_output(response).values() + + except Exception as e: + logging.error(f"Error raised by bedrock service: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + llm_output = {"usage": usage_info, "stop_reason": stop_reason} + + + ''' TODO: checking for intervention is body should be done in ChatBedrock''' + # Verify and raise a callback error if any intervention occurs or a signal is + # sent from a Bedrock service, + # such as when guardrails are triggered. + # services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type] + + # if run_manager is not None and services_trace.get("signal"): + # run_manager.on_llm_error( + # Exception( + # f"Error raised by bedrock service: {services_trace.get('reason')}" + # ), + # **services_trace, + # ) + + return text, tool_calls, llm_output + # Implement other abstract methods similarly... \ No newline at end of file diff --git a/libs/aws/langchain_aws/chat_model_adapter/demo_chat_adapter.py b/libs/aws/langchain_aws/chat_model_adapter/demo_chat_adapter.py new file mode 100644 index 00000000..02f00929 --- /dev/null +++ b/libs/aws/langchain_aws/chat_model_adapter/demo_chat_adapter.py @@ -0,0 +1,83 @@ +from typing import ( + Any, + Iterator, + List, + Optional, + Sequence, + Union, + Dict, + Callable, + Literal, + Type, + TypeVar, + Tuple, + TypedDict, + cast, + Mapping +) + +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import ( + BaseMessage, + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, + ChatMessage, + ToolCall +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.pydantic import TypeBaseModel +from langchain_aws.utils import enforce_stop_tokens +from langchain_aws.function_calling import _tools_in_params, _lc_tool_calls_to_anthropic_tool_use_blocks +from langchain_core.messages.tool import tool_call, tool_call_chunk +from pydantic import BaseModel +from langchain_core.outputs import Generation, GenerationChunk, LLMResult +from abc import ABC, abstractmethod +import re +import json +import logging +import warnings +# ModelAdapter might also need access to the data that the wrapper ChatModel class has +# for example, the provider or custom inputs passed in by the user + + +class ModelAdapter(ABC): + """Abstract base class for model-specific adaptation strategies""" + + @abstractmethod + def convert_messages_to_payload( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Any: + """Convert LangChain messages to model-specific payload""" + pass + + @abstractmethod + def convert_response_to_chat_result(self, response: Any) -> ChatResult: + """Convert model-specific response to LangChain ChatResult""" + pass + + @abstractmethod + def convert_stream_response_to_chunks( + self, response: Any + ) -> Iterator[ChatGenerationChunk]: + """Convert model-specific stream response to LangChain chunks""" + pass + + @abstractmethod + def format_tools( + self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]] + ) -> Any: + """Format tools for the specific model""" + pass + + + + diff --git a/libs/aws/langchain_aws/chat_model_adapter/llama_adapter.py b/libs/aws/langchain_aws/chat_model_adapter/llama_adapter.py new file mode 100644 index 00000000..21a00381 --- /dev/null +++ b/libs/aws/langchain_aws/chat_model_adapter/llama_adapter.py @@ -0,0 +1,273 @@ +from typing import ( + Any, + Iterator, + List, + Optional, + Sequence, + Union, + Dict, + Callable, + Literal, + Type, + TypeVar, + Tuple, + TypedDict, + cast, + Mapping +) + +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import ( + BaseMessage, + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, + ChatMessage, + ToolCall +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable +from langchain_core.tools import BaseTool +from langchain_core.utils.pydantic import TypeBaseModel +from langchain_aws.utils import enforce_stop_tokens +from langchain_aws.function_calling import _tools_in_params, _lc_tool_calls_to_anthropic_tool_use_blocks +from langchain_core.messages.tool import tool_call, tool_call_chunk +from langchain_aws.chat_model_adapter.demo_chat_adapter import ModelAdapter +from pydantic import BaseModel +from langchain_core.outputs import Generation, GenerationChunk, LLMResult +import re +import json +import logging +import warnings + + +class AnthropicTool(TypedDict): + name: str + description: str + input_schema: Dict[str, Any] + +HUMAN_PROMPT = "\n\nHuman:" +ASSISTANT_PROMPT = "\n\nAssistant:" +ALTERNATION_ERROR = ( + "Error: Prompt must alternate between '\n\nHuman:' and '\n\nAssistant:'." +) + +# Example concrete implementation for a specific model +class BedrockLlamaAdapter(ModelAdapter): + + _message_type_lookups = { + "human": "user", + "ai": "assistant", + "AIMessageChunk": "assistant", + "HumanMessageChunk": "user", + } + + def convert_messages_to_payload( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + model: Optional[str] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + # Specific implementation for converting LC messages to Claude payload + prompt = self._convert_messages_to_prompt(messages=messages, model=model) + + return {"prompt": prompt} + + + def convert_response_to_chat_result(self, response: Any) -> ChatResult: + pass + + def convert_stream_response_to_chunks( + self, response: Any + ) -> Iterator[ChatGenerationChunk]: + """Convert model-specific stream response to LangChain chunks""" + pass + + def format_tools( + self, tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]] + ) -> Any: + """Format tools for the specific model""" + pass + + + def _convert_messages_to_prompt( + self, messages: List[BaseMessage], model: str + ) -> str: + if "llama3" in model: + prompt = self._convert_messages_to_prompt_llama3(messages=messages) + else: + prompt = self._convert_messages_to_prompt_llama(messages=messages) + return prompt + + def _convert_one_message_to_text_llama(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + + def _convert_messages_to_prompt_llama(self, messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + + return "\n".join( + [self._convert_one_message_to_text_llama(message) for message in messages] + ) + + + def _convert_one_message_to_text_llama3(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = ( + f"<|start_header_id|>{message.role}" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, HumanMessage): + message_text = ( + f"<|start_header_id|>user" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, AIMessage): + message_text = ( + f"<|start_header_id|>assistant" + f"<|end_header_id|>{message.content}<|eot_id|>" + ) + elif isinstance(message, SystemMessage): + message_text = ( + f"<|start_header_id|>system" f"<|end_header_id|>{message.content}<|eot_id|>" + ) + else: + raise ValueError(f"Got unknown type {message}") + + return message_text + + + def _convert_messages_to_prompt_llama3(self, messages: List[BaseMessage]) -> str: + """Convert a list of messages to a prompt for llama.""" + return "\n".join( + ["<|begin_of_text|>"] + + [self._convert_one_message_to_text_llama3(message) for message in messages] + + ["<|start_header_id|>assistant<|end_header_id|>\n\n"] + ) + + def _prepare_input( + self, + model_kwargs: Dict[str, Any], + prompt: Optional[str] = None, + system: Optional[str] = None, + messages: Optional[List[Dict]] = None, + tools: Optional[List[AnthropicTool]] = None, + *, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> Dict[str, Any]: + + input_body = {**model_kwargs} + input_body["prompt"] = prompt + if max_tokens: + input_body["max_gen_len"] = max_tokens + if temperature is not None: + input_body["temperature"] = temperature + return input_body + + def _prepare_output(self, response: Any) -> dict: + text = "" + tool_calls = [] + response_body = json.loads(response.get("body").read().decode()) + text = response_body.get("generation") + + headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {}) + prompt_tokens = int(headers.get("x-amzn-bedrock-input-token-count", 0)) + completion_tokens = int(headers.get("x-amzn-bedrock-output-token-count", 0)) + return { + "text": text, + "tool_calls": tool_calls, + "body": response_body, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + "stop_reason": response_body.get("stop_reason"), + } + + + def prepare_input_and_invoke( + self, + client: Any, + model_id: str, + request_options: Dict[str, Any], + input_params: Dict[str, Any], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs: Any, + ) -> Tuple[ + str, + List[ToolCall], + Dict[str, Any], + ]: + _model_kwargs = model_kwargs or {} + params = {**_model_kwargs, **kwargs} + + input_body = self._prepare_input( + model_kwargs=params, + prompt=input_params["prompt"], + system=input_params["system"], + messages=input_params["messages"], + max_tokens=max_tokens, + temperature=temperature, + ) + body = json.dumps(input_body) + request_options["body"] = body + + try: + print("Meta adapter used for invoke") + response = client.invoke_model(**request_options) + + ( + text, + tool_calls, + body, + usage_info, + stop_reason, + ) = self._prepare_output(response).values() + + except Exception as e: + logging.error(f"Error raised by bedrock service: {e}") + if run_manager is not None: + run_manager.on_llm_error(e) + raise e + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + llm_output = {"usage": usage_info, "stop_reason": stop_reason} + + + ''' TODO: checking for intervention is body should be done in ChatBedrock''' + # Verify and raise a callback error if any intervention occurs or a signal is + # sent from a Bedrock service, + # such as when guardrails are triggered. + # services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type] + + # if run_manager is not None and services_trace.get("signal"): + # run_manager.on_llm_error( + # Exception( + # f"Error raised by bedrock service: {services_trace.get('reason')}" + # ), + # **services_trace, + # ) + + return text, tool_calls, llm_output + # Implement other abstract methods similarly... \ No newline at end of file diff --git a/libs/aws/langchain_aws/chat_models/__init__.py b/libs/aws/langchain_aws/chat_models/__init__.py index 12612c41..7469414a 100644 --- a/libs/aws/langchain_aws/chat_models/__init__.py +++ b/libs/aws/langchain_aws/chat_models/__init__.py @@ -1,4 +1,5 @@ from langchain_aws.chat_models.bedrock import ChatBedrock from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse +from langchain_aws.chat_models.demo_chat import DemoChatBedrock -__all__ = ["ChatBedrock", "ChatBedrockConverse"] +__all__ = ["ChatBedrock", "ChatBedrockConverse", "DemoChatBedrock"] diff --git a/libs/aws/langchain_aws/chat_models/bedrock_converse.py b/libs/aws/langchain_aws/chat_models/bedrock_converse.py index 4748d8b6..6dbe275e 100644 --- a/libs/aws/langchain_aws/chat_models/bedrock_converse.py +++ b/libs/aws/langchain_aws/chat_models/bedrock_converse.py @@ -498,6 +498,7 @@ def _generate( params = self._converse_params( stop=stop, **_snake_to_camel_keys(kwargs, excluded_keys={"inputSchema"}) ) + print("Converse API used for getting response") response = self.client.converse( messages=bedrock_messages, system=system, **params ) diff --git a/libs/aws/langchain_aws/chat_models/demo_chat.py b/libs/aws/langchain_aws/chat_models/demo_chat.py new file mode 100644 index 00000000..f1522ef3 --- /dev/null +++ b/libs/aws/langchain_aws/chat_models/demo_chat.py @@ -0,0 +1,548 @@ +import re +from collections import defaultdict +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import ( + BaseChatModel, + LangSmithParams, + LanguageModelInput, +) +from langchain_core.language_models.chat_models import generate_from_stream +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.messages.ai import UsageMetadata +from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool +from langchain_core.utils.pydantic import TypeBaseModel, is_basemodel_subclass +from pydantic import BaseModel, ConfigDict, model_validator + +from langchain_aws.chat_models.bedrock_converse import ChatBedrockConverse +from langchain_aws.function_calling import ( + ToolsOutputParser, + _lc_tool_calls_to_anthropic_tool_use_blocks, + convert_to_anthropic_tool, + get_system_message, +) +from langchain_aws.llms.bedrock import ( + BedrockBase, + _combine_generation_info_for_llm_result, +) +from langchain_aws.utils import ( + get_num_tokens_anthropic, + get_token_ids_anthropic, +) + +from langchain_aws.chat_model_adapter.demo_chat_adapter import ModelAdapter + + +_message_type_lookups = { + "human": "user", + "ai": "assistant", + "AIMessageChunk": "assistant", + "HumanMessageChunk": "user", +} + + +class DemoChatBedrock(BaseChatModel, BedrockBase): + """A chat model that uses the Bedrock API.""" + + system_prompt_with_tools: str = "" + beta_use_converse_api: bool = False + chat_prompt_adapter: ModelAdapter = None + + """Use the new Bedrock ``converse`` API which provides a standardized interface to + all Bedrock models. Support still in beta. See ChatBedrockConverse docs for more.""" + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "amazon_bedrock_chat" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "chat_models", "bedrock"] + + @property + def lc_attributes(self) -> Dict[str, Any]: + attributes: Dict[str, Any] = {} + + if self.region_name: + attributes["region_name"] = self.region_name + + return attributes + + model_config = ConfigDict( + extra="forbid", + populate_by_name=True, + ) + + # TODO: add get_ls_params() later + + def get_request_options(self): + accept = "application/json" + contentType = "application/json" + + request_options = { + "modelId": self.model_id, + "accept": accept, + "contentType": contentType, + } + + if self._guardrails_enabled: + request_options["guardrailIdentifier"] = self.guardrails.get( # type: ignore[union-attr] + "guardrailIdentifier", "" + ) + request_options["guardrailVersion"] = self.guardrails.get( # type: ignore[union-attr] + "guardrailVersion", "" + ) + if self.guardrails.get("trace"): # type: ignore[union-attr] + request_options["trace"] = "ENABLED" + return request_options + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + if self.beta_use_converse_api: + yield from self._as_converse._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return + provider = self._get_provider() + prompt, system, formatted_messages = None, None, None + + if provider == "anthropic": + system, formatted_messages = self.chat_prompt_adapter.format_anthropic_messages( + messages + ) + if self.system_prompt_with_tools: + if system: + system = self.system_prompt_with_tools + f"\n{system}" + else: + system = self.system_prompt_with_tools + else: + prompt = self.chat_prompt_adapter.convert_messages_to_payload( + messages=messages + ) + + for chunk in self._prepare_input_and_invoke_stream( + prompt=prompt, + system=system, + messages=formatted_messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ): + if isinstance(chunk, AIMessageChunk): + generation_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk + else: + delta = chunk.text + if generation_info := chunk.generation_info: + usage_metadata = generation_info.pop("usage_metadata", None) + else: + usage_metadata = None + generation_chunk = ChatGenerationChunk( + message=AIMessageChunk( + content=delta, + response_metadata=chunk.generation_info, + usage_metadata=usage_metadata, + ) + if chunk.generation_info is not None + else AIMessageChunk(content=delta) + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if not self.chat_prompt_adapter: + return self._as_converse._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + completion = "" + llm_output: Dict[str, Any] = {} + tool_calls: List[ToolCall] = [] + provider_stop_reason_code = self.provider_stop_reason_key_map.get( + self._get_provider(), "stop_reason" + ) + provider = self._get_provider() + print(provider) + request_options = self.get_request_options() + if self.streaming: + if provider == "anthropic": + stream_iter = self._stream(messages, stop, run_manager, **kwargs) + return generate_from_stream(stream_iter) + + response_metadata: List[Dict[str, Any]] = [] + for chunk in self._stream(messages, stop, run_manager, **kwargs): + completion += chunk.text + response_metadata.append(chunk.message.response_metadata) + if "tool_calls" in chunk.message.additional_kwargs.keys(): + tool_calls = chunk.message.additional_kwargs["tool_calls"] + llm_output = _combine_generation_info_for_llm_result( + response_metadata, provider_stop_reason_code + ) + else: + prompt, system, formatted_messages = None, None, None + params: Dict[str, Any] = {**kwargs} + + input_params = self.chat_prompt_adapter.convert_messages_to_payload( + messages=messages, model=self._get_model() + ) + # use tools the new way with claude 3 + if self.system_prompt_with_tools: + if input_params["system"]: + input_params["system"] = self.system_prompt_with_tools + f"\n{input_params["system"]}" + else: + input_params["system"] = self.system_prompt_with_tools + + if stop: + params["stop_sequences"] = stop + + for k in {"system", "prompt", "messages"}: + if k not in input_params: + input_params[k] = None + completion, tool_calls, llm_output = self.chat_prompt_adapter.prepare_input_and_invoke( + client=self.client, + model_id=self.model_id, + request_options=request_options, + input_params=input_params, + stop=stop, + run_manager=run_manager, + model_kwargs=self.model_kwargs, + temperature=self.temperature, + max_tokens=self.max_tokens, + **params + ) + # usage metadata + if usage := llm_output.get("usage"): + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + usage_metadata = UsageMetadata( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=usage.get("total_tokens", input_tokens + output_tokens), + ) + else: + usage_metadata = None + + llm_output["model_id"] = self.model_id + + msg = AIMessage( + content=completion, + additional_kwargs=llm_output, + tool_calls=cast(List[ToolCall], tool_calls), + usage_metadata=usage_metadata, + ) + + return ChatResult( + generations=[ + ChatGeneration( + message=msg, + ) + ], + llm_output=llm_output, + ) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + final_usage: Dict[str, int] = defaultdict(int) + final_output = {} + for output in llm_outputs: + output = output or {} + usage = output.get("usage", {}) + for token_type, token_count in usage.items(): + final_usage[token_type] += token_count + final_output.update(output) + final_output["usage"] = final_usage + return final_output + + def get_num_tokens(self, text: str) -> int: + if self._model_is_anthropic: + return get_num_tokens_anthropic(text) + else: + return super().get_num_tokens(text) + + def get_token_ids(self, text: str) -> List[int]: + if self._model_is_anthropic: + return get_token_ids_anthropic(text) + else: + return super().get_token_ids(text) + + def set_system_prompt_with_tools(self, xml_tools_system_prompt: str) -> None: + """Workaround to bind. Sets the system prompt with tools""" + self.system_prompt_with_tools = xml_tools_system_prompt + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], TypeBaseModel, Callable, BaseTool]], + *, + tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Assumes model has a tool calling API. + + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Which tool to require the model to call. + Must be the name of the single provided function or + "auto" to automatically determine which function to call + (if any), or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + if self.beta_use_converse_api: + if isinstance(tool_choice, bool): + tool_choice = "any" if tool_choice else None + return self._as_converse.bind_tools( + tools, tool_choice=tool_choice, **kwargs + ) + if self._get_provider() == "anthropic": + formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] + + # true if the model is a claude 3 model + if "claude-3" in self._get_model(): + if not tool_choice: + pass + elif isinstance(tool_choice, dict): + kwargs["tool_choice"] = tool_choice + elif isinstance(tool_choice, str) and tool_choice in ("any", "auto"): + kwargs["tool_choice"] = {"type": tool_choice} + elif isinstance(tool_choice, str): + kwargs["tool_choice"] = {"type": "tool", "name": tool_choice} + else: + raise ValueError( + f"Unrecognized 'tool_choice' type {tool_choice=}." + f"Expected dict, str, or None." + ) + return self.bind(tools=formatted_tools, **kwargs) + else: + # add tools to the system prompt, the old way + system_formatted_tools = get_system_message(formatted_tools) + self.set_system_prompt_with_tools(system_formatted_tools) + return self + + def with_structured_output( + self, + schema: Union[Dict, TypeBaseModel], + *, + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input. The output type depends on + include_raw and schema. + + If include_raw is True then output is a dict with keys: + raw: BaseMessage, + parsed: Optional[_DictOrPydantic], + parsing_error: Optional[BaseException], + + If include_raw is False and schema is a Dict then the runnable outputs a Dict. + If include_raw is False and schema is a Type[BaseModel] then the runnable + outputs a BaseModel. + + Example: Pydantic schema (include_raw=False): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + + # -> AnswerWithJustification( + # answer='They weigh the same', + # justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.' + # ) + + Example: Pydantic schema (include_raw=True): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + from pydantic import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), + # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), + # 'parsing_error': None + # } + + Example: Dict schema (include_raw=False): + .. code-block:: python + + from langchain_aws.chat_models.bedrock import ChatBedrock + + schema = { + "name": "AnswerWithJustification", + "description": "An answer to the user question along with justification for the answer.", + "input_schema": { + "type": "object", + "properties": { + "answer": {"type": "string"}, + "justification": {"type": "string"}, + }, + "required": ["answer", "justification"] + } + } + llm =ChatBedrock( + model_id="anthropic.claude-3-sonnet-20240229-v1:0", + model_kwargs={"temperature": 0.001}, + ) # type: ignore[call-arg] + structured_llm = llm.with_structured_output(schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'They weigh the same', + # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' + # } + + """ # noqa: E501 + if self.beta_use_converse_api: + return self._as_converse.with_structured_output( + schema, include_raw=include_raw, **kwargs + ) + if "claude-3" not in self._get_model(): + ValueError( + f"Structured output is not supported for model {self._get_model()}" + ) + + tool_name = convert_to_anthropic_tool(schema)["name"] + llm = self.bind_tools([schema], tool_choice=tool_name) + if isinstance(schema, type) and is_basemodel_subclass(schema): + output_parser = ToolsOutputParser( + first_tool_only=True, pydantic_schemas=[schema] + ) + else: + output_parser = ToolsOutputParser(first_tool_only=True, args_only=True) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + + @property + def _as_converse(self) -> ChatBedrockConverse: + kwargs = { + k: v + for k, v in (self.model_kwargs or {}).items() + if k + in ( + "stop", + "stop_sequences", + "max_tokens", + "temperature", + "top_p", + "additional_model_request_fields", + "additional_model_response_field_paths", + ) + } + if self.max_tokens: + kwargs["max_tokens"] = self.max_tokens + if self.temperature is not None: + kwargs["temperature"] = self.temperature + return ChatBedrockConverse( + model=self.model_id, + region_name=self.region_name, + credentials_profile_name=self.credentials_profile_name, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, + config=self.config, + provider=self.provider or "", + base_url=self.endpoint_url, + guardrail_config=(self.guardrails if self._guardrails_enabled else None), # type: ignore[call-arg] + **kwargs, + )