Skip to content

Commit 9eb78b7

Browse files
authored
fix(@langchain/mistralai): Added logic to ensure toolCalls have corresponding toolResponses when sending messages to the Mistral API (#9023)
1 parent dafd038 commit 9eb78b7

File tree

3 files changed

+120
-21
lines changed

3 files changed

+120
-21
lines changed

.changeset/real-dogs-listen.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@langchain/mistralai": patch
3+
---
4+
5+
Added logic to ensure toolCalls have corresponding toolResponses when sending messages to the Mistral API

libs/langchain-mistralai/src/chat_models.ts

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ export interface ChatMistralAIInput
221221
numCompletions?: number;
222222
}
223223

224-
function convertMessagesToMistralMessages(
224+
export function convertMessagesToMistralMessages(
225225
messages: Array<BaseMessage>
226226
): Array<MistralAIMessage> {
227227
const getRole = (role: MessageType) => {
@@ -323,35 +323,80 @@ function convertMessagesToMistralMessages(
323323
return undefined;
324324
};
325325

326-
return messages.map((message) => {
326+
// Build a set of toolCallIds that have corresponding tool responses present
327+
// to ensure 1:1 assistant toolCalls <-> tool responses.
328+
const toolResponseIds = new Set<string>();
329+
for (const m of messages) {
330+
if ("tool_call_id" in m && typeof m.tool_call_id === "string") {
331+
toolResponseIds.add(
332+
_convertToolCallIdToMistralCompatible(m.tool_call_id)
333+
);
334+
}
335+
}
336+
337+
return messages.flatMap((message) => {
327338
const toolCalls = getTools(message);
328339
const content = getContent(message.content, message.getType());
329340
if ("tool_call_id" in message && typeof message.tool_call_id === "string") {
330-
return {
331-
role: getRole(message.getType()),
332-
content,
333-
name: message.name,
334-
toolCallId: _convertToolCallIdToMistralCompatible(message.tool_call_id),
335-
};
341+
return [
342+
{
343+
role: getRole(message.getType()),
344+
content,
345+
name: message.name,
346+
toolCallId: _convertToolCallIdToMistralCompatible(
347+
message.tool_call_id
348+
),
349+
} as MistralAIMessage,
350+
];
336351
// Mistral "assistant" role can only support either content or tool calls but not both
337352
} else if (isAIMessage(message)) {
338353
if (toolCalls === undefined) {
339-
return {
340-
role: getRole(message.getType()),
341-
content,
342-
};
354+
return [
355+
{
356+
role: getRole(message.getType()),
357+
content,
358+
} as MistralAIMessage,
359+
];
343360
} else {
344-
return {
345-
role: getRole(message.getType()),
346-
toolCalls,
347-
};
361+
// Filter out toolCalls that do not have a matching tool response later in the list
362+
const filteredToolCalls = toolCalls.filter((tc) =>
363+
toolResponseIds.has(
364+
_convertToolCallIdToMistralCompatible(tc.id ?? "")
365+
)
366+
);
367+
368+
if (filteredToolCalls.length === 0) {
369+
// If there are no matching tool responses, and there's no content, drop this message
370+
const isEmptyContent =
371+
(typeof content === "string" && content.trim() === "") ||
372+
(Array.isArray(content) && content.length === 0);
373+
if (isEmptyContent) {
374+
return [];
375+
}
376+
// Otherwise, send content only
377+
return [
378+
{
379+
role: getRole(message.getType()),
380+
content,
381+
} as MistralAIMessage,
382+
];
383+
}
384+
385+
return [
386+
{
387+
role: getRole(message.getType()),
388+
toolCalls: filteredToolCalls as MistralAIToolCall[],
389+
} as MistralAIMessage,
390+
];
348391
}
349392
}
350393

351-
return {
352-
role: getRole(message.getType()),
353-
content,
354-
};
394+
return [
395+
{
396+
role: getRole(message.getType()),
397+
content,
398+
} as MistralAIMessage,
399+
];
355400
}) as MistralAIMessage[];
356401
}
357402

libs/langchain-mistralai/src/tests/chat_models.test.ts

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
import { ChatMistralAI } from "../chat_models.js";
1+
import { AIMessage, HumanMessage, ToolMessage } from "@langchain/core/messages";
2+
import {
3+
ChatMistralAI,
4+
convertMessagesToMistralMessages,
5+
} from "../chat_models.js";
26
import {
37
_isValidMistralToolCallId,
48
_convertToolCallIdToMistralCompatible,
@@ -38,3 +42,48 @@ test("Serialization", () => {
3842
`{"lc":1,"type":"constructor","id":["langchain","chat_models","mistralai","ChatMistralAI"],"kwargs":{"mistral_api_key":{"lc":1,"type":"secret","id":["MISTRAL_API_KEY"]}}}`
3943
);
4044
});
45+
46+
/**
47+
* Test to make sure that the logic in convertMessagesToMistralMessages that makes sure
48+
* tool calls are only included if there is a corresponding ToolMessage works as expected
49+
*
50+
* Or else the Mistral API will reject the request
51+
*/
52+
test("convertMessagesToMistralMessages converts roles and filters toolCalls", () => {
53+
const msgs = [
54+
new HumanMessage("hi"),
55+
new AIMessage({
56+
content: "",
57+
tool_calls: [
58+
{
59+
id: "123456789",
60+
name: "extract-1",
61+
args: { answer: "x" },
62+
type: "tool_call",
63+
},
64+
{ id: "ORPHAN123", name: "noop", args: {}, type: "tool_call" },
65+
],
66+
}),
67+
new ToolMessage({ tool_call_id: "123456789", content: "result payload" }),
68+
];
69+
70+
const converted = convertMessagesToMistralMessages(msgs) as {
71+
role: "user" | "assistant" | "tool";
72+
toolCalls?: { id: string; name: string; args: Record<string, unknown> }[];
73+
toolCallId?: string;
74+
}[];
75+
// Expect user, assistant (toolCalls), tool
76+
const roles = converted.map((m) => m.role);
77+
expect(roles).toContain("user");
78+
expect(roles).toContain("assistant");
79+
expect(roles).toContain("tool");
80+
81+
const assistantMsg = converted.find((m) => Array.isArray(m.toolCalls)) as {
82+
toolCalls: { id: string }[];
83+
};
84+
expect(assistantMsg.toolCalls.length).toBe(1);
85+
expect(assistantMsg.toolCalls[0].id).toBe("123456789");
86+
87+
const toolMsg = converted.find((m) => m.role === "tool");
88+
expect(toolMsg?.toolCallId).toBe("123456789");
89+
});

0 commit comments

Comments
 (0)