Skip to content

Commit 1ece675

Browse files
massi-angbigadsoleiman
authored andcommitted
feat(claude_mm): adding multimodal chat for Claude
redesign of the Idefics Model Interface to accept multiple model via adapters.
1 parent 0545d45 commit 1ece675

File tree

19 files changed

+461
-241
lines changed

19 files changed

+461
-241
lines changed

lib/aws-genai-llm-chatbot-stack.ts

Lines changed: 57 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
5353
});
5454

5555
// Langchain Interface Construct
56-
// This is the model interface recieving messages from the websocket interface via the message topic
56+
// This is the model interface receiving messages from the websocket interface via the message topic
5757
// and interacting with the model via LangChain library
5858
const langchainModels = models.models.filter(
5959
(model) => model.interface === ModelInterface.LangChain
@@ -100,46 +100,45 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
100100
}
101101

102102
// IDEFICS Interface Construct
103-
// This is the model interface recieving messages from the websocket interface via the message topic
103+
// This is the model interface receiving messages from the websocket interface via the message topic
104104
// and interacting with IDEFICS visual language models
105105
const ideficsModels = models.models.filter(
106-
(model) => model.interface === ModelInterface.Idefics
106+
(model) => model.interface === ModelInterface.MultiModal
107107
);
108108

109109
// check if any deployed model requires idefics interface
110-
if (ideficsModels.length > 0) {
111-
const ideficsInterface = new IdeficsInterface(this, "IdeficsInterface", {
112-
shared,
113-
config: props.config,
114-
messagesTopic: chatBotApi.messagesTopic,
115-
sessionsTable: chatBotApi.sessionsTable,
116-
byUserIdIndex: chatBotApi.byUserIdIndex,
117-
chatbotFilesBucket: chatBotApi.filesBucket,
118-
});
119110

120-
// Route all incoming messages targeted to idefics to the idefics model interface queue
121-
chatBotApi.messagesTopic.addSubscription(
122-
new subscriptions.SqsSubscription(ideficsInterface.ingestionQueue, {
123-
filterPolicyWithMessageBody: {
124-
direction: sns.FilterOrPolicy.filter(
125-
sns.SubscriptionFilter.stringFilter({
126-
allowlist: [Direction.In],
127-
})
128-
),
129-
modelInterface: sns.FilterOrPolicy.filter(
130-
sns.SubscriptionFilter.stringFilter({
131-
allowlist: [ModelInterface.Idefics],
132-
})
133-
),
134-
},
135-
})
136-
);
111+
const ideficsInterface = new IdeficsInterface(this, "IdeficsInterface", {
112+
shared,
113+
config: props.config,
114+
messagesTopic: chatBotApi.messagesTopic,
115+
sessionsTable: chatBotApi.sessionsTable,
116+
byUserIdIndex: chatBotApi.byUserIdIndex,
117+
chatbotFilesBucket: chatBotApi.filesBucket,
118+
});
137119

138-
for (const model of models.models) {
139-
// if model name contains idefics then add to idefics interface
140-
if (model.interface === ModelInterface.Idefics) {
141-
ideficsInterface.addSageMakerEndpoint(model);
142-
}
120+
// Route all incoming messages targeted to idefics to the idefics model interface queue
121+
chatBotApi.messagesTopic.addSubscription(
122+
new subscriptions.SqsSubscription(ideficsInterface.ingestionQueue, {
123+
filterPolicyWithMessageBody: {
124+
direction: sns.FilterOrPolicy.filter(
125+
sns.SubscriptionFilter.stringFilter({
126+
allowlist: [Direction.In],
127+
})
128+
),
129+
modelInterface: sns.FilterOrPolicy.filter(
130+
sns.SubscriptionFilter.stringFilter({
131+
allowlist: [ModelInterface.MultiModal],
132+
})
133+
),
134+
},
135+
})
136+
);
137+
138+
for (const model of models.models) {
139+
// if model name contains idefics then add to idefics interface
140+
if (model.interface === ModelInterface.MultiModal) {
141+
ideficsInterface.addSageMakerEndpoint(model);
143142
}
144143
}
145144

@@ -415,38 +414,40 @@ export class AwsGenAILLMChatbotStack extends cdk.Stack {
415414
reason: "Not yet upgraded from Python 3.11 to 3.12.",
416415
},
417416
]);
418-
417+
419418
if (props.config.privateWebsite) {
420419
const paths = [];
421-
for(let index = 0; index < shared.vpc.availabilityZones.length; index++) {
422-
paths.push(`/${this.stackName}/UserInterface/PrivateWebsite/DescribeNetworkInterfaces-${index}/CustomResourcePolicy/Resource`,)
420+
for (
421+
let index = 0;
422+
index < shared.vpc.availabilityZones.length;
423+
index++
424+
) {
425+
paths.push(
426+
`/${this.stackName}/UserInterface/PrivateWebsite/DescribeNetworkInterfaces-${index}/CustomResourcePolicy/Resource`
427+
);
423428
}
424-
paths.push(`/${this.stackName}/UserInterface/PrivateWebsite/describeVpcEndpoints/CustomResourcePolicy/Resource`,)
429+
paths.push(
430+
`/${this.stackName}/UserInterface/PrivateWebsite/describeVpcEndpoints/CustomResourcePolicy/Resource`
431+
);
432+
NagSuppressions.addResourceSuppressionsByPath(this, paths, [
433+
{
434+
id: "AwsSolutions-IAM5",
435+
reason:
436+
"Custom Resource requires permissions to Describe VPC Endpoint Network Interfaces",
437+
},
438+
]);
425439
NagSuppressions.addResourceSuppressionsByPath(
426440
this,
427-
paths,
441+
[
442+
`/${this.stackName}/AWS679f53fac002430cb0da5b7982bd2287/ServiceRole/Resource`,
443+
],
428444
[
429445
{
430-
id: "AwsSolutions-IAM5",
431-
reason:
432-
"Custom Resource requires permissions to Describe VPC Endpoint Network Interfaces",
446+
id: "AwsSolutions-IAM4",
447+
reason: "IAM role implicitly created by CDK.",
433448
},
434449
]
435450
);
436-
NagSuppressions.addResourceSuppressionsByPath(
437-
this,
438-
[
439-
`/${this.stackName}/AWS679f53fac002430cb0da5b7982bd2287/ServiceRole/Resource`
440-
],
441-
[
442-
{
443-
id: "AwsSolutions-IAM4",
444-
reason:
445-
"IAM role implicitly created by CDK.",
446-
},
447-
]
448-
);
449-
450451
}
451452
}
452453
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .idefics import Idefics
2+
from .claude import Claude3
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from abc import abstractmethod
2+
3+
4+
class MultiModalModelBase:
5+
@abstractmethod
6+
def handle_run(self, prompt: str, model_kwargs: dict) -> str: ...
7+
8+
@abstractmethod
9+
def format_prompt(self, prompt: str, messages: list, files: list) -> str: ...
10+
11+
def clean_prompt(self, prompt: str) -> str:
12+
return prompt
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from .base import MultiModalModelBase
2+
from genai_core.types import ChatbotAction, ChatbotMessageType
3+
from urllib.parse import urljoin
4+
import os
5+
from genai_core.clients import get_bedrock_client
6+
import json
7+
import requests
8+
from base64 import b64encode
9+
from genai_core.registry import registry
10+
11+
12+
def get_image_message(
13+
file: dict,
14+
):
15+
img = requests.get(
16+
f"{urljoin(os.environ['CHATBOT_FILES_PRIVATE_API'], file['key'])}"
17+
).content
18+
return {
19+
"type": "image",
20+
"source": {
21+
"type": "base64",
22+
"media_type": "image/jpeg",
23+
"data": str(b64encode(img), "ascii"),
24+
},
25+
}
26+
27+
28+
class Claude3(MultiModalModelBase):
29+
model_id: str
30+
client: any
31+
32+
def __init__(self, model_id: str):
33+
self.model_id = model_id
34+
self.client = get_bedrock_client()
35+
36+
def format_prompt(self, prompt: str, messages: list, files: list) -> str:
37+
prompts = []
38+
39+
# Chat history
40+
for message in messages:
41+
if message.type.lower() == ChatbotMessageType.Human.value.lower():
42+
user_msg = {
43+
"role": "user",
44+
"content": [{"type": "text", "text": message.content}],
45+
}
46+
prompts.append(user_msg)
47+
message_files = message.additional_kwargs.get("files", [])
48+
for message_file in message_files:
49+
user_msg["content"].append(get_image_message(message_file))
50+
if message.type.lower() == ChatbotMessageType.AI.value.lower():
51+
prompts.append({"role": "assistant", "content": message.content})
52+
53+
# User prompt
54+
user_msg = {
55+
"role": "user",
56+
"content": [{"type": "text", "text": prompt}],
57+
}
58+
prompts.append(user_msg)
59+
for file in files:
60+
user_msg["content"].append(get_image_message(file))
61+
62+
return json.dumps(
63+
{
64+
"anthropic_version": "bedrock-2023-05-31",
65+
"max_tokens": 512,
66+
"messages": prompts,
67+
"temperature": 0.3,
68+
}
69+
)
70+
71+
def handle_run(self, prompt: str, model_kwargs: dict):
72+
print(model_kwargs)
73+
body = json.loads(prompt)
74+
75+
if "temperature" in model_kwargs:
76+
body["temperature"] = model_kwargs["temperature"]
77+
if "topP" in model_kwargs:
78+
body["top_p"] = model_kwargs["topP"]
79+
if "maxTokens" in model_kwargs:
80+
body["max_tokens"] = model_kwargs["maxTokens"]
81+
if "topK" in model_kwargs:
82+
body["top_k"] = model_kwargs["topK"]
83+
84+
body_str = json.dumps(body)
85+
mlm_response = self.client.invoke_model(
86+
modelId=self.model_id,
87+
body=body_str,
88+
accept="application/json",
89+
contentType="application/json",
90+
)
91+
92+
return json.loads(mlm_response["body"].read())["content"][0]["text"]
93+
94+
def clean_prompt(self, prompt: str) -> str:
95+
p = json.loads(prompt)
96+
for m in p["messages"]:
97+
if m["role"] == "user" and type(m["content"]) == type([]):
98+
for c in m["content"]:
99+
if c["type"] == "image":
100+
c["data"]["source"] = ""
101+
return json.dumps(p)
102+
103+
104+
registry.register(r"^bedrock.anthropic.claude-3.*", Claude3)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from .base import MultiModalModelBase
2+
from genai_core.types import ChatbotAction, ChatbotMessageType
3+
from urllib.parse import urljoin
4+
import os
5+
from langchain.llms import SagemakerEndpoint
6+
from content_handler import ContentHandler
7+
from genai_core.registry import registry
8+
9+
10+
class Idefics(MultiModalModelBase):
11+
model_id: str
12+
13+
def __init__(self, model_id: str):
14+
self.model_id = model_id
15+
16+
def format_prompt(self, prompt: str, messages: list, files: list) -> str:
17+
18+
human_prompt_template = "User:{prompt}"
19+
human_prompt_with_image = "User:{prompt}![]({image})"
20+
ai_prompt_template = "Assistant:{prompt}"
21+
22+
prompts = []
23+
for message in messages:
24+
if message.type.lower() == ChatbotMessageType.Human.value.lower():
25+
message_files = message.additional_kwargs.get("files", [])
26+
if not message_files:
27+
prompts.append(human_prompt_template.format(prompt=message.content))
28+
for message_file in message_files:
29+
prompts.append(
30+
human_prompt_with_image.format(
31+
prompt=message.content,
32+
image=f"{urljoin(os.environ['CHATBOT_FILES_PRIVATE_API'], message_file['key'])}",
33+
)
34+
)
35+
if message.type.lower() == ChatbotMessageType.AI.value.lower():
36+
prompts.append(ai_prompt_template.format(prompt=message.content))
37+
38+
if not files:
39+
prompts.append(human_prompt_template.format(prompt=prompt))
40+
41+
for file in files:
42+
key = file["key"]
43+
prompts.append(
44+
human_prompt_with_image.format(
45+
prompt=prompt,
46+
image=f"{urljoin(os.environ['CHATBOT_FILES_PRIVATE_API'], key)}",
47+
)
48+
)
49+
50+
prompts.append("<end_of_utterance>\nAssistant:")
51+
52+
prompt_template = "".join(prompts)
53+
print(prompt_template)
54+
return prompt_template
55+
56+
def handle_run(self, prompt: str, model_kwargs: dict):
57+
print(model_kwargs)
58+
params = {
59+
"do_sample": True,
60+
"top_p": 0.2,
61+
"temperature": 0.4,
62+
"top_k": 50,
63+
"max_new_tokens": 512,
64+
"stop": ["User:", "<end_of_utterance>"],
65+
}
66+
if "temperature" in model_kwargs:
67+
params["temperature"] = model_kwargs["temperature"]
68+
if "topP" in model_kwargs:
69+
params["top_p"] = model_kwargs["topP"]
70+
if "maxTokens" in model_kwargs:
71+
params["max_new_tokens"] = model_kwargs["maxTokens"]
72+
73+
mlm = SagemakerEndpoint(
74+
endpoint_name=self.model_id,
75+
region_name=os.environ["AWS_REGION"],
76+
model_kwargs=params,
77+
content_handler=ContentHandler(),
78+
)
79+
80+
mlm_response = mlm.predict(prompt)
81+
return mlm_response
82+
83+
84+
registry.register(r"^sagemaker.*idefics*", Idefics)

0 commit comments

Comments
 (0)