diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 785a4d4ce5..ff88fbf0fe 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -329,6 +329,140 @@ pub(crate) async fn stream_chat_completions( } } + // Validate tool call protocol: ensure all tool_calls have responses and correct ordering. + // Collect all tool_call_ids from assistant messages with tool_calls. + let mut expected_tool_call_ids = std::collections::HashSet::new(); + for msg in &messages { + if let Some(obj) = msg.as_object() { + if obj.get("role").and_then(|v| v.as_str()) == Some("assistant") + && let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array()) + { + for tc in tool_calls { + if let Some(id) = tc.get("id").and_then(|v| v.as_str()) { + expected_tool_call_ids.insert(id.to_string()); + } + } + } + } + } + + // Collect all tool_call_ids from tool response messages. + let mut provided_tool_call_ids = std::collections::HashSet::new(); + for msg in &messages { + if let Some(obj) = msg.as_object() { + if obj.get("role").and_then(|v| v.as_str()) == Some("tool") + && let Some(id) = obj.get("tool_call_id").and_then(|v| v.as_str()) + { + provided_tool_call_ids.insert(id.to_string()); + } + } + } + + // Remove assistant messages with incomplete tool calls (no matching responses). + // Also collect tool_call_ids from assistant messages that will be removed. + let incomplete_ids: Vec = expected_tool_call_ids + .difference(&provided_tool_call_ids) + .cloned() + .collect(); + + let mut removed_assistant_call_ids = std::collections::HashSet::new(); + + if !incomplete_ids.is_empty() { + messages.retain(|msg| { + if let Some(obj) = msg.as_object() { + if obj.get("role").and_then(|v| v.as_str()) == Some("assistant") + && let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array()) + { + // Check if this assistant message has any incomplete tool calls. + for tc in tool_calls { + if let Some(id) = tc.get("id").and_then(|v| v.as_str()) { + if incomplete_ids.contains(&id.to_string()) { + // Collect ALL tool_call_ids from this assistant message. + // We need to remove their corresponding tool responses too. + for tc2 in tool_calls { + if let Some(id2) = tc2.get("id").and_then(|v| v.as_str()) { + removed_assistant_call_ids.insert(id2.to_string()); + } + } + return false; + } + } + } + } + } + true + }); + + // Remove tool response messages whose assistant origin was removed. + messages.retain(|msg| { + if let Some(obj) = msg.as_object() { + if obj.get("role").and_then(|v| v.as_str()) == Some("tool") { + if let Some(tool_call_id) = obj.get("tool_call_id").and_then(|v| v.as_str()) { + if removed_assistant_call_ids.contains(tool_call_id) { + return false; + } + } + } + } + true + }); + } + + // Enforce sequence ordering: OpenAI protocol requires assistant messages with tool_calls + // to be immediately followed by tool response messages. Remove any messages that violate + // this constraint. + let mut i = 0; + while i < messages.len() { + if let Some(obj) = messages[i].as_object() { + if obj.get("role").and_then(|v| v.as_str()) == Some("assistant") + && obj.contains_key("tool_calls") + { + // Extract all tool_call_ids from this message. + let mut call_ids = std::collections::HashSet::new(); + if let Some(tool_calls) = obj.get("tool_calls").and_then(|v| v.as_array()) { + for tc in tool_calls { + if let Some(id) = tc.get("id").and_then(|v| v.as_str()) { + call_ids.insert(id.to_string()); + } + } + } + + // Scan forward to find all corresponding tool responses. + let mut j = i + 1; + let mut found_responses = std::collections::HashSet::new(); + + while j < messages.len() { + if let Some(next_obj) = messages[j].as_object() { + let next_role = next_obj.get("role").and_then(|v| v.as_str()); + + if next_role == Some("tool") { + // Accumulate tool responses. + if let Some(tool_call_id) = next_obj.get("tool_call_id").and_then(|v| v.as_str()) { + if call_ids.contains(tool_call_id) { + found_responses.insert(tool_call_id.to_string()); + } + } + j += 1; + } else { + // Non-tool message encountered. + if found_responses.len() < call_ids.len() { + // Incomplete sequence: remove the intruding message. + messages.remove(j); + continue; + } else { + // All responses found, this is the next turn. + break; + } + } + } else { + j += 1; + } + } + } + } + i += 1; + } + let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; let payload = json!({ "model": model_family.slug,