Skip to content

Commit 0424f3d

Browse files
authored
Support using LangChain Chat Models with AutoGen v0.2 (#1009)
2 parents 0f446a5 + 8a766f4 commit 0424f3d

File tree

8 files changed

+503
-0
lines changed

8 files changed

+503
-0
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ autots
2424
* Source code: https://github.com/winedarksea/AutoTS
2525
* Project home: https://winedarksea.github.io/AutoTS/build/html/index.html
2626

27+
autogen
28+
* Copyright (c) 2024 Microsoft Corporation.
29+
* License: MIT License
30+
* Source code: https://github.com/microsoft/autogen
31+
* Project home: microsoft.github.io/autogen/
32+
2733
bokeh
2834
* Copyright Copyright (c) 2012 - 2021, Anaconda, Inc., and Bokeh Contributors
2935
* License: BSD 3-Clause "New" or "Revised" License

ads/llm/autogen/__init__.py

Whitespace-only changes.

ads/llm/autogen/client_v02.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# coding: utf-8
2+
# Copyright (c) 2016, 2024, Oracle and/or its affiliates. All rights reserved.
3+
# This software is dual-licensed to you under the Universal Permissive License (UPL) 1.0 as shown at https://oss.oracle.com/licenses/upl or Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0. You may choose either license.
4+
5+
"""This module contains the custom LLM client for AutoGen v0.2 to use LangChain chat models.
6+
https://microsoft.github.io/autogen/0.2/blog/2024/01/26/Custom-Models/
7+
8+
To use the custom client:
9+
1. Prepare the LLM config, including the parameters for initializing the LangChain client.
10+
2. Register the custom LLM
11+
12+
The LLM config should config the following keys:
13+
* model_client_cls: Required by AutoGen to identify the custom client. It should be "LangChainModelClient"
14+
* langchain_cls: LangChain class including the full import path.
15+
* model: Name of the model to be used by AutoGen
16+
* client_params: A dictionary containing the parameters to initialize the LangChain chat model.
17+
18+
Although the `LangChainModelClient` is designed to be generic and can potentially support any LangChain chat model,
19+
the invocation depends on the server API spec and it may not be compatible with some implementations.
20+
21+
Following is an example config for OCI Generative AI service:
22+
{
23+
"model_client_cls": "LangChainModelClient",
24+
"langchain_cls": "langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI",
25+
"model": "cohere.command-r-plus",
26+
# client_params will be used to initialize the LangChain ChatOCIGenAI class.
27+
"client_params": {
28+
"model_id": "cohere.command-r-plus",
29+
"compartment_id": COMPARTMENT_OCID,
30+
"model_kwargs": {"temperature": 0, "max_tokens": 2048},
31+
# Update the authentication method as needed
32+
"auth_type": "SECURITY_TOKEN",
33+
"auth_profile": "DEFAULT",
34+
# You may need to specify `service_endpoint` if the service is in a different region.
35+
},
36+
}
37+
38+
Following is an example config for OCI Data Science Model Deployment:
39+
{
40+
"model_client_cls": "LangChainModelClient",
41+
"langchain_cls": "ads.llm.ChatOCIModelDeploymentVLLM",
42+
"model": "odsc-llm",
43+
"endpoint": "https://MODEL_DEPLOYMENT_URL/predict",
44+
"model_kwargs": {"temperature": 0.1, "max_tokens": 2048},
45+
# function_call_params will only be added to the API call when function/tools are added.
46+
"function_call_params": {
47+
"tool_choice": "auto",
48+
"chat_template": ChatTemplates.mistral(),
49+
},
50+
}
51+
52+
Note that if `client_params` is not specified in the config, all arguments from the config except
53+
`model_client_cls` and `langchain_cls`, and `function_call_params`, will be used to initialize
54+
the LangChain chat model.
55+
56+
The `function_call_params` will only be used for function/tool calling when tools are specified.
57+
58+
To register the custom client:
59+
60+
from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
61+
register_custom_client(LangChainModelClient)
62+
63+
Once registered with ADS, the custom LLM class will be auto-registered for all new agents.
64+
There is no need to call `register_model_client()` on each agent.
65+
66+
References:
67+
https://microsoft.github.io/autogen/0.2/docs/notebooks/agentchat_huggingface_langchain/
68+
https://github.com/microsoft/autogen/blob/0.2/notebook/agentchat_custom_model.ipynb
69+
70+
"""
71+
import copy
72+
import importlib
73+
import json
74+
import logging
75+
from typing import Any, Dict, List, Union
76+
from types import SimpleNamespace
77+
78+
from autogen import ModelClient
79+
from autogen.oai.client import OpenAIWrapper, PlaceHolderClient
80+
from langchain_core.messages import AIMessage
81+
82+
83+
logger = logging.getLogger(__name__)
84+
85+
# custom_clients is a dictionary mapping the name of the class to the actual class
86+
custom_clients = {}
87+
88+
# There is a bug in GroupChat when using custom client:
89+
# https://github.com/microsoft/autogen/issues/2956
90+
# Here we will be patching the OpenAIWrapper to fix the issue.
91+
# With this patch, you only need to register the client once with ADS.
92+
# For example:
93+
#
94+
# from ads.llm.autogen.client_v02 import LangChainModelClient, register_custom_client
95+
# register_custom_client(LangChainModelClient)
96+
#
97+
# This patch will auto-register the custom LLM to all new agents.
98+
# So there is no need to call `register_model_client()` on each agent.
99+
OpenAIWrapper._original_register_default_client = OpenAIWrapper._register_default_client
100+
101+
102+
def _new_register_default_client(
103+
self: OpenAIWrapper, config: Dict[str, Any], openai_config: Dict[str, Any]
104+
) -> None:
105+
"""This is a patched version of the _register_default_client() method
106+
to automatically register custom client for agents.
107+
"""
108+
model_client_cls_name = config.get("model_client_cls")
109+
if model_client_cls_name in custom_clients:
110+
self._clients.append(PlaceHolderClient(config))
111+
self.register_model_client(custom_clients[model_client_cls_name])
112+
else:
113+
self._original_register_default_client(
114+
config=config, openai_config=openai_config
115+
)
116+
117+
118+
# Patch the _register_default_client() method
119+
OpenAIWrapper._register_default_client = _new_register_default_client
120+
121+
122+
def register_custom_client(client_class):
123+
"""Registers custom client for AutoGen."""
124+
if client_class.__name__ not in custom_clients:
125+
custom_clients[client_class.__name__] = client_class
126+
127+
128+
def _convert_to_langchain_tool(tool):
129+
"""Converts the OpenAI tool spec to LangChain tool spec."""
130+
if tool["type"] == "function":
131+
tool = tool["function"]
132+
required = tool["parameters"].get("required", [])
133+
properties = copy.deepcopy(tool["parameters"]["properties"])
134+
for key in properties.keys():
135+
val = properties[key]
136+
val["default"] = key in required
137+
return {
138+
"title": tool["name"],
139+
"description": tool["description"],
140+
"properties": properties,
141+
}
142+
raise NotImplementedError(f"Type {tool['type']} is not supported.")
143+
144+
145+
def _convert_to_openai_tool_call(tool_call):
146+
"""Converts the LangChain tool call in AI message to OpenAI tool call."""
147+
return {
148+
"id": tool_call.get("id"),
149+
"function": {
150+
"name": tool_call.get("name"),
151+
"arguments": (
152+
""
153+
if tool_call.get("args") is None
154+
else json.dumps(tool_call.get("args"))
155+
),
156+
},
157+
"type": "function",
158+
}
159+
160+
161+
class Message(AIMessage):
162+
"""Represents message returned from the LLM."""
163+
164+
@classmethod
165+
def from_message(cls, message: AIMessage):
166+
"""Converts from LangChain AIMessage."""
167+
message = copy.deepcopy(message)
168+
message.__class__ = cls
169+
message.tool_calls = [
170+
_convert_to_openai_tool_call(tool) for tool in message.tool_calls
171+
]
172+
return message
173+
174+
@property
175+
def function_call(self):
176+
"""Function calls."""
177+
return self.tool_calls
178+
179+
180+
class LangChainModelClient(ModelClient):
181+
"""Represents a model client wrapping a LangChain chat model."""
182+
183+
def __init__(self, config: dict, **kwargs) -> None:
184+
super().__init__()
185+
logger.info("LangChain model client config: %s", str(config))
186+
# Make a copy of the config since we are popping some keys
187+
config = copy.deepcopy(config)
188+
# model_client_cls will always be LangChainModelClient
189+
self.client_class = config.pop("model_client_cls")
190+
191+
# model_name is used in constructing the response.
192+
self.model_name = config.get("model", "")
193+
194+
# If the config specified function_call_params,
195+
# Pop the params and use them only for tool calling.
196+
self.function_call_params = config.pop("function_call_params", {})
197+
198+
# If the config specified invoke_params,
199+
# Pop the params and use them only for invoking.
200+
self.invoke_params = config.pop("invoke_params", {})
201+
202+
# Import the LangChain class
203+
if "langchain_cls" not in config:
204+
raise ValueError("Missing langchain_cls in LangChain Model Client config.")
205+
module_cls = config.pop("langchain_cls")
206+
module_name, cls_name = str(module_cls).rsplit(".", 1)
207+
langchain_module = importlib.import_module(module_name)
208+
langchain_cls = getattr(langchain_module, cls_name)
209+
210+
# If the config specified client_params,
211+
# Only use the client_params to initialize the LangChain model.
212+
# Otherwise, use the config
213+
self.client_params = config.get("client_params", config)
214+
215+
# Initialize the LangChain client
216+
self.model = langchain_cls(**self.client_params)
217+
218+
def create(self, params) -> ModelClient.ModelClientResponseProtocol:
219+
"""Creates a LLM completion for a given config.
220+
221+
Parameters
222+
----------
223+
params : dict
224+
OpenAI API compatible parameters, including all the keys from llm_config.
225+
226+
Returns
227+
-------
228+
ModelClientResponseProtocol
229+
Response from LLM
230+
231+
"""
232+
streaming = params.get("stream", False)
233+
# TODO: num_of_responses
234+
num_of_responses = params.get("n", 1)
235+
messages = params.pop("messages", [])
236+
237+
invoke_params = copy.deepcopy(self.invoke_params)
238+
239+
tools = params.get("tools")
240+
if tools:
241+
model = self.model.bind_tools(
242+
[_convert_to_langchain_tool(tool) for tool in tools]
243+
)
244+
# invoke_params["tools"] = tools
245+
invoke_params.update(self.function_call_params)
246+
else:
247+
model = self.model
248+
249+
response = SimpleNamespace()
250+
response.choices = []
251+
response.model = self.model_name
252+
253+
if streaming and messages:
254+
# If streaming is enabled and has messages, then iterate over the chunks of the response.
255+
raise NotImplementedError()
256+
else:
257+
# If streaming is not enabled, send a regular chat completion request
258+
ai_message = model.invoke(messages, **invoke_params)
259+
choice = SimpleNamespace()
260+
choice.message = Message.from_message(ai_message)
261+
response.choices.append(choice)
262+
return response
263+
264+
def message_retrieval(
265+
self, response: ModelClient.ModelClientResponseProtocol
266+
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
267+
"""
268+
Retrieve and return a list of strings or a list of Choice.Message from the response.
269+
270+
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
271+
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
272+
"""
273+
return [choice.message for choice in response.choices]
274+
275+
def cost(self, response: ModelClient.ModelClientResponseProtocol) -> float:
276+
response.cost = 0
277+
return 0
278+
279+
@staticmethod
280+
def get_usage(response: ModelClient.ModelClientResponseProtocol) -> Dict:
281+
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
282+
return {}

0 commit comments

Comments
 (0)