Skip to content

Commit ffe667c

Browse files
Langchain::Assistant when using MistralAI accepts a message with image_url (#803)
* Langchain::Assistant when using MistralAI accepts a message with image_url * fix linting errors
1 parent 33ad323 commit ffe667c

File tree

5 files changed

+104
-18
lines changed

5 files changed

+104
-18
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
## [Unreleased]
2-
- Assistant can now process image_urls in the messages (currently only for OpenAI)
2+
- Assistant can now process image_urls in the messages (currently only for OpenAI and Mistral AI)
33

44
## [0.16.1] - 2024-09-30
55
- Deprecate Langchain::LLM::GooglePalm

lib/langchain/assistants/assistant.rb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,7 @@ def build_chat_params(tools:, instructions:, messages:, tool_choice:)
569569
end
570570

571571
def build_message(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil)
572-
warn "Image URL is not supported by MistralAI currently" if image_url
573-
574-
Langchain::Messages::MistralAIMessage.new(role: role, content: content, tool_calls: tool_calls, tool_call_id: tool_call_id)
572+
Langchain::Messages::MistralAIMessage.new(role: role, content: content, image_url: image_url, tool_calls: tool_calls, tool_call_id: tool_call_id)
575573
end
576574

577575
# Extract the tool call information from the OpenAI tool call hash

lib/langchain/assistants/messages/mistral_ai_message.rb

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@ class MistralAIMessage < Base
1515

1616
# Initialize a new MistralAI message
1717
#
18-
# @param [String] The role of the message
19-
# @param [String] The content of the message
20-
# @param [Array<Hash>] The tool calls made in the message
21-
# @param [String] The ID of the tool call
22-
def initialize(role:, content: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
18+
# @param role [String] The role of the message
19+
# @param content [String] The content of the message
20+
# @param image_url [String] The URL of the image
21+
# @param tool_calls [Array<Hash>] The tool calls made in the message
22+
# @param tool_call_id [String] The ID of the tool call
23+
def initialize(role:, content: nil, image_url: nil, tool_calls: [], tool_call_id: nil) # TODO: Implement image_file: reference (https://platform.openai.com/docs/api-reference/messages/object#messages/object-content)
2324
raise ArgumentError, "Role must be one of #{ROLES.join(", ")}" unless ROLES.include?(role)
2425
raise ArgumentError, "Tool calls must be an array of hashes" unless tool_calls.is_a?(Array) && tool_calls.all? { |tool_call| tool_call.is_a?(Hash) }
2526

2627
@role = role
2728
# Some Tools return content as a JSON hence `.to_s`
2829
@content = content.to_s
30+
# Make sure you're using the Pixtral model if you want to send image_url
31+
@image_url = image_url
2932
@tool_calls = tool_calls
3033
@tool_call_id = tool_call_id
3134
end
@@ -43,9 +46,28 @@ def llm?
4346
def to_hash
4447
{}.tap do |h|
4548
h[:role] = role
46-
h[:content] = content if content # Content is nil for tool calls
47-
h[:tool_calls] = tool_calls if tool_calls.any?
48-
h[:tool_call_id] = tool_call_id if tool_call_id
49+
50+
if tool_calls.any?
51+
h[:tool_calls] = tool_calls
52+
else
53+
h[:tool_call_id] = tool_call_id if tool_call_id
54+
55+
h[:content] = []
56+
57+
if content && !content.empty?
58+
h[:content] << {
59+
type: "text",
60+
text: content
61+
}
62+
end
63+
64+
if image_url
65+
h[:content] << {
66+
type: "image_url",
67+
image_url: image_url
68+
}
69+
end
70+
end
4971
end
5072
end
5173

spec/langchain/assistants/assistant_spec.rb

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,19 @@
523523

524524
thread.add_message(role: "user", content: "foo")
525525
end
526+
527+
it "adds a message with image_url" do
528+
message_with_image = {role: "user", content: "hello", image_url: "https://example.com/image.jpg"}
529+
subject = described_class.new(llm: llm, messages: [])
530+
531+
expect {
532+
subject.add_message(**message_with_image)
533+
}.to change { subject.messages.count }.from(0).to(1)
534+
expect(subject.messages.first).to be_a(Langchain::Messages::MistralAIMessage)
535+
expect(subject.messages.first.role).to eq("user")
536+
expect(subject.messages.first.content).to eq("hello")
537+
expect(subject.messages.first.image_url).to eq("https://example.com/image.jpg")
538+
end
526539
end
527540

528541
describe "#submit_tool_output" do
@@ -568,8 +581,8 @@
568581
allow(subject.llm).to receive(:chat)
569582
.with(
570583
messages: [
571-
{role: "system", content: instructions},
572-
{role: "user", content: "Please calculate 2+2"}
584+
{role: "system", content: [{type: "text", text: instructions}]},
585+
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]}
573586
],
574587
tools: calculator.class.function_schemas.to_openai_format,
575588
tool_choice: "auto"
@@ -612,16 +625,16 @@
612625
allow(subject.llm).to receive(:chat)
613626
.with(
614627
messages: [
615-
{role: "system", content: instructions},
616-
{role: "user", content: "Please calculate 2+2"},
617-
{role: "assistant", content: "", tool_calls: [
628+
{role: "system", content: [{type: "text", text: instructions}]},
629+
{role: "user", content: [{type: "text", text: "Please calculate 2+2"}]},
630+
{role: "assistant", tool_calls: [
618631
{
619632
"function" => {"arguments" => "{\"input\":\"2+2\"}", "name" => "langchain_tool_calculator__execute"},
620633
"id" => "call_9TewGANaaIjzY31UCpAAGLeV",
621634
"type" => "function"
622635
}
623636
]},
624-
{content: "4.0", role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
637+
{content: [{type: "text", text: "4.0"}], role: "tool", tool_call_id: "call_9TewGANaaIjzY31UCpAAGLeV"}
625638
],
626639
tools: calculator.class.function_schemas.to_openai_format,
627640
tool_choice: "auto"
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# frozen_string_literal: true
2+
3+
RSpec.describe Langchain::Messages::MistralAIMessage do
4+
it "raises an error if role is not one of allowed" do
5+
expect { described_class.new(role: "foo") }.to raise_error(ArgumentError)
6+
end
7+
8+
describe "#to_hash" do
9+
context "when role and content are not nil" do
10+
let(:message) { described_class.new(role: "user", content: "Hello, world!", tool_calls: [], tool_call_id: nil) }
11+
12+
it "returns a hash with the role and content key" do
13+
expect(message.to_hash).to eq({role: "user", content: [{type: "text", text: "Hello, world!"}]})
14+
end
15+
end
16+
17+
context "when tool_call_id is not nil" do
18+
let(:message) { described_class.new(role: "tool", content: "Hello, world!", tool_calls: [], tool_call_id: "123") }
19+
20+
it "returns a hash with the tool_call_id key" do
21+
expect(message.to_hash).to eq({role: "tool", content: [{type: "text", text: "Hello, world!"}], tool_call_id: "123"})
22+
end
23+
end
24+
25+
context "when tool_calls is not empty" do
26+
let(:tool_call) {
27+
{"id" => "call_9TewGANaaIjzY31UCpAAGLeV",
28+
"type" => "function",
29+
"function" => {"name" => "weather__execute", "arguments" => "{\"input\":\"Saint Petersburg\"}"}}
30+
}
31+
32+
let(:message) { described_class.new(role: "assistant", tool_calls: [tool_call], tool_call_id: nil) }
33+
34+
it "returns a hash with the tool_calls key" do
35+
expect(message.to_hash).to eq({role: "assistant", tool_calls: [tool_call]})
36+
end
37+
end
38+
39+
context "when image_url is present" do
40+
let(:message) { described_class.new(role: "user", content: "Please describe this image", image_url: "https://example.com/image.jpg") }
41+
42+
it "returns a hash with the image_url key" do
43+
expect(message.to_hash).to eq({
44+
role: "user",
45+
content: [
46+
{type: "text", text: "Please describe this image"},
47+
{type: "image_url", image_url: "https://example.com/image.jpg"}
48+
]
49+
})
50+
end
51+
end
52+
end
53+
end

0 commit comments

Comments
 (0)