Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions codex-rs/core/src/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,140 @@ pub(crate) async fn stream_chat_completions(
}
}

// Validate tool call protocol: ensure all tool_calls have responses and correct ordering.
// Collect all tool_call_ids from assistant messages with tool_calls.
let mut expected_tool_call_ids = std::collections::HashSet::new();
for msg in &messages {
if let Some(obj) = msg.as_object() {
if obj.get("role").and_then(|v| v.as_str()) == Some("assistant")
&& let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array())
{
for tc in tool_calls {
if let Some(id) = tc.get("id").and_then(|v| v.as_str()) {
expected_tool_call_ids.insert(id.to_string());
}
}
}
}
}

// Collect all tool_call_ids from tool response messages.
let mut provided_tool_call_ids = std::collections::HashSet::new();
for msg in &messages {
if let Some(obj) = msg.as_object() {
if obj.get("role").and_then(|v| v.as_str()) == Some("tool")
&& let Some(id) = obj.get("tool_call_id").and_then(|v| v.as_str())
{
provided_tool_call_ids.insert(id.to_string());
}
}
}

// Remove assistant messages with incomplete tool calls (no matching responses).
// Also collect tool_call_ids from assistant messages that will be removed.
let incomplete_ids: Vec<String> = expected_tool_call_ids
.difference(&provided_tool_call_ids)
.cloned()
.collect();

let mut removed_assistant_call_ids = std::collections::HashSet::new();

if !incomplete_ids.is_empty() {
messages.retain(|msg| {
if let Some(obj) = msg.as_object() {
if obj.get("role").and_then(|v| v.as_str()) == Some("assistant")
&& let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array())
{
// Check if this assistant message has any incomplete tool calls.
for tc in tool_calls {
if let Some(id) = tc.get("id").and_then(|v| v.as_str()) {
if incomplete_ids.contains(&id.to_string()) {
// Collect ALL tool_call_ids from this assistant message.
// We need to remove their corresponding tool responses too.
for tc2 in tool_calls {
if let Some(id2) = tc2.get("id").and_then(|v| v.as_str()) {
removed_assistant_call_ids.insert(id2.to_string());
}
}
return false;
}
}
}
}
}
true
});

// Remove tool response messages whose assistant origin was removed.
messages.retain(|msg| {
if let Some(obj) = msg.as_object() {
if obj.get("role").and_then(|v| v.as_str()) == Some("tool") {
if let Some(tool_call_id) = obj.get("tool_call_id").and_then(|v| v.as_str()) {
if removed_assistant_call_ids.contains(tool_call_id) {
return false;
}
}
}
}
true
});
}

// Enforce sequence ordering: OpenAI protocol requires assistant messages with tool_calls
// to be immediately followed by tool response messages. Remove any messages that violate
// this constraint.
let mut i = 0;
while i < messages.len() {
if let Some(obj) = messages[i].as_object() {
if obj.get("role").and_then(|v| v.as_str()) == Some("assistant")
&& obj.contains_key("tool_calls")
{
// Extract all tool_call_ids from this message.
let mut call_ids = std::collections::HashSet::new();
if let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array()) {
for tc in tool_calls {
if let Some(id) = tc.get("id").and_then(|v| v.as_str()) {
call_ids.insert(id.to_string());
}
}
}

// Scan forward to find all corresponding tool responses.
let mut j = i + 1;
let mut found_responses = std::collections::HashSet::new();

while j < messages.len() {
if let Some(next_obj) = messages[j].as_object() {
let next_role = next_obj.get("role").and_then(|v| v.as_str());

if next_role == Some("tool") {
// Accumulate tool responses.
if let Some(tool_call_id) = next_obj.get("tool_call_id").and_then(|v| v.as_str()) {
if call_ids.contains(tool_call_id) {
found_responses.insert(tool_call_id.to_string());
}
}
j += 1;
} else {
// Non-tool message encountered.
if found_responses.len() < call_ids.len() {
// Incomplete sequence: remove the intruding message.
messages.remove(j);
continue;
} else {
// All responses found, this is the next turn.
break;
}
}
} else {
j += 1;
}
}
}
}
i += 1;
}

let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": model_family.slug,
Expand Down
Loading