diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 7bd60b12fb..b2fcd77fcb 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -829,6 +829,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "codex-api-client" +version = "0.0.0" +dependencies = [ + "async-trait", + "bytes", + "codex-app-server-protocol", + "codex-otel", + "codex-protocol", + "eventsource-stream", + "futures", + "maplit", + "regex-lite", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", + "tokio-util", + "toml", + "tracing", +] + [[package]] name = "codex-app-server" version = "0.0.0" @@ -1039,6 +1062,7 @@ name = "codex-common" version = "0.0.0" dependencies = [ "clap", + "codex-api-client", "codex-app-server-protocol", "codex-core", "codex-protocol", @@ -1059,6 +1083,7 @@ dependencies = [ "base64", "bytes", "chrono", + "codex-api-client", "codex-app-server-protocol", "codex-apply-patch", "codex-async-utils", @@ -1131,6 +1156,7 @@ dependencies = [ "anyhow", "assert_cmd", "clap", + "codex-api-client", "codex-arg0", "codex-common", "codex-core", @@ -1206,7 +1232,6 @@ name = "codex-git" version = "0.0.0" dependencies = [ "assert_matches", - "once_cell", "pretty_assertions", "regex", "schemars 0.8.22", @@ -1296,6 +1321,7 @@ dependencies = [ "assert_matches", "async-stream", "bytes", + "codex-api-client", "codex-core", "futures", "reqwest", @@ -1437,6 +1463,7 @@ dependencies = [ "chrono", "clap", "codex-ansi-escape", + "codex-api-client", "codex-app-server-protocol", "codex-arg0", "codex-common", @@ -1670,6 +1697,7 @@ version = "0.0.0" dependencies = [ "anyhow", "assert_cmd", + "codex-api-client", "codex-core", "codex-protocol", "notify", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index a2c52f44ad..82b7e4b32d 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -38,7 +38,7 @@ members = [ "utils/pty", "utils/readiness", "utils/string", - "utils/tokenizer", + "utils/tokenizer", "api-client", ] resolver = "2" @@ -87,6 +87,7 @@ codex-utils-pty = { path = "utils/pty" } codex-utils-readiness = { path = "utils/readiness" } codex-utils-string = { path = "utils/string" } codex-utils-tokenizer = { path = "utils/tokenizer" } +codex-api-client = { path = "api-client" } core_test_support = { path = "core/tests/common" } mcp-types = { path = "mcp-types" } mcp_test_support = { path = "mcp-server/tests/common" } diff --git a/codex-rs/api-client/Cargo.toml b/codex-rs/api-client/Cargo.toml new file mode 100644 index 0000000000..a7a2a35a1f --- /dev/null +++ b/codex-rs/api-client/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "codex-api-client" +version.workspace = true +edition.workspace = true + +[dependencies] +async-trait = { workspace = true } +bytes = { workspace = true } +codex-app-server-protocol = { workspace = true } +codex-otel = { workspace = true } +codex-protocol = { path = "../protocol" } +eventsource-stream = { workspace = true } +futures = { workspace = true, default-features = false, features = ["std"] } +regex-lite = { workspace = true } +reqwest = { workspace = true, features = ["json", "stream"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros", "io-util"] } +tokio-util = { workspace = true } +tracing = { workspace = true } + +[dev-dependencies] +maplit = "1.0.2" +toml = { workspace = true } + +[lints] +workspace = true diff --git a/codex-rs/api-client/src/api.rs b/codex-rs/api-client/src/api.rs new file mode 100644 index 0000000000..8ebb356a1a --- /dev/null +++ b/codex-rs/api-client/src/api.rs @@ -0,0 +1,16 @@ +use async_trait::async_trait; + +use crate::error::Error; +use crate::prompt::Prompt; +use crate::stream::ResponseStream; + +#[async_trait] +pub trait ApiClient: Send + Sync { + type Config: Send + Sync; + + async fn new(config: Self::Config) -> Result + where + Self: Sized; + + async fn stream(&self, prompt: Prompt) -> Result; +} diff --git a/codex-rs/api-client/src/auth.rs b/codex-rs/api-client/src/auth.rs new file mode 100644 index 0000000000..0c6fb94490 --- /dev/null +++ b/codex-rs/api-client/src/auth.rs @@ -0,0 +1,15 @@ +use async_trait::async_trait; +use codex_app_server_protocol::AuthMode; + +#[derive(Debug, Clone)] +pub struct AuthContext { + pub mode: AuthMode, + pub bearer_token: Option, + pub account_id: Option, +} + +#[async_trait] +pub trait AuthProvider: Send + Sync { + async fn auth_context(&self) -> Option; + async fn refresh_token(&self) -> Result, String>; +} diff --git a/codex-rs/api-client/src/chat.rs b/codex-rs/api-client/src/chat.rs new file mode 100644 index 0000000000..a002fb3122 --- /dev/null +++ b/codex-rs/api-client/src/chat.rs @@ -0,0 +1,866 @@ +use std::collections::VecDeque; +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::models::ContentItem; +use codex_protocol::models::FunctionCallOutputContentItem; +use codex_protocol::models::ReasoningItemContent; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use eventsource_stream::Eventsource; +use futures::Stream; +use futures::StreamExt; +use futures::TryStreamExt; +use serde_json::Value; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tracing::debug; +use tracing::trace; + +use crate::api::ApiClient; +use crate::error::Error; +use crate::model_provider::ModelProviderInfo; +use crate::prompt::Prompt; +use crate::stream::ResponseEvent; +use crate::stream::ResponseStream; + +pub type Result = std::result::Result; + +#[derive(Clone, Copy, Debug)] +pub enum ChatAggregationMode { + AggregatedOnly, + Streaming, +} + +#[derive(Clone)] +pub struct ChatCompletionsApiClientConfig { + pub http_client: reqwest::Client, + pub provider: ModelProviderInfo, + pub model: String, + pub otel_event_manager: OtelEventManager, + pub session_source: SessionSource, + pub aggregation_mode: ChatAggregationMode, +} + +#[derive(Clone)] +pub struct ChatCompletionsApiClient { + config: ChatCompletionsApiClientConfig, +} + +#[async_trait] +impl ApiClient for ChatCompletionsApiClient { + type Config = ChatCompletionsApiClientConfig; + + async fn new(config: Self::Config) -> Result { + Ok(Self { config }) + } + + async fn stream(&self, prompt: Prompt) -> Result { + Self::validate_prompt(&prompt)?; + + let payload = self.build_payload(&prompt)?; + let (tx_event, rx_event) = mpsc::channel::>(1600); + + let mut attempt = 0u64; + let max_retries = self.config.provider.request_max_retries(); + + loop { + attempt += 1; + + let mut req_builder = self + .config + .provider + .create_request_builder(&self.config.http_client, &None) + .await?; + + if let SessionSource::SubAgent(sub) = &self.config.session_source { + let subagent = if let SubAgentSource::Other(label) = sub { + label.clone() + } else { + serde_json::to_value(sub) + .ok() + .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .unwrap_or_else(|| "other".to_string()) + }; + req_builder = req_builder.header("x-openai-subagent", subagent); + } + + let res = self + .config + .otel_event_manager + .log_request(attempt, || { + req_builder + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(&payload) + .send() + }) + .await; + + match res { + Ok(resp) if resp.status().is_success() => { + let stream = resp + .bytes_stream() + .map_err(|err| Error::ResponseStreamFailed { + source: err, + request_id: None, + }); + let idle_timeout = self.config.provider.stream_idle_timeout(); + let otel = self.config.otel_event_manager.clone(); + let mode = self.config.aggregation_mode; + + tokio::spawn(process_chat_sse( + stream, + tx_event.clone(), + idle_timeout, + otel, + mode, + )); + + return Ok(ResponseStream { rx_event }); + } + Ok(resp) => { + if attempt >= max_retries { + let status = resp.status(); + let body = resp + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(Error::UnexpectedStatus { status, body }); + } + + let retry_after = resp + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .map(Duration::from_secs); + tokio::time::sleep(retry_after.unwrap_or_else(|| backoff(attempt))).await; + } + Err(error) => { + if attempt >= max_retries { + return Err(Error::Http(error)); + } + tokio::time::sleep(backoff(attempt)).await; + } + } + } + } +} + +impl ChatCompletionsApiClient { + fn validate_prompt(prompt: &Prompt) -> Result<()> { + if prompt.output_schema.is_some() { + return Err(Error::UnsupportedOperation( + "output_schema is not supported for Chat Completions API".to_string(), + )); + } + Ok(()) + } + + fn build_payload(&self, prompt: &Prompt) -> Result { + let mut messages = Vec::::new(); + messages.push(json!({ "role": "system", "content": prompt.instructions })); + + let mut reasoning_by_anchor_index: std::collections::HashMap = + std::collections::HashMap::new(); + + let mut last_emitted_role: Option<&str> = None; + for item in &prompt.input { + match item { + ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + last_emitted_role = Some("assistant"); + } + ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), + ResponseItem::Reasoning { .. } + | ResponseItem::Other + | ResponseItem::CustomToolCall { .. } + | ResponseItem::CustomToolCallOutput { .. } + | ResponseItem::WebSearchCall { .. } + | ResponseItem::GhostSnapshot { .. } => {} + } + } + + let mut last_user_index: Option = None; + for (idx, item) in prompt.input.iter().enumerate() { + if let ResponseItem::Message { role, .. } = item + && role == "user" + { + last_user_index = Some(idx); + } + } + + if !matches!(last_emitted_role, Some("user")) { + for (idx, item) in prompt.input.iter().enumerate() { + if let Some(u_idx) = last_user_index + && idx <= u_idx + { + continue; + } + + if let ResponseItem::Reasoning { + content: Some(items), + .. + } = item + { + let mut text = String::new(); + for entry in items { + match entry { + ReasoningItemContent::ReasoningText { text: segment } + | ReasoningItemContent::Text { text: segment } => { + text.push_str(segment); + } + } + } + if text.trim().is_empty() { + continue; + } + + let mut attached = false; + if idx > 0 + && let ResponseItem::Message { role, .. } = &prompt.input[idx - 1] + && role == "assistant" + { + reasoning_by_anchor_index + .entry(idx - 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + attached = true; + } + + if !attached && idx + 1 < prompt.input.len() { + match &prompt.input[idx + 1] { + ResponseItem::FunctionCall { .. } + | ResponseItem::LocalShellCall { .. } => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + ResponseItem::Message { role, .. } if role == "assistant" => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + _ => {} + } + } + } + } + } + + let mut last_assistant_text: Option = None; + + for (idx, item) in prompt.input.iter().enumerate() { + match item { + ResponseItem::Message { role, content, .. } => { + let mut text = String::new(); + let mut items: Vec = Vec::new(); + let mut saw_image = false; + + for c in content { + match c { + ContentItem::InputText { text: t } + | ContentItem::OutputText { text: t } => { + text.push_str(t); + items.push(json!({"type":"text","text": t})); + } + ContentItem::InputImage { image_url } => { + saw_image = true; + items.push( + json!({"type":"image_url","image_url": {"url": image_url}}), + ); + } + } + } + + if role == "assistant" { + if let Some(prev) = &last_assistant_text + && prev == &text + { + continue; + } + last_assistant_text = Some(text.clone()); + } + + let content_value = if role == "assistant" { + json!(text) + } else if saw_image { + json!(items) + } else { + json!(text) + }; + + let mut message = json!({ + "role": role, + "content": content_value, + }); + + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = message.as_object_mut() + { + obj.insert("reasoning".to_string(), json!({"text": reasoning})); + } + + messages.push(message); + } + ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + } => { + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + }, + }], + })); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + let content_value = if let Some(items) = &output.content_items { + let mapped: Vec = items + .iter() + .map(|item| match item { + FunctionCallOutputContentItem::InputText { text } => { + json!({"type":"text","text": text}) + } + FunctionCallOutputContentItem::InputImage { image_url } => { + json!({"type":"image_url","image_url": {"url": image_url}}) + } + }) + .collect(); + json!(mapped) + } else { + json!(output.content) + }; + + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": content_value, + })); + } + ResponseItem::LocalShellCall { + id, + call_id, + action, + .. + } => { + let tool_id = call_id + .clone() + .filter(|value| !value.is_empty()) + .or_else(|| id.clone()) + .unwrap_or_default(); + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_id, + "type": "function", + "function": { + "name": "shell", + "arguments": serde_json::to_string(action).unwrap_or_default(), + }, + }], + })); + } + ResponseItem::CustomToolCall { + call_id, + name, + input, + .. + } => { + messages.push(json!({ + "role": "assistant", + "tool_calls": [{ + "id": call_id.clone(), + "type": "function", + "function": { + "name": name, + "arguments": input, + }, + }], + })); + } + ResponseItem::CustomToolCallOutput { call_id, output } => { + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": output, + })); + } + ResponseItem::WebSearchCall { .. } + | ResponseItem::Reasoning { .. } + | ResponseItem::Other + | ResponseItem::GhostSnapshot { .. } => {} + } + } + + let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; + let payload = json!({ + "model": self.config.model, + "messages": messages, + "stream": true, + "tools": tools_json, + }); + + trace!("chat completions payload: {}", payload); + Ok(payload) + } +} + +async fn append_assistant_text( + tx_event: &mpsc::Sender>, + assistant_item: &mut Option, + text: String, +) { + if assistant_item.is_none() { + let item = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![], + }; + *assistant_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Message { content, .. }) = assistant_item { + content.push(ContentItem::OutputText { text: text.clone() }); + let _ = tx_event + .send(Ok(ResponseEvent::OutputTextDelta(text.clone()))) + .await; + } +} + +async fn append_reasoning_text( + tx_event: &mpsc::Sender>, + reasoning_item: &mut Option, + text: String, +) { + if reasoning_item.is_none() { + let item = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![]), + encrypted_content: None, + }; + *reasoning_item = Some(item.clone()); + let _ = tx_event + .send(Ok(ResponseEvent::OutputItemAdded(item))) + .await; + } + + if let Some(ResponseItem::Reasoning { + content: Some(content), + .. + }) = reasoning_item + { + content.push(ReasoningItemContent::ReasoningText { text: text.clone() }); + + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta(text.clone()))) + .await; + } +} + +async fn process_chat_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, + otel_event_manager: OtelEventManager, + _aggregation_mode: ChatAggregationMode, +) where + S: Stream> + Unpin, +{ + let mut stream = stream.eventsource(); + + #[derive(Default)] + struct FunctionCallState { + name: Option, + arguments: String, + call_id: Option, + } + + let mut function_call_state = FunctionCallState::default(); + let mut assistant_item: Option = None; + let mut reasoning_item: Option = None; + + loop { + let response = timeout(idle_timeout, stream.next()).await; + otel_event_manager.log_sse_event(&response, idle_timeout); + + let sse = match response { + Ok(Some(Ok(sse))) => sse, + Ok(Some(Err(e))) => { + debug!("SSE Error: {e:#}"); + let event = Error::Stream(e.to_string(), None); + let _ = tx_event.send(Err(event)).await; + return; + } + Ok(None) => { + if let Some(item) = assistant_item.take() { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + if let Some(item) = reasoning_item.take() { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + return; + } + Err(_) => { + let _ = tx_event + .send(Err(Error::Stream( + "idle timeout waiting for SSE".into(), + None, + ))) + .await; + return; + } + }; + + trace!("chat_completions received SSE chunk: {}", sse.data); + + if sse.data.trim() == "[DONE]" { + if let Some(item) = assistant_item.take() { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + if let Some(item) = reasoning_item.take() { + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + token_usage: None, + })) + .await; + return; + } + + let chunk: serde_json::Value = match serde_json::from_str(&sse.data) { + Ok(v) => v, + Err(_) => continue, + }; + + let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); + + if let Some(choice) = choice_opt { + if let Some(content) = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + && !content.is_empty() + { + append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await; + } + + if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { + let mut maybe_text = reasoning_val + .as_str() + .map(str::to_string) + .filter(|s| !s.is_empty()); + + if maybe_text.is_none() && reasoning_val.is_object() { + if let Some(s) = reasoning_val + .get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + { + maybe_text = Some(s.to_string()); + } else if let Some(s) = reasoning_val + .get("content") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + { + maybe_text = Some(s.to_string()); + } + } + + if let Some(reasoning) = maybe_text { + append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await; + } + } + + if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) + { + if let Some(s) = message_reasoning.as_str() { + if !s.is_empty() { + append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; + } + } else if let Some(obj) = message_reasoning.as_object() + && let Some(s) = obj + .get("text") + .and_then(|v| v.as_str()) + .or_else(|| obj.get("content").and_then(|v| v.as_str())) + && !s.is_empty() + { + append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; + } + } + + if let Some(tool_calls) = choice + .get("delta") + .and_then(|d| d.get("tool_calls")) + .and_then(|v| v.as_array()) + { + for call in tool_calls { + if let Some(index) = call.get("index").and_then(serde_json::Value::as_u64) + && index == 0 + && let Some(function) = call.get("function") + { + if let Some(name) = function.get("name").and_then(|n| n.as_str()) { + function_call_state.name = Some(name.to_string()); + } + if let Some(arguments) = function.get("arguments").and_then(|a| a.as_str()) + { + function_call_state.arguments.push_str(arguments); + } + if let Some(id) = call.get("id").and_then(|i| i.as_str()) { + function_call_state.call_id = Some(id.to_string()); + } + + if let Some(finish) = choice.get("finish_reason").and_then(|f| f.as_str()) + && finish == "tool_calls" + && let Some(name) = function_call_state.name.take() + { + let call_id = function_call_state.call_id.take().unwrap_or_default(); + let arguments = std::mem::take(&mut function_call_state.arguments); + let item = ResponseItem::FunctionCall { + id: None, + name, + arguments, + call_id, + }; + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + } + } + } + } + } +} + +pub trait AggregateStreamExt: Stream> + Sized { + fn aggregate(self) -> AggregatedChatStream + where + Self: Unpin, + { + AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) + } + + fn streaming_mode(self) -> AggregatedChatStream + where + Self: Unpin, + { + AggregatedChatStream::new(self, AggregateMode::Streaming) + } +} + +impl AggregateStreamExt for S where S: Stream> + Sized + Unpin {} + +enum AggregateMode { + AggregatedOnly, + Streaming, +} + +pub struct AggregatedChatStream { + inner: S, + cumulative: String, + cumulative_reasoning: String, + pending: VecDeque, + mode: AggregateMode, +} + +impl AggregatedChatStream +where + S: Stream> + Unpin, +{ + fn new(inner: S, mode: AggregateMode) -> Self { + Self { + inner, + cumulative: String::new(), + cumulative_reasoning: String::new(), + pending: VecDeque::new(), + mode, + } + } +} + +impl Stream for AggregatedChatStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(ev) = self.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + + loop { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { + let is_assistant_message = matches!( + &item, + ResponseItem::Message { role, .. } if role == "assistant" + ); + + if is_assistant_message { + match self.mode { + AggregateMode::AggregatedOnly => { + if self.cumulative.is_empty() + && let ResponseItem::Message { content, .. } = &item + && let Some(text) = content.iter().find_map(|c| match c { + ContentItem::OutputText { text } => Some(text), + _ => None, + }) + { + self.cumulative.push_str(text); + } + continue; + } + AggregateMode::Streaming => { + if self.cumulative.is_empty() { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + item, + )))); + } else { + continue; + } + } + } + } + + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); + } + Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); + } + Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))) => { + let mut emitted_any = false; + + if !self.cumulative_reasoning.is_empty() + && matches!(self.mode, AggregateMode::AggregatedOnly) + { + let aggregated_reasoning = ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![ReasoningItemContent::ReasoningText { + text: std::mem::take(&mut self.cumulative_reasoning), + }]), + encrypted_content: None, + }; + self.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); + emitted_any = true; + } + + if !self.cumulative.is_empty() { + let aggregated_message = ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: std::mem::take(&mut self.cumulative), + }], + }; + self.pending + .push_back(ResponseEvent::OutputItemDone(aggregated_message)); + emitted_any = true; + } + + if emitted_any { + self.pending.push_back(ResponseEvent::Completed { + response_id: response_id.clone(), + token_usage: token_usage.clone(), + }); + if let Some(ev) = self.pending.pop_front() { + return Poll::Ready(Some(Ok(ev))); + } + } + + return Poll::Ready(Some(Ok(ResponseEvent::Completed { + response_id, + token_usage, + }))); + } + Poll::Ready(Some(Ok(ResponseEvent::Created))) => continue, + Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { + self.cumulative.push_str(&delta); + if matches!(self.mode, AggregateMode::Streaming) { + return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => { + self.cumulative_reasoning.push_str(&delta); + if matches!(self.mode, AggregateMode::Streaming) { + return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))); + } + } + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => continue, + Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => continue, + Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); + } + } + } + } +} + +fn create_tools_json_for_chat_completions_api( + tools: &[serde_json::Value], +) -> Result> { + let tools_json = tools + .iter() + .filter_map(|tool| { + if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) { + return None; + } + + let function_value = if let Some(function) = tool.get("function") { + function.clone() + } else if let Some(map) = tool.as_object() { + let mut function = map.clone(); + function.remove("type"); + Value::Object(function) + } else { + return None; + }; + + Some(json!({ + "type": "function", + "function": function_value, + })) + }) + .collect::>(); + Ok(tools_json) +} + +fn backoff(attempt: u64) -> Duration { + let capped = attempt.min(6); + Duration::from_millis(100 * 2u64.pow(capped as u32)) +} diff --git a/codex-rs/api-client/src/error.rs b/codex-rs/api-client/src/error.rs new file mode 100644 index 0000000000..c02dcab71a --- /dev/null +++ b/codex-rs/api-client/src/error.rs @@ -0,0 +1,42 @@ +use std::time::Duration; + +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Debug, Error)] +pub enum Error { + #[error("{0}")] + UnsupportedOperation(String), + #[error(transparent)] + Http(#[from] reqwest::Error), + #[error("{source}")] + ResponseStreamFailed { + #[source] + source: reqwest::Error, + request_id: Option, + }, + #[error("{0}")] + Stream(String, Option), + #[error("unexpected status {status}: {body}")] + UnexpectedStatus { + status: reqwest::StatusCode, + body: String, + }, + #[error("retry limit reached (status {status}, request id: {request_id:?})")] + RetryLimit { + status: reqwest::StatusCode, + request_id: Option, + }, + #[error("missing environment variable {var}")] + MissingEnvVar { + var: String, + instructions: Option, + }, + #[error("{0}")] + Auth(String), + #[error(transparent)] + Json(#[from] serde_json::Error), + #[error("{0}")] + Other(String), +} diff --git a/codex-rs/api-client/src/lib.rs b/codex-rs/api-client/src/lib.rs new file mode 100644 index 0000000000..f57ac46c8f --- /dev/null +++ b/codex-rs/api-client/src/lib.rs @@ -0,0 +1,35 @@ +pub mod api; +pub mod auth; +pub mod chat; +pub mod error; +pub mod model_provider; +pub mod prompt; +pub mod responses; +pub mod stream; + +pub use crate::api::ApiClient; +pub use crate::auth::AuthContext; +pub use crate::auth::AuthProvider; +pub use crate::chat::AggregateStreamExt; +pub use crate::chat::ChatAggregationMode; +pub use crate::chat::ChatCompletionsApiClient; +pub use crate::chat::ChatCompletionsApiClientConfig; +pub use crate::error::Error; +pub use crate::error::Result; +pub use crate::model_provider::BUILT_IN_OSS_MODEL_PROVIDER_ID; +pub use crate::model_provider::ModelProviderInfo; +pub use crate::model_provider::WireApi; +pub use crate::model_provider::built_in_model_providers; +pub use crate::model_provider::create_oss_provider; +pub use crate::model_provider::create_oss_provider_with_base_url; +pub use crate::prompt::Prompt; +pub use crate::responses::ResponsesApiClient; +pub use crate::responses::ResponsesApiClientConfig; +pub use crate::responses::stream_from_fixture; +pub use crate::stream::EventStream; +pub use crate::stream::Reasoning; +pub use crate::stream::ResponseEvent; +pub use crate::stream::ResponseStream; +pub use crate::stream::TextControls; +pub use crate::stream::TextFormat; +pub use crate::stream::TextFormatType; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/api-client/src/model_provider.rs similarity index 79% rename from codex-rs/core/src/model_provider_info.rs rename to codex-rs/api-client/src/model_provider.rs index 8dc252aa7c..43ecc32d38 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/api-client/src/model_provider.rs @@ -5,17 +5,18 @@ //! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` //! key. These override or extend the defaults at runtime. -use crate::CodexAuth; -use crate::default_client::CodexHttpClient; -use crate::default_client::CodexRequestBuilder; -use codex_app_server_protocol::AuthMode; -use serde::Deserialize; -use serde::Serialize; use std::collections::HashMap; use std::env::VarError; use std::time::Duration; -use crate::error::EnvVarError; +use codex_app_server_protocol::AuthMode; +use serde::Deserialize; +use serde::Serialize; + +use crate::auth::AuthContext; +use crate::error::Error; +use crate::error::Result; + const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000; const DEFAULT_STREAM_MAX_RETRIES: u64 = 5; const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; @@ -23,19 +24,19 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; const MAX_STREAM_MAX_RETRIES: u64 = 100; /// Hard cap for user-configured `request_max_retries`. const MAX_REQUEST_MAX_RETRIES: u64 = 100; +const DEFAULT_OLLAMA_PORT: u32 = 11434; /// Wire protocol that the provider speaks. Most third-party services only /// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI /// itself (and a handful of others) additionally expose the more modern -/// *Responses* API. The two protocols use different request/response shapes -/// and *cannot* be auto-detected at runtime, therefore each provider entry +/// Responses API. The two protocols use different request/response shapes +/// and cannot be auto-detected at runtime, therefore each provider entry /// must declare which one it expects. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum WireApi { /// The Responses API exposed by OpenAI at `/v1/responses`. Responses, - /// Regular Chat Completions compatible with `/v1/chat/completions`. #[default] Chat, @@ -50,87 +51,79 @@ pub struct ModelProviderInfo { pub base_url: Option, /// Environment variable that stores the user's API key for this provider. pub env_key: Option, - /// Optional instructions to help the user get a valid value for the /// variable and set it. pub env_key_instructions: Option, - /// Value to use with `Authorization: Bearer ` header. Use of this /// config is discouraged in favor of `env_key` for security reasons, but /// this may be necessary when using this programmatically. pub experimental_bearer_token: Option, - /// Which wire protocol this provider expects. #[serde(default)] pub wire_api: WireApi, - /// Optional query parameters to append to the base URL. pub query_params: Option>, - /// Additional HTTP headers to include in requests to this provider where /// the (key, value) pairs are the header name and value. pub http_headers: Option>, - /// Optional HTTP headers to include in requests to this provider where the - /// (key, value) pairs are the header name and _environment variable_ whose + /// (key, value) pairs are the header name and environment variable whose /// value should be used. If the environment variable is not set, or the /// value is empty, the header will not be included in the request. pub env_http_headers: Option>, - /// Maximum number of times to retry a failed HTTP request to this provider. pub request_max_retries: Option, - /// Number of times to retry reconnecting a dropped streaming response before failing. pub stream_max_retries: Option, - /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating /// the connection as lost. pub stream_idle_timeout_ms: Option, - /// Does this provider require an OpenAI API Key or ChatGPT login token? If true, - /// user is presented with login screen on first run, and login preference and token/key - /// are stored in auth.json. If false (which is the default), login screen is skipped, - /// and API key (if needed) comes from the "env_key" environment variable. + /// the user is presented with a login screen on first run, and login preference and token/key + /// are stored in auth.json. If false (which is the default), the login screen is skipped, + /// and the API key (if needed) comes from the `env_key` environment variable. #[serde(default)] pub requires_openai_auth: bool, } impl ModelProviderInfo { - /// Construct a `POST` RequestBuilder for the given URL using the provided - /// [`CodexHttpClient`] applying: - /// • provider-specific headers (static + env based) - /// • Bearer auth header when an API key is available. - /// • Auth token for OAuth. + /// Construct a `POST` request builder for the given URL using the provided + /// [`reqwest::Client`] applying: + /// - provider-specific headers (static and environment based) + /// - Bearer auth header when an API key is available + /// - Auth token for OAuth /// - /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the - /// one produced by [`ModelProviderInfo::api_key`]. - pub async fn create_request_builder<'a>( - &'a self, - client: &'a CodexHttpClient, - auth: &Option, - ) -> crate::error::Result { + /// If the provider declares an `env_key` but the variable is missing or empty, this returns an + /// error identical to the one produced by [`ModelProviderInfo::api_key`]. + pub async fn create_request_builder( + &self, + client: &reqwest::Client, + auth: &Option, + ) -> Result { let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token { - Some(CodexAuth::from_api_key(secret_key)) + Some(AuthContext { + mode: AuthMode::ApiKey, + bearer_token: Some(secret_key.clone()), + account_id: None, + }) } else { - match self.api_key() { - Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)), - Ok(None) => auth.clone(), - Err(err) => { - if auth.is_some() { - auth.clone() - } else { - return Err(err); - } - } + match self.api_key()? { + Some(key) => Some(AuthContext { + mode: AuthMode::ApiKey, + bearer_token: Some(key), + account_id: None, + }), + None => auth.clone(), } }; - let url = self.get_full_url(&effective_auth); - + let url = self.get_full_url(effective_auth.as_ref()); let mut builder = client.post(url); - if let Some(auth) = effective_auth.as_ref() { - builder = builder.bearer_auth(auth.get_token().await?); + if let Some(context) = effective_auth.as_ref() + && let Some(token) = context.bearer_token.as_ref() + { + builder = builder.bearer_auth(token); } Ok(self.apply_http_headers(builder)) @@ -149,10 +142,10 @@ impl ModelProviderInfo { }) } - pub(crate) fn get_full_url(&self, auth: &Option) -> String { + pub fn get_full_url(&self, auth: Option<&AuthContext>) -> String { let default_base_url = if matches!( auth, - Some(CodexAuth { + Some(AuthContext { mode: AuthMode::ChatGPT, .. }) @@ -165,7 +158,7 @@ impl ModelProviderInfo { let base_url = self .base_url .clone() - .unwrap_or(default_base_url.to_string()); + .unwrap_or_else(|| default_base_url.to_string()); match self.wire_api { WireApi::Responses => format!("{base_url}/responses{query_string}"), @@ -173,7 +166,7 @@ impl ModelProviderInfo { } } - pub(crate) fn is_azure_responses_endpoint(&self) -> bool { + pub fn is_azure_responses_endpoint(&self) -> bool { if self.wire_api != WireApi::Responses { return false; } @@ -188,10 +181,9 @@ impl ModelProviderInfo { .unwrap_or(false) } - /// Apply provider-specific HTTP headers (both static and environment-based) - /// onto an existing [`CodexRequestBuilder`] and return the updated - /// builder. - fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder { + /// Apply provider-specific HTTP headers (both static and environment-based) onto an existing + /// [`reqwest::RequestBuilder`] and return the updated builder. + fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { if let Some(extra) = &self.http_headers { for (k, v) in extra { builder = builder.header(k, v); @@ -210,10 +202,9 @@ impl ModelProviderInfo { builder } - /// If `env_key` is Some, returns the API key for this provider if present - /// (and non-empty) in the environment. If `env_key` is required but - /// cannot be found, returns an error. - pub fn api_key(&self) -> crate::error::Result> { + /// If `env_key` is `Some`, returns the API key for this provider if present (and non-empty) in + /// the environment. If `env_key` is required but cannot be found, returns an error. + pub fn api_key(&self) -> Result> { match &self.env_key { Some(env_key) => { let env_value = std::env::var(env_key); @@ -225,11 +216,9 @@ impl ModelProviderInfo { Ok(Some(v)) } }) - .map_err(|_| { - crate::error::CodexErr::EnvVar(EnvVarError { - var: env_key.clone(), - instructions: self.env_key_instructions.clone(), - }) + .map_err(|_| Error::MissingEnvVar { + var: env_key.clone(), + instructions: self.env_key_instructions.clone(), }) } None => Ok(None), @@ -258,28 +247,23 @@ impl ModelProviderInfo { } } -const DEFAULT_OLLAMA_PORT: u32 = 11434; - pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss"; /// Built-in default provider list. pub fn built_in_model_providers() -> HashMap { use ModelProviderInfo as P; - // We do not want to be in the business of adjucating which third-party - // providers are bundled with Codex CLI, so we only include the OpenAI and - // open source ("oss") providers by default. Users are encouraged to add to - // `model_providers` in config.toml to add their own providers. + // We do not want to be in the business of adjudicating which third-party providers are bundled + // with Codex CLI, so we only include the OpenAI and open source ("oss") providers by default. + // Users are encouraged to add to `model_providers` in config.toml to add their own providers. [ ( "openai", P { name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. + // Allow users to override the default OpenAI endpoint by exporting `OPENAI_BASE_URL`. + // This is useful when pointing Codex at a proxy, mock server, or Azure-style + // deployment without requiring a full TOML override for the built-in OpenAI provider. base_url: std::env::var("OPENAI_BASE_URL") .ok() .filter(|v| !v.trim().is_empty()), @@ -318,9 +302,10 @@ pub fn built_in_model_providers() -> HashMap { .collect() } +/// Convenience helper for the built-in OSS provider. pub fn create_oss_provider() -> ModelProviderInfo { - // These CODEX_OSS_ environment variables are experimental: we may - // switch to reading values from config.toml instead. + // These CODEX_OSS_ environment variables are experimental: we may switch to reading values from + // config.toml instead. let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL") .ok() .filter(|v| !v.trim().is_empty()) @@ -366,23 +351,23 @@ fn matches_azure_responses_base_url(base_url: &str) -> bool { "azure-api.", "azurefd.", ]; - AZURE_MARKERS.iter().any(|marker| base.contains(marker)) + AZURE_MARKERS.iter().any(|needle| base.contains(needle)) } #[cfg(test)] mod tests { use super::*; - use pretty_assertions::assert_eq; + use maplit::hashmap; #[test] - fn test_deserialize_ollama_model_provider_toml() { + fn deserializes_defaults_without_optional_fields() { let azure_provider_toml = r#" -name = "Ollama" -base_url = "http://localhost:11434/v1" +name = "Azure" +base_url = "https://xxxxx.openai.azure.com/openai" "#; let expected_provider = ModelProviderInfo { - name: "Ollama".into(), - base_url: Some("http://localhost:11434/v1".into()), + name: "Azure".into(), + base_url: Some("https://xxxxx.openai.azure.com/openai".into()), env_key: None, env_key_instructions: None, experimental_bearer_token: None, @@ -415,7 +400,7 @@ query_params = { api-version = "2025-04-01-preview" } env_key_instructions: None, experimental_bearer_token: None, wire_api: WireApi::Chat, - query_params: Some(maplit::hashmap! { + query_params: Some(hashmap! { "api-version".to_string() => "2025-04-01-preview".to_string(), }), http_headers: None, @@ -447,10 +432,10 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } experimental_bearer_token: None, wire_api: WireApi::Chat, query_params: None, - http_headers: Some(maplit::hashmap! { + http_headers: Some(hashmap! { "X-Example-Header".to_string() => "example-value".to_string(), }), - env_http_headers: Some(maplit::hashmap! { + env_http_headers: Some(hashmap! { "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), }), request_max_retries: None, @@ -516,16 +501,12 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } }; assert!(named_provider.is_azure_responses_endpoint()); - let negative_cases = [ - "https://api.openai.com/v1", - "https://example.com/openai", - "https://myproxy.azurewebsites.net/openai", - ]; + let negative_cases = ["https://api.openai.com/v1", "https://example.com"]; for base_url in negative_cases { let provider = provider_for(base_url); assert!( !provider.is_azure_responses_endpoint(), - "expected {base_url} not to be detected as Azure" + "expected {base_url} to be non-Azure" ); } } diff --git a/codex-rs/api-client/src/prompt.rs b/codex-rs/api-client/src/prompt.rs new file mode 100644 index 0000000000..83ad55dec3 --- /dev/null +++ b/codex-rs/api-client/src/prompt.rs @@ -0,0 +1,49 @@ +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::SessionSource; +use serde_json::Value; + +use crate::Reasoning; +use crate::TextControls; + +#[derive(Debug, Clone, Default)] +pub struct Prompt { + pub instructions: String, + pub input: Vec, + pub tools: Vec, + pub parallel_tool_calls: bool, + pub output_schema: Option, + pub reasoning: Option, + pub text_controls: Option, + pub prompt_cache_key: Option, + pub previous_response_id: Option, + pub session_source: Option, +} + +impl Prompt { + #[allow(clippy::too_many_arguments)] + pub fn new( + instructions: String, + input: Vec, + tools: Vec, + parallel_tool_calls: bool, + output_schema: Option, + reasoning: Option, + text_controls: Option, + prompt_cache_key: Option, + previous_response_id: Option, + session_source: Option, + ) -> Self { + Self { + instructions, + input, + tools, + parallel_tool_calls, + output_schema, + reasoning, + text_controls, + prompt_cache_key, + previous_response_id, + session_source, + } + } +} diff --git a/codex-rs/api-client/src/responses.rs b/codex-rs/api-client/src/responses.rs new file mode 100644 index 0000000000..da55e9f2c5 --- /dev/null +++ b/codex-rs/api-client/src/responses.rs @@ -0,0 +1,742 @@ +use std::io::BufRead; +use std::path::Path; +use std::sync::Arc; +use std::sync::OnceLock; +use std::time::Duration; + +use async_trait::async_trait; +use bytes::Bytes; +use codex_app_server_protocol::AuthMode; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::ConversationId; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::RateLimitWindow; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::SubAgentSource; +use codex_protocol::protocol::TokenUsage; +use eventsource_stream::Eventsource; +use futures::Stream; +use futures::StreamExt; +use futures::TryStreamExt; +use regex_lite::Regex; +use reqwest::StatusCode; +use reqwest::header::HeaderMap; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value; +use serde_json::json; +use tokio::sync::mpsc; +use tokio::time::timeout; +use tokio_util::io::ReaderStream; +use tracing::debug; +use tracing::trace; + +use crate::api::ApiClient; +use crate::auth::AuthProvider; +use crate::error::Error; +use crate::model_provider::ModelProviderInfo; +use crate::prompt::Prompt; +use crate::stream::ResponseEvent; +use crate::stream::ResponseStream; + +type Result = std::result::Result; + +#[derive(Clone)] +pub struct ResponsesApiClientConfig { + pub http_client: reqwest::Client, + pub provider: ModelProviderInfo, + pub model: String, + pub conversation_id: ConversationId, + pub auth_provider: Option>, + pub otel_event_manager: OtelEventManager, +} + +#[derive(Clone)] +pub struct ResponsesApiClient { + config: ResponsesApiClientConfig, +} + +#[async_trait] +impl ApiClient for ResponsesApiClient { + type Config = ResponsesApiClientConfig; + + async fn new(config: Self::Config) -> Result { + Ok(Self { config }) + } + + async fn stream(&self, prompt: Prompt) -> Result { + if self.config.provider.wire_api != crate::model_provider::WireApi::Responses { + return Err(Error::UnsupportedOperation( + "ResponsesApiClient requires a Responses provider".to_string(), + )); + } + + let mut payload_json = self.build_payload(&prompt)?; + + if self.config.provider.is_azure_responses_endpoint() + && let Some(input_value) = payload_json.get_mut("input") + && let Some(array) = input_value.as_array_mut() + { + attach_item_ids_array(array, &prompt.input); + } + + let max_attempts = self.config.provider.request_max_retries(); + for attempt in 0..=max_attempts { + match self + .attempt_stream_responses(attempt, &prompt, &payload_json) + .await + { + Ok(stream) => return Ok(stream), + Err(StreamAttemptError::Fatal(err)) => return Err(err), + Err(retryable) => { + if attempt == max_attempts { + return Err(retryable.into_error()); + } + + tokio::time::sleep(retryable.delay(attempt)).await; + } + } + } + + unreachable!("attempt_stream_responses should always return"); + } +} + +impl ResponsesApiClient { + fn build_payload(&self, prompt: &Prompt) -> Result { + let azure_workaround = self.config.provider.is_azure_responses_endpoint(); + + let mut payload = json!({ + "model": self.config.model, + "instructions": prompt.instructions, + "input": prompt.input, + "tools": prompt.tools, + "tool_choice": "auto", + "parallel_tool_calls": prompt.parallel_tool_calls, + "store": azure_workaround, + "stream": true, + "prompt_cache_key": prompt + .prompt_cache_key + .clone() + .unwrap_or_else(|| self.config.conversation_id.to_string()), + }); + + if let Some(reasoning) = prompt.reasoning.as_ref() + && let Some(obj) = payload.as_object_mut() + { + obj.insert("reasoning".to_string(), serde_json::to_value(reasoning)?); + } + + if let Some(text) = prompt.text_controls.as_ref() + && let Some(obj) = payload.as_object_mut() + { + obj.insert("text".to_string(), serde_json::to_value(text)?); + } + + if let Some(prev) = prompt.previous_response_id.as_ref() + && let Some(obj) = payload.as_object_mut() + { + obj.insert( + "previous_response_id".to_string(), + Value::String(prev.clone()), + ); + } + + let include = if prompt.reasoning.is_some() { + vec!["reasoning.encrypted_content".to_string()] + } else { + Vec::new() + }; + if let Some(obj) = payload.as_object_mut() { + obj.insert( + "include".to_string(), + Value::Array(include.into_iter().map(Value::String).collect()), + ); + } + + Ok(payload) + } + + async fn attempt_stream_responses( + &self, + attempt: u64, + prompt: &Prompt, + payload_json: &Value, + ) -> std::result::Result { + let auth = match &self.config.auth_provider { + Some(provider) => provider.auth_context().await, + None => None, + }; + + trace!( + "POST to {}: {:?}", + self.config.provider.get_full_url(auth.as_ref()), + serde_json::to_string(payload_json) + .unwrap_or_else(|_| "".to_string()) + ); + + let mut req_builder = self + .config + .provider + .create_request_builder(&self.config.http_client, &auth) + .await + .map_err(StreamAttemptError::Fatal)?; + + if let Some(SessionSource::SubAgent(sub)) = prompt.session_source.as_ref() { + let subagent = match sub { + SubAgentSource::Other(label) => label.clone(), + other => serde_json::to_value(other) + .ok() + .and_then(|v| v.as_str().map(ToString::to_string)) + .unwrap_or_else(|| "other".to_string()), + }; + req_builder = req_builder.header("x-openai-subagent", subagent); + } + + req_builder = req_builder + .header("conversation_id", self.config.conversation_id.to_string()) + .header("session_id", self.config.conversation_id.to_string()) + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(payload_json); + + if let Some(ctx) = auth.as_ref() + && ctx.mode == AuthMode::ChatGPT + && let Some(account_id) = ctx.account_id.as_ref() + { + req_builder = req_builder.header("chatgpt-account-id", account_id); + } + + let res = self + .config + .otel_event_manager + .log_request(attempt, || req_builder.send()) + .await; + + let mut request_id = None; + if let Ok(resp) = &res { + request_id = resp + .headers() + .get("cf-ray") + .and_then(|v| v.to_str().ok()) + .map(std::string::ToString::to_string); + } + + match res { + Ok(resp) if resp.status().is_success() => { + let (tx_event, rx_event) = mpsc::channel::>(1600); + + if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) + && tx_event + .send(Ok(ResponseEvent::RateLimits(snapshot))) + .await + .is_err() + { + debug!("receiver dropped rate limit snapshot event"); + } + + let request_id_for_stream = request_id.clone(); + let stream = resp + .bytes_stream() + .map_err(move |err| Error::ResponseStreamFailed { + source: err, + request_id: request_id_for_stream.clone(), + }); + tokio::spawn(process_sse( + stream, + tx_event, + self.config.provider.stream_idle_timeout(), + self.config.otel_event_manager.clone(), + )); + + Ok(ResponseStream { rx_event }) + } + Ok(res) => { + let status = res.status(); + + let retry_after_secs = res + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000)); + + if status == StatusCode::UNAUTHORIZED + && let Some(provider) = self.config.auth_provider.as_ref() + && let Some(ctx) = auth.as_ref() + && ctx.mode == AuthMode::ChatGPT + { + provider + .refresh_token() + .await + .map_err(|err| StreamAttemptError::Fatal(Error::Auth(err)))?; + } + + if !(status == StatusCode::TOO_MANY_REQUESTS + || status == StatusCode::UNAUTHORIZED + || status.is_server_error()) + { + // Surface error body. + let body = res + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(StreamAttemptError::Fatal(Error::UnexpectedStatus { + status, + body, + })); + } + + Err(StreamAttemptError::RetryableHttpError { + status, + retry_after, + request_id, + }) + } + Err(err) => Err(StreamAttemptError::RetryableTransportError(Error::Http( + err, + ))), + } + } +} + +enum StreamAttemptError { + RetryableHttpError { + status: StatusCode, + retry_after: Option, + request_id: Option, + }, + RetryableTransportError(Error), + Fatal(Error), +} + +impl StreamAttemptError { + fn delay(&self, attempt: u64) -> Duration { + let backoff_attempt = attempt + 1; + match self { + StreamAttemptError::RetryableHttpError { retry_after, .. } => { + retry_after.unwrap_or_else(|| backoff(backoff_attempt)) + } + StreamAttemptError::RetryableTransportError { .. } => backoff(backoff_attempt), + StreamAttemptError::Fatal(_) => Duration::from_secs(0), + } + } + + fn into_error(self) -> Error { + match self { + StreamAttemptError::RetryableHttpError { + status, request_id, .. + } => Error::RetryLimit { status, request_id }, + StreamAttemptError::RetryableTransportError(error) => error, + StreamAttemptError::Fatal(error) => error, + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct SseEvent { + #[serde(rename = "type")] + kind: String, + response: Option, + item: Option, + delta: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponseCompleted { + id: String, + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedUsage { + input_tokens: i64, + input_tokens_details: Option, + output_tokens: i64, + output_tokens_details: Option, + total_tokens: i64, +} + +impl From for TokenUsage { + fn from(val: ResponseCompletedUsage) -> Self { + TokenUsage { + input_tokens: val.input_tokens, + cached_input_tokens: val + .input_tokens_details + .map(|d| d.cached_tokens) + .unwrap_or(0), + output_tokens: val.output_tokens, + reasoning_output_tokens: val + .output_tokens_details + .map(|d| d.reasoning_tokens) + .unwrap_or(0), + total_tokens: val.total_tokens, + } + } +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedInputTokensDetails { + cached_tokens: i64, +} + +#[derive(Debug, Deserialize)] +struct ResponseCompletedOutputTokensDetails { + reasoning_tokens: i64, +} + +fn attach_item_ids_array(items: &mut [Value], original_items: &[ResponseItem]) { + for (value, item) in items.iter_mut().zip(original_items.iter()) { + if let ResponseItem::Reasoning { id, .. } + | ResponseItem::Message { id: Some(id), .. } + | ResponseItem::WebSearchCall { id: Some(id), .. } + | ResponseItem::FunctionCall { id: Some(id), .. } + | ResponseItem::LocalShellCall { id: Some(id), .. } + | ResponseItem::CustomToolCall { id: Some(id), .. } + | ResponseItem::CustomToolCallOutput { call_id: id, .. } + | ResponseItem::FunctionCallOutput { call_id: id, .. } = item + { + if id.is_empty() { + continue; + } + + if let Some(obj) = value.as_object_mut() { + obj.insert("id".to_string(), Value::String(id.clone())); + } + } + } +} + +fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { + let primary = parse_rate_limit_window( + headers, + "x-codex-primary-used-percent", + "x-codex-primary-window-minutes", + "x-codex-primary-reset-at", + ); + + let secondary = parse_rate_limit_window( + headers, + "x-codex-secondary-used-percent", + "x-codex-secondary-window-minutes", + "x-codex-secondary-reset-at", + ); + + Some(RateLimitSnapshot { primary, secondary }) +} + +fn parse_rate_limit_window( + headers: &HeaderMap, + used_percent_header: &str, + window_minutes_header: &str, + resets_at_header: &str, +) -> Option { + let used_percent: Option = parse_header_f64(headers, used_percent_header); + + used_percent.and_then(|used_percent| { + let window_minutes = parse_header_i64(headers, window_minutes_header); + let resets_at = parse_header_i64(headers, resets_at_header); + + let has_data = used_percent != 0.0 + || window_minutes.is_some_and(|minutes| minutes != 0) + || resets_at.is_some(); + + has_data.then_some(RateLimitWindow { + used_percent, + window_minutes, + resets_at, + }) + }) +} + +fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { + parse_header_str(headers, name)? + .parse::() + .ok() + .filter(|v| v.is_finite()) +} + +fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { + parse_header_str(headers, name)?.parse::().ok() +} + +fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { + headers.get(name)?.to_str().ok() +} + +async fn process_sse( + stream: S, + tx_event: mpsc::Sender>, + idle_timeout: Duration, + otel_event_manager: OtelEventManager, +) where + S: Stream> + Unpin + Send + 'static, +{ + let mut stream = stream.eventsource(); + + let mut response_completed: Option = None; + let mut response_error: Option = None; + + loop { + let start = std::time::Instant::now(); + let response = timeout(idle_timeout, stream.next()).await; + let duration = start.elapsed(); + otel_event_manager.log_sse_event(&response, duration); + + let sse = match response { + Ok(Some(Ok(sse))) => sse, + Ok(Some(Err(e))) => { + debug!("SSE Error: {e:#}"); + let event = Error::Stream(e.to_string(), None); + let _ = tx_event.send(Err(event)).await; + return; + } + Ok(None) => { + match response_completed { + Some(ResponseCompleted { + id: response_id, + usage, + }) => { + if let Some(token_usage) = &usage { + otel_event_manager.sse_event_completed( + token_usage.input_tokens, + token_usage.output_tokens, + token_usage + .input_tokens_details + .as_ref() + .map(|d| d.cached_tokens), + token_usage + .output_tokens_details + .as_ref() + .map(|d| d.reasoning_tokens), + token_usage.total_tokens, + ); + } + let event = ResponseEvent::Completed { + response_id, + token_usage: usage.map(Into::into), + }; + let _ = tx_event.send(Ok(event)).await; + } + None => { + let error = response_error.unwrap_or(Error::Stream( + "stream closed before response.completed".into(), + None, + )); + otel_event_manager.see_event_completed_failed(&error); + + let _ = tx_event.send(Err(error)).await; + } + } + return; + } + Err(_) => { + let _ = tx_event + .send(Err(Error::Stream( + "idle timeout waiting for SSE".into(), + None, + ))) + .await; + return; + } + }; + + let raw = sse.data.clone(); + trace!("SSE event: {}", raw); + + let event: SseEvent = match serde_json::from_str(&sse.data) { + Ok(event) => event, + Err(e) => { + debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); + continue; + } + }; + + match event.kind.as_str() { + "response.output_item.done" => { + let Some(item_val) = event.item else { continue }; + let Ok(item) = serde_json::from_value::(item_val) else { + debug!("failed to parse ResponseItem from output_item.done"); + continue; + }; + + let event = ResponseEvent::OutputItemDone(item); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + "response.output_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::OutputTextDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.reasoning_summary_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::ReasoningSummaryDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.reasoning_text.delta" => { + if let Some(delta) = event.delta { + let event = ResponseEvent::ReasoningContentDelta(delta); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + } + "response.created" => { + if event.response.is_some() { + let _ = tx_event.send(Ok(ResponseEvent::Created)).await; + } + } + "response.failed" => { + if let Some(resp_val) = event.response { + response_error = Some(Error::Stream( + "response.failed event received".to_string(), + None, + )); + + if let Some(error) = resp_val.get("error") { + match serde_json::from_value::(error.clone()) { + Ok(error) => { + if is_context_window_error(&error) { + response_error = Some(Error::UnsupportedOperation( + "context window exceeded".to_string(), + )); + } else { + let delay = try_parse_retry_after(&error); + let message = error.message.clone().unwrap_or_default(); + response_error = Some(Error::Stream(message, delay)); + } + } + Err(e) => { + let error = format!("failed to parse ErrorResponse: {e}"); + debug!(error); + response_error = Some(Error::Stream(error, None)) + } + } + } + } + } + "response.completed" => { + if let Some(resp_val) = event.response { + match serde_json::from_value::(resp_val) { + Ok(r) => { + response_completed = Some(r); + } + Err(e) => { + let error = format!("failed to parse ResponseCompleted: {e}"); + debug!(error); + response_error = Some(Error::Stream(error, None)); + continue; + } + }; + }; + } + "response.output_item.added" => { + let Some(item_val) = event.item else { continue }; + let Ok(item) = serde_json::from_value::(item_val) else { + debug!("failed to parse ResponseItem from output_item.done"); + continue; + }; + + let event = ResponseEvent::OutputItemAdded(item); + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + "response.reasoning_summary_part.added" => { + let event = ResponseEvent::ReasoningSummaryPartAdded; + if tx_event.send(Ok(event)).await.is_err() { + return; + } + } + _ => {} + } + } +} + +#[derive(Debug, Deserialize)] +struct ErrorResponse { + code: Option, + message: Option, +} + +fn backoff(attempt: u64) -> Duration { + let exponent = attempt.min(6) as u32; + let base = 2u64.pow(exponent); + Duration::from_millis(base * 100) +} + +fn rate_limit_regex() -> Option<&'static Regex> { + static RE: OnceLock> = OnceLock::new(); + + RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").ok()) + .as_ref() +} + +fn try_parse_retry_after(err: &ErrorResponse) -> Option { + if err.code.as_deref() != Some("rate_limit_exceeded") { + return None; + } + + if let Some(re) = rate_limit_regex() + && let Some(message) = &err.message + && let Some(captures) = re.captures(message) + { + let seconds = captures.get(1); + let unit = captures.get(2); + + if let (Some(value), Some(unit)) = (seconds, unit) { + let value = value.as_str().parse::().ok()?; + let unit = unit.as_str(); + + if unit == "s" { + return Some(Duration::from_secs_f64(value)); + } else if unit == "ms" { + return Some(Duration::from_millis(value as u64)); + } + } + } + None +} + +fn is_context_window_error(error: &ErrorResponse) -> bool { + error.code.as_deref() == Some("context_length_exceeded") +} + +/// used in tests to stream from a text SSE file +pub async fn stream_from_fixture( + path: impl AsRef, + provider: ModelProviderInfo, + otel_event_manager: OtelEventManager, +) -> Result { + let (tx_event, rx_event) = mpsc::channel::>(1600); + let display_path = path.as_ref().display().to_string(); + let file = std::fs::File::open(path.as_ref()) + .map_err(|e| Error::Other(format!("failed to open fixture {display_path}: {e}")))?; + let lines = std::io::BufReader::new(file).lines(); + + let mut content = String::new(); + for line in lines { + let line = + line.map_err(|e| Error::Other(format!("failed to read fixture {display_path}: {e}")))?; + content.push_str(&line); + content.push_str("\n\n"); + } + + let rdr = std::io::Cursor::new(content); + let stream = ReaderStream::new(rdr).map_err(|e| Error::Other(e.to_string())); + tokio::spawn(process_sse( + stream, + tx_event, + provider.stream_idle_timeout(), + otel_event_manager, + )); + Ok(ResponseStream { rx_event }) +} diff --git a/codex-rs/api-client/src/stream.rs b/codex-rs/api-client/src/stream.rs new file mode 100644 index 0000000000..ac76b282fa --- /dev/null +++ b/codex-rs/api-client/src/stream.rs @@ -0,0 +1,83 @@ +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; +use codex_protocol::models::ResponseItem; +use codex_protocol::protocol::RateLimitSnapshot; +use codex_protocol::protocol::TokenUsage; +use futures::Stream; +use serde::Serialize; +use serde_json::Value; +use tokio::sync::mpsc; + +use crate::error::Result; + +#[derive(Debug, Serialize, Clone)] +pub struct Reasoning { + #[serde(skip_serializing_if = "Option::is_none")] + pub effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Serialize, Default, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TextFormatType { + #[default] + JsonSchema, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextFormat { + pub r#type: TextFormatType, + pub strict: bool, + pub schema: Value, + pub name: String, +} + +#[derive(Debug, Serialize, Default, Clone)] +pub struct TextControls { + #[serde(skip_serializing_if = "Option::is_none")] + pub verbosity: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, +} + +#[derive(Debug)] +pub enum ResponseEvent { + Created, + OutputItemDone(ResponseItem), + OutputItemAdded(ResponseItem), + Completed { + response_id: String, + token_usage: Option, + }, + OutputTextDelta(String), + ReasoningSummaryDelta(String), + ReasoningContentDelta(String), + ReasoningSummaryPartAdded, + RateLimits(RateLimitSnapshot), +} + +#[derive(Debug)] +pub struct EventStream { + pub(crate) rx_event: mpsc::Receiver, +} + +impl EventStream { + pub fn from_receiver(rx_event: mpsc::Receiver) -> Self { + Self { rx_event } + } +} + +impl Stream for EventStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx_event.poll_recv(cx) + } +} + +pub type ResponseStream = EventStream>; diff --git a/codex-rs/common/Cargo.toml b/codex-rs/common/Cargo.toml index d8f30cc09d..a836e5d6cb 100644 --- a/codex-rs/common/Cargo.toml +++ b/codex-rs/common/Cargo.toml @@ -8,6 +8,7 @@ workspace = true [dependencies] clap = { workspace = true, features = ["derive", "wrap_help"], optional = true } +codex-api-client = { workspace = true } codex-core = { workspace = true } codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } diff --git a/codex-rs/common/src/config_summary.rs b/codex-rs/common/src/config_summary.rs index dabc606ce1..b3019e3ea9 100644 --- a/codex-rs/common/src/config_summary.rs +++ b/codex-rs/common/src/config_summary.rs @@ -1,4 +1,4 @@ -use codex_core::WireApi; +use codex_api_client::WireApi; use codex_core::config::Config; use crate::sandbox_summary::summarize_sandbox_policy; diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 921ca2843a..57604bf2ae 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true, features = ["serde"] } codex-app-server-protocol = { workspace = true } codex-apply-patch = { workspace = true } codex-async-utils = { workspace = true } +codex-api-client = { workspace = true } codex-file-search = { workspace = true } codex-git = { workspace = true } codex-keyring-store = { workspace = true } diff --git a/codex-rs/core/src/auth.rs b/codex-rs/core/src/auth.rs index b18cae5fa6..b222fa745b 100644 --- a/codex-rs/core/src/auth.rs +++ b/codex-rs/core/src/auth.rs @@ -22,7 +22,6 @@ use crate::auth::storage::AuthStorageBackend; use crate::auth::storage::create_auth_storage; use crate::config::Config; use crate::default_client::CodexHttpClient; -use crate::token_data::PlanType; use crate::token_data::TokenData; use crate::token_data::parse_id_token; use crate::util::try_parse_error_message; @@ -153,11 +152,6 @@ impl CodexAuth { self.get_current_token_data().and_then(|t| t.id_token.email) } - pub(crate) fn get_plan_type(&self) -> Option { - self.get_current_token_data() - .and_then(|t| t.id_token.chatgpt_plan_type) - } - fn get_current_auth_json(&self) -> Option { #[expect(clippy::unwrap_used)] self.auth_dot_json.lock().unwrap().clone() diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs deleted file mode 100644 index abb27d9b55..0000000000 --- a/codex-rs/core/src/chat_completions.rs +++ /dev/null @@ -1,967 +0,0 @@ -use std::time::Duration; - -use crate::ModelProviderInfo; -use crate::client_common::Prompt; -use crate::client_common::ResponseEvent; -use crate::client_common::ResponseStream; -use crate::default_client::CodexHttpClient; -use crate::error::CodexErr; -use crate::error::ConnectionFailedError; -use crate::error::ResponseStreamFailed; -use crate::error::Result; -use crate::error::RetryLimitReachedError; -use crate::error::UnexpectedResponseError; -use crate::model_family::ModelFamily; -use crate::tools::spec::create_tools_json_for_chat_completions_api; -use crate::util::backoff; -use bytes::Bytes; -use codex_otel::otel_event_manager::OtelEventManager; -use codex_protocol::models::ContentItem; -use codex_protocol::models::FunctionCallOutputContentItem; -use codex_protocol::models::ReasoningItemContent; -use codex_protocol::models::ResponseItem; -use codex_protocol::protocol::SessionSource; -use codex_protocol::protocol::SubAgentSource; -use eventsource_stream::Eventsource; -use futures::Stream; -use futures::StreamExt; -use futures::TryStreamExt; -use reqwest::StatusCode; -use serde_json::json; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; -use tokio::time::timeout; -use tracing::debug; -use tracing::trace; - -/// Implementation for the classic Chat Completions API. -pub(crate) async fn stream_chat_completions( - prompt: &Prompt, - model_family: &ModelFamily, - client: &CodexHttpClient, - provider: &ModelProviderInfo, - otel_event_manager: &OtelEventManager, - session_source: &SessionSource, -) -> Result { - if prompt.output_schema.is_some() { - return Err(CodexErr::UnsupportedOperation( - "output_schema is not supported for Chat Completions API".to_string(), - )); - } - - // Build messages array - let mut messages = Vec::::new(); - - let full_instructions = prompt.get_full_instructions(model_family); - messages.push(json!({"role": "system", "content": full_instructions})); - - let input = prompt.get_formatted_input(); - - // Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user. - // - If the last emitted message is a user message, drop all reasoning. - // - Otherwise, for each Reasoning item after the last user message, attach it - // to the immediate previous assistant message (stop turns) or the immediate - // next assistant anchor (tool-call turns: function/local shell call, or assistant message). - let mut reasoning_by_anchor_index: std::collections::HashMap = - std::collections::HashMap::new(); - - // Determine the last role that would be emitted to Chat Completions. - let mut last_emitted_role: Option<&str> = None; - for item in &input { - match item { - ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - last_emitted_role = Some("assistant") - } - ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), - ResponseItem::Reasoning { .. } | ResponseItem::Other => {} - ResponseItem::CustomToolCall { .. } => {} - ResponseItem::CustomToolCallOutput { .. } => {} - ResponseItem::WebSearchCall { .. } => {} - ResponseItem::GhostSnapshot { .. } => {} - } - } - - // Find the last user message index in the input. - let mut last_user_index: Option = None; - for (idx, item) in input.iter().enumerate() { - if let ResponseItem::Message { role, .. } = item - && role == "user" - { - last_user_index = Some(idx); - } - } - - // Attach reasoning only if the conversation does not end with a user message. - if !matches!(last_emitted_role, Some("user")) { - for (idx, item) in input.iter().enumerate() { - // Only consider reasoning that appears after the last user message. - if let Some(u_idx) = last_user_index - && idx <= u_idx - { - continue; - } - - if let ResponseItem::Reasoning { - content: Some(items), - .. - } = item - { - let mut text = String::new(); - for entry in items { - match entry { - ReasoningItemContent::ReasoningText { text: segment } - | ReasoningItemContent::Text { text: segment } => text.push_str(segment), - } - } - if text.trim().is_empty() { - continue; - } - - // Prefer immediate previous assistant message (stop turns) - let mut attached = false; - if idx > 0 - && let ResponseItem::Message { role, .. } = &input[idx - 1] - && role == "assistant" - { - reasoning_by_anchor_index - .entry(idx - 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - attached = true; - } - - // Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message) - if !attached && idx + 1 < input.len() { - match &input[idx + 1] { - ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - ResponseItem::Message { role, .. } if role == "assistant" => { - reasoning_by_anchor_index - .entry(idx + 1) - .and_modify(|v| v.push_str(&text)) - .or_insert(text.clone()); - } - _ => {} - } - } - } - } - } - - // Track last assistant text we emitted to avoid duplicate assistant messages - // in the outbound Chat Completions payload (can happen if a final - // aggregated assistant message was recorded alongside an earlier partial). - let mut last_assistant_text: Option = None; - - for (idx, item) in input.iter().enumerate() { - match item { - ResponseItem::Message { role, content, .. } => { - // Build content either as a plain string (typical for assistant text) - // or as an array of content items when images are present (user/tool multimodal). - let mut text = String::new(); - let mut items: Vec = Vec::new(); - let mut saw_image = false; - - for c in content { - match c { - ContentItem::InputText { text: t } - | ContentItem::OutputText { text: t } => { - text.push_str(t); - items.push(json!({"type":"text","text": t})); - } - ContentItem::InputImage { image_url } => { - saw_image = true; - items.push(json!({"type":"image_url","image_url": {"url": image_url}})); - } - } - } - - // Skip exact-duplicate assistant messages. - if role == "assistant" { - if let Some(prev) = &last_assistant_text - && prev == &text - { - continue; - } - last_assistant_text = Some(text.clone()); - } - - // For assistant messages, always send a plain string for compatibility. - // For user messages, if an image is present, send an array of content items. - let content_value = if role == "assistant" { - json!(text) - } else if saw_image { - json!(items) - } else { - json!(text) - }; - - let mut msg = json!({"role": role, "content": content_value}); - if role == "assistant" - && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCall { - name, - arguments, - call_id, - .. - } => { - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": call_id, - "type": "function", - "function": { - "name": name, - "arguments": arguments, - } - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::LocalShellCall { - id, - call_id: _, - status, - action, - } => { - // Confirm with API team. - let mut msg = json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id.clone().unwrap_or_else(|| "".to_string()), - "type": "local_shell_call", - "status": status, - "action": action, - }] - }); - if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) - && let Some(obj) = msg.as_object_mut() - { - obj.insert("reasoning".to_string(), json!(reasoning)); - } - messages.push(msg); - } - ResponseItem::FunctionCallOutput { call_id, output } => { - // Prefer structured content items when available (e.g., images) - // otherwise fall back to the legacy plain-string content. - let content_value = if let Some(items) = &output.content_items { - let mapped: Vec = items - .iter() - .map(|it| match it { - FunctionCallOutputContentItem::InputText { text } => { - json!({"type":"text","text": text}) - } - FunctionCallOutputContentItem::InputImage { image_url } => { - json!({"type":"image_url","image_url": {"url": image_url}}) - } - }) - .collect(); - json!(mapped) - } else { - json!(output.content) - }; - - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": content_value, - })); - } - ResponseItem::CustomToolCall { - id, - call_id: _, - name, - input, - status: _, - } => { - messages.push(json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": id, - "type": "custom", - "custom": { - "name": name, - "input": input, - } - }] - })); - } - ResponseItem::CustomToolCallOutput { call_id, output } => { - messages.push(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": output, - })); - } - ResponseItem::GhostSnapshot { .. } => { - // Ghost snapshots annotate history but are not sent to the model. - continue; - } - ResponseItem::Reasoning { .. } - | ResponseItem::WebSearchCall { .. } - | ResponseItem::Other => { - // Omit these items from the conversation history. - continue; - } - } - } - - let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; - let payload = json!({ - "model": model_family.slug, - "messages": messages, - "stream": true, - "tools": tools_json, - }); - - debug!( - "POST to {}: {}", - provider.get_full_url(&None), - serde_json::to_string_pretty(&payload).unwrap_or_default() - ); - - let mut attempt = 0; - let max_retries = provider.request_max_retries(); - loop { - attempt += 1; - - let mut req_builder = provider.create_request_builder(client, &None).await?; - - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = session_source.clone() { - let subagent = if let SubAgentSource::Other(label) = sub { - label - } else { - serde_json::to_value(&sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) - }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - - let res = otel_event_manager - .log_request(attempt, || { - req_builder - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload) - .send() - }) - .await; - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - let stream = resp.bytes_stream().map_err(|e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: None, - }) - }); - tokio::spawn(process_chat_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager.clone(), - )); - return Ok(ResponseStream { rx_event }); - } - Ok(res) => { - let status = res.status(); - if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { - let body = (res.text().await).unwrap_or_default(); - return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError { - status, - body, - request_id: None, - })); - } - - if attempt > max_retries { - return Err(CodexErr::RetryLimit(RetryLimitReachedError { - status, - request_id: None, - })); - } - - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - let delay = retry_after_secs - .map(|s| Duration::from_millis(s * 1_000)) - .unwrap_or_else(|| backoff(attempt)); - tokio::time::sleep(delay).await; - } - Err(e) => { - if attempt > max_retries { - return Err(CodexErr::ConnectionFailed(ConnectionFailedError { - source: e, - })); - } - let delay = backoff(attempt); - tokio::time::sleep(delay).await; - } - } - } -} - -async fn append_assistant_text( - tx_event: &mpsc::Sender>, - assistant_item: &mut Option, - text: String, -) { - if assistant_item.is_none() { - let item = ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![], - }; - *assistant_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Message { content, .. }) = assistant_item { - content.push(ContentItem::OutputText { text: text.clone() }); - let _ = tx_event - .send(Ok(ResponseEvent::OutputTextDelta(text.clone()))) - .await; - } -} - -async fn append_reasoning_text( - tx_event: &mpsc::Sender>, - reasoning_item: &mut Option, - text: String, -) { - if reasoning_item.is_none() { - let item = ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![]), - encrypted_content: None, - }; - *reasoning_item = Some(item.clone()); - let _ = tx_event - .send(Ok(ResponseEvent::OutputItemAdded(item))) - .await; - } - - if let Some(ResponseItem::Reasoning { - content: Some(content), - .. - }) = reasoning_item - { - content.push(ReasoningItemContent::ReasoningText { text: text.clone() }); - - let _ = tx_event - .send(Ok(ResponseEvent::ReasoningContentDelta(text.clone()))) - .await; - } -} -/// Lightweight SSE processor for the Chat Completions streaming format. The -/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest -/// of the pipeline can stay agnostic of the underlying wire format. -async fn process_chat_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, -{ - let mut stream = stream.eventsource(); - - // State to accumulate a function call across streaming chunks. - // OpenAI may split the `arguments` string over multiple `delta` events - // until the chunk whose `finish_reason` is `tool_calls` is emitted. We - // keep collecting the pieces here and forward a single - // `ResponseItem::FunctionCall` once the call is complete. - #[derive(Default)] - struct FunctionCallState { - name: Option, - arguments: String, - call_id: Option, - active: bool, - } - - let mut fn_call_state = FunctionCallState::default(); - let mut assistant_item: Option = None; - let mut reasoning_item: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(ev))) => ev, - Ok(Some(Err(e))) => { - let _ = tx_event - .send(Err(CodexErr::Stream(e.to_string(), None))) - .await; - return; - } - Ok(None) => { - // Stream closed gracefully – emit Completed with dummy id. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; - } - }; - - // OpenAI Chat streaming sends a literal string "[DONE]" when finished. - if sse.data.trim() == "[DONE]" { - // Emit any finalized items before closing so downstream consumers receive - // terminal events for both assistant content and raw reasoning. - if let Some(item) = assistant_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - if let Some(item) = reasoning_item { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - return; - } - - // Parse JSON chunk - let chunk: serde_json::Value = match serde_json::from_str(&sse.data) { - Ok(v) => v, - Err(_) => continue, - }; - trace!("chat_completions received SSE chunk: {chunk:?}"); - - let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); - - if let Some(choice) = choice_opt { - // Handle assistant content tokens as streaming deltas. - if let Some(content) = choice - .get("delta") - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()) - && !content.is_empty() - { - append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await; - } - - // Forward any reasoning/thinking deltas if present. - // Some providers stream `reasoning` as a plain string while others - // nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`). - if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { - let mut maybe_text = reasoning_val - .as_str() - .map(str::to_string) - .filter(|s| !s.is_empty()); - - if maybe_text.is_none() && reasoning_val.is_object() { - if let Some(s) = reasoning_val - .get("text") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } else if let Some(s) = reasoning_val - .get("content") - .and_then(|t| t.as_str()) - .filter(|s| !s.is_empty()) - { - maybe_text = Some(s.to_string()); - } - } - - if let Some(reasoning) = maybe_text { - // Accumulate so we can emit a terminal Reasoning item at the end. - append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await; - } - } - - // Some providers only include reasoning on the final message object. - if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) - { - // Accept either a plain string or an object with { text | content } - if let Some(s) = message_reasoning.as_str() { - if !s.is_empty() { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } else if let Some(obj) = message_reasoning.as_object() - && let Some(s) = obj - .get("text") - .and_then(|v| v.as_str()) - .or_else(|| obj.get("content").and_then(|v| v.as_str())) - && !s.is_empty() - { - append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await; - } - } - - // Handle streaming function / tool calls. - if let Some(tool_calls) = choice - .get("delta") - .and_then(|d| d.get("tool_calls")) - .and_then(|tc| tc.as_array()) - && let Some(tool_call) = tool_calls.first() - { - // Mark that we have an active function call in progress. - fn_call_state.active = true; - - // Extract call_id if present. - if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) { - fn_call_state.call_id.get_or_insert_with(|| id.to_string()); - } - - // Extract function details if present. - if let Some(function) = tool_call.get("function") { - if let Some(name) = function.get("name").and_then(|n| n.as_str()) { - fn_call_state.name.get_or_insert_with(|| name.to_string()); - } - - if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str()) - { - fn_call_state.arguments.push_str(args_fragment); - } - } - } - - // Emit end-of-turn when finish_reason signals completion. - if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) { - match finish_reason { - "tool_calls" if fn_call_state.active => { - // First, flush the terminal raw reasoning so UIs can finalize - // the reasoning stream before any exec/tool events begin. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - - // Then emit the FunctionCall response item. - let item = ResponseItem::FunctionCall { - id: None, - name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), - arguments: fn_call_state.arguments.clone(), - call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), - }; - - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - "stop" => { - // Regular turn without tool-call. Emit the final assistant message - // as a single OutputItemDone so non-delta consumers see the result. - if let Some(item) = assistant_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - // Also emit a terminal Reasoning item so UIs can finalize raw reasoning. - if let Some(item) = reasoning_item.take() { - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; - } - } - _ => {} - } - - // Emit Completed regardless of reason so the agent can advance. - let _ = tx_event - .send(Ok(ResponseEvent::Completed { - response_id: String::new(), - token_usage: None, - })) - .await; - - // Prepare for potential next turn (should not happen in same stream). - // fn_call_state = FunctionCallState::default(); - - return; // End processing for this SSE stream. - } - } - } -} - -/// Optional client-side aggregation helper -/// -/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from -/// [`process_chat_sse`] into a *running* assistant message, **suppressing the -/// per-token deltas**. The stream stays silent while the model is thinking -/// and only emits two events per turn: -/// -/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message -/// (fully concatenated). -/// 2. The original `ResponseEvent::Completed` right after it. -/// -/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers. -/// -/// The adapter is intentionally *lossless*: callers who do **not** opt in via -/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified -/// events. -#[derive(Copy, Clone, Eq, PartialEq)] -enum AggregateMode { - AggregatedOnly, - Streaming, -} -pub(crate) struct AggregatedChatStream { - inner: S, - cumulative: String, - cumulative_reasoning: String, - pending: std::collections::VecDeque, - mode: AggregateMode, -} - -impl Stream for AggregatedChatStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - // First, flush any buffered events from the previous call. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - - loop { - match Pin::new(&mut this.inner).poll_next(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), - Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // If this is an incremental assistant message chunk, accumulate but - // do NOT emit yet. Forward any other item (e.g. FunctionCall) right - // away so downstream consumers see it. - - let is_assistant_message = matches!( - &item, - codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant" - ); - - if is_assistant_message { - match this.mode { - AggregateMode::AggregatedOnly => { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let codex_protocol::models::ResponseItem::Message { - content, - .. - } = &item - && let Some(text) = content.iter().find_map(|c| match c { - codex_protocol::models::ContentItem::OutputText { - text, - } => Some(text), - _ => None, - }) - { - this.cumulative.push_str(text); - } - // Swallow assistant message here; emit on Completed. - continue; - } - AggregateMode::Streaming => { - // In streaming mode, if we have not seen any deltas, forward - // the final assistant message directly. If deltas were seen, - // suppress the final message to avoid duplication. - if this.cumulative.is_empty() { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( - item, - )))); - } else { - continue; - } - } - } - } - - // Not an assistant message – forward immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); - } - Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))); - } - Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))) => { - // Build any aggregated items in the correct order: Reasoning first, then Message. - let mut emitted_any = false; - - if !this.cumulative_reasoning.is_empty() - && matches!(this.mode, AggregateMode::AggregatedOnly) - { - let aggregated_reasoning = - codex_protocol::models::ResponseItem::Reasoning { - id: String::new(), - summary: Vec::new(), - content: Some(vec![ - codex_protocol::models::ReasoningItemContent::ReasoningText { - text: std::mem::take(&mut this.cumulative_reasoning), - }, - ]), - encrypted_content: None, - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_reasoning)); - emitted_any = true; - } - - // Always emit the final aggregated assistant message when any - // content deltas have been observed. In AggregatedOnly mode this - // is the sole assistant output; in Streaming mode this finalizes - // the streamed deltas into a terminal OutputItemDone so callers - // can persist/render the message once per turn. - if !this.cumulative.is_empty() { - let aggregated_message = codex_protocol::models::ResponseItem::Message { - id: None, - role: "assistant".to_string(), - content: vec![codex_protocol::models::ContentItem::OutputText { - text: std::mem::take(&mut this.cumulative), - }], - }; - this.pending - .push_back(ResponseEvent::OutputItemDone(aggregated_message)); - emitted_any = true; - } - - // Always emit Completed last when anything was aggregated. - if emitted_any { - this.pending.push_back(ResponseEvent::Completed { - response_id: response_id.clone(), - token_usage: token_usage.clone(), - }); - // Return the first pending event now. - if let Some(ev) = this.pending.pop_front() { - return Poll::Ready(Some(Ok(ev))); - } - } - - // Nothing aggregated – forward Completed directly. - return Poll::Ready(Some(Ok(ResponseEvent::Completed { - response_id, - token_usage, - }))); - } - Poll::Ready(Some(Ok(ResponseEvent::Created))) => { - // These events are exclusive to the Responses API and - // will never appear in a Chat Completions stream. - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => { - // Always accumulate deltas so we can emit a final OutputItemDone at Completed. - this.cumulative.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => { - // Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed. - this.cumulative_reasoning.push_str(&delta); - if matches!(this.mode, AggregateMode::Streaming) { - // In streaming mode, also forward the delta immediately. - return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))); - } else { - continue; - } - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => { - continue; - } - Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => { - return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))); - } - } - } - } -} - -/// Extension trait that activates aggregation on any stream of [`ResponseEvent`]. -pub(crate) trait AggregateStreamExt: Stream> + Sized { - /// Returns a new stream that emits **only** the final assistant message - /// per turn instead of every incremental delta. The produced - /// `ResponseEvent` sequence for a typical text turn looks like: - /// - /// ```ignore - /// OutputItemDone() - /// Completed - /// ``` - /// - /// No other `OutputItemDone` events will be seen by the caller. - /// - /// Usage: - /// - /// ```ignore - /// let agg_stream = client.stream(&prompt).await?.aggregate(); - /// while let Some(event) = agg_stream.next().await { - /// // event now contains cumulative text - /// } - /// ``` - fn aggregate(self) -> AggregatedChatStream { - AggregatedChatStream::new(self, AggregateMode::AggregatedOnly) - } -} - -impl AggregateStreamExt for T where T: Stream> + Sized {} - -impl AggregatedChatStream { - fn new(inner: S, mode: AggregateMode) -> Self { - AggregatedChatStream { - inner, - cumulative: String::new(), - cumulative_reasoning: String::new(), - pending: std::collections::VecDeque::new(), - mode, - } - } - - pub(crate) fn streaming_mode(inner: S) -> Self { - Self::new(inner, AggregateMode::Streaming) - } -} diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 683091b174..9a82a160d0 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -1,95 +1,122 @@ -use std::io::BufRead; -use std::path::Path; +use std::fmt; use std::sync::Arc; -use std::sync::OnceLock; -use std::time::Duration; -use bytes::Bytes; -use chrono::DateTime; -use chrono::Utc; -use codex_app_server_protocol::AuthMode; +use async_trait::async_trait; +use codex_api_client::AggregateStreamExt; +use codex_api_client::ApiClient; +use codex_api_client::AuthContext; +use codex_api_client::AuthProvider; +use codex_api_client::ChatAggregationMode; +use codex_api_client::ChatCompletionsApiClient; +use codex_api_client::ChatCompletionsApiClientConfig; +use codex_api_client::ResponsesApiClient; +use codex_api_client::ResponsesApiClientConfig; +use codex_api_client::Result as ApiClientResult; +use codex_api_client::stream_from_fixture; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; -use codex_protocol::models::ResponseItem; use codex_protocol::protocol::SessionSource; -use eventsource_stream::Eventsource; -use futures::prelude::*; -use regex_lite::Regex; -use reqwest::StatusCode; -use reqwest::header::HeaderMap; -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value; +use futures::StreamExt; +use futures::stream::BoxStream; +use tokio::sync::OnceCell; use tokio::sync::mpsc; -use tokio::time::timeout; -use tokio_util::io::ReaderStream; -use tracing::debug; -use tracing::trace; use tracing::warn; use crate::AuthManager; -use crate::auth::CodexAuth; -use crate::chat_completions::AggregateStreamExt; -use crate::chat_completions::stream_chat_completions; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::client_common::ResponseStream; -use crate::client_common::ResponsesApiRequest; use crate::client_common::create_reasoning_param_for_request; use crate::client_common::create_text_param_for_request; use crate::config::Config; -use crate::default_client::CodexHttpClient; use crate::default_client::create_client; use crate::error::CodexErr; use crate::error::ConnectionFailedError; +use crate::error::EnvVarError; use crate::error::ResponseStreamFailed; use crate::error::Result; use crate::error::RetryLimitReachedError; use crate::error::UnexpectedResponseError; -use crate::error::UsageLimitReachedError; +use crate::features::Feature; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_family::ModelFamily; -use crate::model_provider_info::ModelProviderInfo; -use crate::model_provider_info::WireApi; use crate::openai_model_info::get_model_info; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::RateLimitWindow; -use crate::protocol::TokenUsage; -use crate::token_data::PlanType; -use crate::tools::spec::create_tools_json_for_responses_api; -use crate::util::backoff; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; -#[derive(Debug, Deserialize)] -struct ErrorResponse { - error: Error, -} - -#[derive(Debug, Deserialize)] -struct Error { - r#type: Option, - code: Option, - message: Option, - - // Optional fields available on "usage_limit_reached" and "usage_not_included" errors - plan_type: Option, - resets_at: Option, -} - -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ModelClient { config: Arc, auth_manager: Option>, otel_event_manager: OtelEventManager, - client: CodexHttpClient, + http_client: reqwest::Client, provider: ModelProviderInfo, + backend: Arc>, conversation_id: ConversationId, effort: Option, summary: ReasoningSummaryConfig, session_source: SessionSource, } +impl fmt::Debug for ModelClient { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ModelClient") + .field("provider", &self.provider.name) + .field("model", &self.config.model) + .field("conversation_id", &self.conversation_id) + .field("backend_initialized", &self.backend.get().is_some()) + .finish() + } +} + +type ApiClientStream = BoxStream<'static, ApiClientResult>; + +enum ModelBackend { + Responses(ResponsesBackend), + Chat(ChatBackend), +} + +impl ModelBackend { + async fn stream(&self, prompt: Prompt) -> ApiClientResult { + match self { + ModelBackend::Responses(backend) => backend.stream(prompt).await, + ModelBackend::Chat(backend) => backend.stream(prompt).await, + } + } +} + +struct ResponsesBackend { + client: ResponsesApiClient, +} + +impl ResponsesBackend { + async fn stream(&self, prompt: Prompt) -> ApiClientResult { + self.client + .stream(prompt) + .await + .map(futures::StreamExt::boxed) + } +} + +struct ChatBackend { + client: ChatCompletionsApiClient, + show_reasoning: bool, +} + +impl ChatBackend { + async fn stream(&self, prompt: Prompt) -> ApiClientResult { + let stream = self.client.stream(prompt).await?; + let stream = if self.show_reasoning { + stream.streaming_mode().boxed() + } else { + stream.aggregate().boxed() + }; + Ok(stream) + } +} + #[allow(clippy::too_many_arguments)] impl ModelClient { pub fn new( @@ -102,14 +129,16 @@ impl ModelClient { conversation_id: ConversationId, session_source: SessionSource, ) -> Self { - let client = create_client(); + let http_client = create_client().clone_inner(); + let backend = Arc::new(OnceCell::new()); Self { config, auth_manager, otel_event_manager, - client, + http_client, provider, + backend, conversation_id, effort, summary, @@ -139,320 +168,120 @@ impl ModelClient { &self.provider } - pub async fn stream(&self, prompt: &Prompt) -> Result { - match self.provider.wire_api { - WireApi::Responses => self.stream_responses(prompt).await, - WireApi::Chat => { - // Create the raw streaming connection first. - let response_stream = stream_chat_completions( - prompt, - &self.config.model_family, - &self.client, - &self.provider, - &self.otel_event_manager, - &self.session_source, - ) - .await?; + pub fn supports_responses_api_chaining(&self) -> bool { + self.provider.wire_api == WireApi::Responses + && self.config.features.enabled(Feature::ResponsesApiChaining) + } - // Wrap it with the aggregation adapter so callers see *only* - // the final assistant message per turn (matching the - // behaviour of the Responses API). - let mut aggregated = if self.config.show_raw_agent_reasoning { - crate::chat_completions::AggregatedChatStream::streaming_mode(response_stream) - } else { - response_stream.aggregate() - }; + pub async fn stream(&self, prompt: &Prompt) -> Result { + let mut prompt = prompt.clone(); + self.populate_prompt(&mut prompt); + if self.provider.wire_api == WireApi::Responses + && let Some(path) = &*CODEX_RS_SSE_FIXTURE + { + warn!(path, "Streaming from fixture"); + let stream = + stream_from_fixture(path, self.provider.clone(), self.otel_event_manager.clone()) + .await + .map_err(map_api_error)? + .boxed(); + return Ok(wrap_stream(stream)); + } - // Bridge the aggregated stream back into a standard - // `ResponseStream` by forwarding events through a channel. - let (tx, rx) = mpsc::channel::>(16); + let backend = self + .backend + .get_or_try_init(|| async { self.build_backend().await }) + .await + .map_err(map_api_error)?; - tokio::spawn(async move { - use futures::StreamExt; - while let Some(ev) = aggregated.next().await { - // Exit early if receiver hung up. - if tx.send(ev).await.is_err() { - break; - } - } - }); + let api_stream = backend.stream(prompt).await.map_err(map_api_error)?; - Ok(ResponseStream { rx_event: rx }) - } - } + Ok(wrap_stream(api_stream)) } - /// Implementation for the OpenAI *Responses* experimental API. - async fn stream_responses(&self, prompt: &Prompt) -> Result { - if let Some(path) = &*CODEX_RS_SSE_FIXTURE { - // short circuit for tests - warn!(path, "Streaming from fixture"); - return stream_from_fixture( - path, - self.provider.clone(), - self.otel_event_manager.clone(), - ) - .await; + fn populate_prompt(&self, prompt: &mut Prompt) { + if prompt.prompt_cache_key.is_none() { + prompt.prompt_cache_key = Some(self.conversation_id.to_string()); } - let auth_manager = self.auth_manager.clone(); + prompt.session_source = Some(self.session_source.clone()); - let full_instructions = prompt.get_full_instructions(&self.config.model_family); - let tools_json = create_tools_json_for_responses_api(&prompt.tools)?; - let reasoning = create_reasoning_param_for_request( + prompt.reasoning = create_reasoning_param_for_request( &self.config.model_family, self.effort, self.summary, ); - let include: Vec = if reasoning.is_some() { - vec!["reasoning.encrypted_content".to_string()] - } else { - vec![] - }; - - let input_with_instructions = prompt.get_formatted_input(); - let verbosity = if self.config.model_family.support_verbosity { self.config.model_verbosity } else { - warn!( - "model_verbosity is set but ignored as the model does not support verbosity: {}", - self.config.model_family.family - ); + if self.config.model_verbosity.is_some() { + warn!( + "model_verbosity is set but ignored as the model does not support verbosity: {}", + self.config.model_family.family + ); + } None }; - // Only include `text.verbosity` for GPT-5 family models - let text = create_text_param_for_request(verbosity, &prompt.output_schema); - - // In general, we want to explicitly send `store: false` when using the Responses API, - // but in practice, the Azure Responses API rejects `store: false`: - // - // - If store = false and id is sent an error is thrown that ID is not found - // - If store = false and id is not sent an error is thrown that ID is required - // - // For Azure, we send `store: true` and preserve reasoning item IDs. - let azure_workaround = self.provider.is_azure_responses_endpoint(); - - let payload = ResponsesApiRequest { - model: &self.config.model, - instructions: &full_instructions, - input: &input_with_instructions, - tools: &tools_json, - tool_choice: "auto", - parallel_tool_calls: prompt.parallel_tool_calls, - reasoning, - store: azure_workaround, - stream: true, - include, - prompt_cache_key: Some(self.conversation_id.to_string()), - text, - }; - - let mut payload_json = serde_json::to_value(&payload)?; - if azure_workaround { - attach_item_ids(&mut payload_json, &input_with_instructions); - } - - let max_attempts = self.provider.request_max_retries(); - for attempt in 0..=max_attempts { - match self - .attempt_stream_responses(attempt, &payload_json, &auth_manager) - .await - { - Ok(stream) => { - return Ok(stream); - } - Err(StreamAttemptError::Fatal(e)) => { - return Err(e); - } - Err(retryable_attempt_error) => { - if attempt == max_attempts { - return Err(retryable_attempt_error.into_error()); - } + prompt.text_controls = create_text_param_for_request(verbosity, &prompt.output_schema); + } - tokio::time::sleep(retryable_attempt_error.delay(attempt)).await; - } - } + async fn build_backend(&self) -> ApiClientResult { + match self.provider.wire_api { + WireApi::Responses => self.build_responses_backend().await, + WireApi::Chat => self.build_chat_backend().await, } - - unreachable!("stream_responses_attempt should always return"); } - /// Single attempt to start a streaming Responses API call. - async fn attempt_stream_responses( - &self, - attempt: u64, - payload_json: &Value, - auth_manager: &Option>, - ) -> std::result::Result { - // Always fetch the latest auth in case a prior attempt refreshed the token. - let auth = auth_manager.as_ref().and_then(|m| m.auth()); + async fn build_responses_backend(&self) -> ApiClientResult { + let auth_provider = self.auth_manager.as_ref().map(|manager| { + Arc::new(AuthManagerProvider::new(Arc::clone(manager))) as Arc + }); - trace!( - "POST to {}: {:?}", - self.provider.get_full_url(&auth), - serde_json::to_string(payload_json) - .unwrap_or("".to_string()) - ); + let config = ResponsesApiClientConfig { + http_client: self.http_client.clone(), + provider: self.provider.clone(), + model: self.config.model.clone(), + conversation_id: self.conversation_id, + auth_provider, + otel_event_manager: self.otel_event_manager.clone(), + }; - let mut req_builder = self - .provider - .create_request_builder(&self.client, &auth) - .await - .map_err(StreamAttemptError::Fatal)?; + let client = ResponsesApiClient::new(config).await?; + Ok(ModelBackend::Responses(ResponsesBackend { client })) + } - // Include subagent header only for subagent sessions. - if let SessionSource::SubAgent(sub) = &self.session_source { - let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub { - label.clone() + async fn build_chat_backend(&self) -> ApiClientResult { + let show_reasoning = self.config.show_raw_agent_reasoning; + let config = ChatCompletionsApiClientConfig { + http_client: self.http_client.clone(), + provider: self.provider.clone(), + model: self.config.model.clone(), + otel_event_manager: self.otel_event_manager.clone(), + session_source: self.session_source.clone(), + aggregation_mode: if show_reasoning { + ChatAggregationMode::Streaming } else { - serde_json::to_value(sub) - .ok() - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) - .unwrap_or_else(|| "other".to_string()) - }; - req_builder = req_builder.header("x-openai-subagent", subagent); - } - - req_builder = req_builder - // Send session_id for compatibility. - .header("conversation_id", self.conversation_id.to_string()) - .header("session_id", self.conversation_id.to_string()) - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(payload_json); - - if let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Some(account_id) = auth.get_account_id() - { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } - - let res = self - .otel_event_manager - .log_request(attempt, || req_builder.send()) - .await; - - let mut request_id = None; - if let Ok(resp) = &res { - request_id = resp - .headers() - .get("cf-ray") - .map(|v| v.to_str().unwrap_or_default().to_string()); - } - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - - if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) - && tx_event - .send(Ok(ResponseEvent::RateLimits(snapshot))) - .await - .is_err() - { - debug!("receiver dropped rate limit snapshot event"); - } - - // spawn task to process SSE - let stream = resp.bytes_stream().map_err(move |e| { - CodexErr::ResponseStreamFailed(ResponseStreamFailed { - source: e, - request_id: request_id.clone(), - }) - }); - tokio::spawn(process_sse( - stream, - tx_event, - self.provider.stream_idle_timeout(), - self.otel_event_manager.clone(), - )); - - Ok(ResponseStream { rx_event }) - } - Ok(res) => { - let status = res.status(); - - // Pull out Retry‑After header if present. - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000)); - - if status == StatusCode::UNAUTHORIZED - && let Some(manager) = auth_manager.as_ref() - && let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - { - manager.refresh_token().await.map_err(|err| { - StreamAttemptError::Fatal(CodexErr::Fatal(format!( - "Failed to refresh ChatGPT credentials: {err}" - ))) - })?; - } - - // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx - // errors. When we bubble early with only the HTTP status the caller sees an opaque - // "unexpected status 400 Bad Request" which makes debugging nearly impossible. - // Instead, read (and include) the response text so higher layers and users see the - // exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is - // small and this branch only runs on error paths so the extra allocation is - // negligible. - if !(status == StatusCode::TOO_MANY_REQUESTS - || status == StatusCode::UNAUTHORIZED - || status.is_server_error()) - { - // Surface the error body to callers. Use `unwrap_or_default` per Clippy. - let body = res.text().await.unwrap_or_default(); - return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus( - UnexpectedResponseError { - status, - body, - request_id: None, - }, - ))); - } + ChatAggregationMode::AggregatedOnly + }, + }; - if status == StatusCode::TOO_MANY_REQUESTS { - let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers()); - let body = res.json::().await.ok(); - if let Some(ErrorResponse { error }) = body { - if error.r#type.as_deref() == Some("usage_limit_reached") { - // Prefer the plan_type provided in the error message if present - // because it's more up to date than the one encoded in the auth - // token. - let plan_type = error - .plan_type - .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); - let resets_at = error - .resets_at - .and_then(|seconds| DateTime::::from_timestamp(seconds, 0)); - let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError { - plan_type, - resets_at, - rate_limits: rate_limit_snapshot, - }); - return Err(StreamAttemptError::Fatal(codex_err)); - } else if error.r#type.as_deref() == Some("usage_not_included") { - return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded)); - } - } - } + let client = ChatCompletionsApiClient::new(config).await?; + Ok(ModelBackend::Chat(ChatBackend { + client, + show_reasoning, + })) + } - Err(StreamAttemptError::RetryableHttpError { - status, - retry_after, - request_id, - }) - } - Err(e) => Err(StreamAttemptError::RetryableTransportError( - CodexErr::ConnectionFailed(ConnectionFailedError { source: e }), - )), - } + pub async fn stream_for_test(&self, mut prompt: Prompt) -> Result { + crate::conversation_history::format_prompt_items(&mut prompt.input, false); + let instructions = + crate::client_common::compute_full_instructions(None, &self.config.model_family, false) + .into_owned(); + prompt.instructions = instructions; + prompt.previous_response_id = None; + self.stream(&prompt).await } pub fn get_provider(&self) -> ModelProviderInfo { @@ -492,965 +321,91 @@ impl ModelClient { } } -enum StreamAttemptError { - RetryableHttpError { - status: StatusCode, - retry_after: Option, - request_id: Option, - }, - RetryableTransportError(CodexErr), - Fatal(CodexErr), +struct AuthManagerProvider { + manager: Arc, } -impl StreamAttemptError { - /// attempt is 0-based. - fn delay(&self, attempt: u64) -> Duration { - // backoff() uses 1-based attempts. - let backoff_attempt = attempt + 1; - match self { - Self::RetryableHttpError { retry_after, .. } => { - retry_after.unwrap_or_else(|| backoff(backoff_attempt)) - } - Self::RetryableTransportError { .. } => backoff(backoff_attempt), - Self::Fatal(_) => { - // Should not be called on Fatal errors. - Duration::from_secs(0) - } - } - } - - fn into_error(self) -> CodexErr { - match self { - Self::RetryableHttpError { - status, request_id, .. - } => { - if status == StatusCode::INTERNAL_SERVER_ERROR { - CodexErr::InternalServerError - } else { - CodexErr::RetryLimit(RetryLimitReachedError { status, request_id }) - } - } - Self::RetryableTransportError(error) => error, - Self::Fatal(error) => error, - } - } -} - -#[derive(Debug, Deserialize, Serialize)] -struct SseEvent { - #[serde(rename = "type")] - kind: String, - response: Option, - item: Option, - delta: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompleted { - id: String, - usage: Option, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedUsage { - input_tokens: i64, - input_tokens_details: Option, - output_tokens: i64, - output_tokens_details: Option, - total_tokens: i64, -} - -impl From for TokenUsage { - fn from(val: ResponseCompletedUsage) -> Self { - TokenUsage { - input_tokens: val.input_tokens, - cached_input_tokens: val - .input_tokens_details - .map(|d| d.cached_tokens) - .unwrap_or(0), - output_tokens: val.output_tokens, - reasoning_output_tokens: val - .output_tokens_details - .map(|d| d.reasoning_tokens) - .unwrap_or(0), - total_tokens: val.total_tokens, - } - } -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedInputTokensDetails { - cached_tokens: i64, -} - -#[derive(Debug, Deserialize)] -struct ResponseCompletedOutputTokensDetails { - reasoning_tokens: i64, -} - -fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { - let Some(input_value) = payload_json.get_mut("input") else { - return; - }; - let serde_json::Value::Array(items) = input_value else { - return; - }; - - for (value, item) in items.iter_mut().zip(original_items.iter()) { - if let ResponseItem::Reasoning { id, .. } - | ResponseItem::Message { id: Some(id), .. } - | ResponseItem::WebSearchCall { id: Some(id), .. } - | ResponseItem::FunctionCall { id: Some(id), .. } - | ResponseItem::LocalShellCall { id: Some(id), .. } - | ResponseItem::CustomToolCall { id: Some(id), .. } = item - { - if id.is_empty() { - continue; - } - - if let Some(obj) = value.as_object_mut() { - obj.insert("id".to_string(), Value::String(id.clone())); - } - } +impl AuthManagerProvider { + fn new(manager: Arc) -> Self { + Self { manager } } } -fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option { - let primary = parse_rate_limit_window( - headers, - "x-codex-primary-used-percent", - "x-codex-primary-window-minutes", - "x-codex-primary-reset-at", - ); - - let secondary = parse_rate_limit_window( - headers, - "x-codex-secondary-used-percent", - "x-codex-secondary-window-minutes", - "x-codex-secondary-reset-at", - ); - - Some(RateLimitSnapshot { primary, secondary }) -} - -fn parse_rate_limit_window( - headers: &HeaderMap, - used_percent_header: &str, - window_minutes_header: &str, - resets_at_header: &str, -) -> Option { - let used_percent: Option = parse_header_f64(headers, used_percent_header); - - used_percent.and_then(|used_percent| { - let window_minutes = parse_header_i64(headers, window_minutes_header); - let resets_at = parse_header_i64(headers, resets_at_header); - - let has_data = used_percent != 0.0 - || window_minutes.is_some_and(|minutes| minutes != 0) - || resets_at.is_some(); - - has_data.then_some(RateLimitWindow { - used_percent, - window_minutes, - resets_at, - }) - }) -} - -fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)? - .parse::() - .ok() - .filter(|v| v.is_finite()) -} - -fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option { - parse_header_str(headers, name)?.parse::().ok() -} - -fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> { - headers.get(name)?.to_str().ok() -} - -async fn process_sse( - stream: S, - tx_event: mpsc::Sender>, - idle_timeout: Duration, - otel_event_manager: OtelEventManager, -) where - S: Stream> + Unpin, -{ - let mut stream = stream.eventsource(); - - // If the stream stays completely silent for an extended period treat it as disconnected. - // The response id returned from the "complete" message. - let mut response_completed: Option = None; - let mut response_error: Option = None; - - loop { - let start = std::time::Instant::now(); - let response = timeout(idle_timeout, stream.next()).await; - let duration = start.elapsed(); - otel_event_manager.log_sse_event(&response, duration); - - let sse = match response { - Ok(Some(Ok(sse))) => sse, - Ok(Some(Err(e))) => { - debug!("SSE Error: {e:#}"); - let event = CodexErr::Stream(e.to_string(), None); - let _ = tx_event.send(Err(event)).await; - return; - } - Ok(None) => { - match response_completed { - Some(ResponseCompleted { - id: response_id, - usage, - }) => { - if let Some(token_usage) = &usage { - otel_event_manager.sse_event_completed( - token_usage.input_tokens, - token_usage.output_tokens, - token_usage - .input_tokens_details - .as_ref() - .map(|d| d.cached_tokens), - token_usage - .output_tokens_details - .as_ref() - .map(|d| d.reasoning_tokens), - token_usage.total_tokens, - ); - } - let event = ResponseEvent::Completed { - response_id, - token_usage: usage.map(Into::into), - }; - let _ = tx_event.send(Ok(event)).await; - } - None => { - let error = response_error.unwrap_or(CodexErr::Stream( - "stream closed before response.completed".into(), - None, - )); - otel_event_manager.see_event_completed_failed(&error); - - let _ = tx_event.send(Err(error)).await; - } - } - return; - } - Err(_) => { - let _ = tx_event - .send(Err(CodexErr::Stream( - "idle timeout waiting for SSE".into(), - None, - ))) - .await; - return; +#[async_trait] +impl AuthProvider for AuthManagerProvider { + async fn auth_context(&self) -> Option { + let auth = self.manager.auth()?; + let mode = auth.mode; + let account_id = auth.get_account_id(); + let bearer_token = match auth.get_token().await { + Ok(token) if !token.is_empty() => Some(token), + Ok(_) => None, + Err(err) => { + warn!("failed to resolve auth token: {err}"); + None } }; - let raw = sse.data.clone(); - trace!("SSE event: {}", raw); - - let event: SseEvent = match serde_json::from_str(&sse.data) { - Ok(event) => event, - Err(e) => { - debug!("Failed to parse SSE event: {e}, data: {}", &sse.data); - continue; - } - }; - - match event.kind.as_str() { - // Individual output item finalised. Forward immediately so the - // rest of the agent can stream assistant text/functions *live* - // instead of waiting for the final `response.completed` envelope. - // - // IMPORTANT: We used to ignore these events and forward the - // duplicated `output` array embedded in the `response.completed` - // payload. That produced two concrete issues: - // 1. No real‑time streaming – the user only saw output after the - // entire turn had finished, which broke the "typing" UX and - // made long‑running turns look stalled. - // 2. Duplicate `function_call_output` items – both the - // individual *and* the completed array were forwarded, which - // confused the backend and triggered 400 - // "previous_response_not_found" errors because the duplicated - // IDs did not match the incremental turn chain. - // - // The fix is to forward the incremental events *as they come* and - // drop the duplicated list inside `response.completed`. - "response.output_item.done" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemDone(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.output_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::OutputTextDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_summary_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::ReasoningSummaryDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.reasoning_text.delta" => { - if let Some(delta) = event.delta { - let event = ResponseEvent::ReasoningContentDelta(delta); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - } - "response.created" => { - if event.response.is_some() { - let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; - } - } - "response.failed" => { - if let Some(resp_val) = event.response { - response_error = Some(CodexErr::Stream( - "response.failed event received".to_string(), - None, - )); - - let error = resp_val.get("error"); - - if let Some(error) = error { - match serde_json::from_value::(error.clone()) { - Ok(error) => { - if is_context_window_error(&error) { - response_error = Some(CodexErr::ContextWindowExceeded); - } else { - let delay = try_parse_retry_after(&error); - let message = error.message.clone().unwrap_or_default(); - response_error = Some(CodexErr::Stream(message, delay)); - } - } - Err(e) => { - let error = format!("failed to parse ErrorResponse: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)) - } - } - } - } - } - // Final response completed – includes array of output items & id - "response.completed" => { - if let Some(resp_val) = event.response { - match serde_json::from_value::(resp_val) { - Ok(r) => { - response_completed = Some(r); - } - Err(e) => { - let error = format!("failed to parse ResponseCompleted: {e}"); - debug!(error); - response_error = Some(CodexErr::Stream(error, None)); - continue; - } - }; - }; - } - "response.content_part.done" - | "response.function_call_arguments.delta" - | "response.custom_tool_call_input.delta" - | "response.custom_tool_call_input.done" // also emitted as response.output_item.done - | "response.in_progress" - | "response.output_text.done" => {} - "response.output_item.added" => { - let Some(item_val) = event.item else { continue }; - let Ok(item) = serde_json::from_value::(item_val) else { - debug!("failed to parse ResponseItem from output_item.done"); - continue; - }; - - let event = ResponseEvent::OutputItemAdded(item); - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_part.added" => { - // Boundary between reasoning summary sections (e.g., titles). - let event = ResponseEvent::ReasoningSummaryPartAdded; - if tx_event.send(Ok(event)).await.is_err() { - return; - } - } - "response.reasoning_summary_text.done" => {} - _ => {} - } + Some(AuthContext { + mode, + bearer_token, + account_id, + }) } -} -/// used in tests to stream from a text SSE file -async fn stream_from_fixture( - path: impl AsRef, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, -) -> Result { - let (tx_event, rx_event) = mpsc::channel::>(1600); - let f = std::fs::File::open(path.as_ref())?; - let lines = std::io::BufReader::new(f).lines(); - - // insert \n\n after each line for proper SSE parsing - let mut content = String::new(); - for line in lines { - content.push_str(&line?); - content.push_str("\n\n"); + async fn refresh_token(&self) -> std::result::Result, String> { + self.manager + .refresh_token() + .await + .map_err(|err| err.to_string()) } - - let rdr = std::io::Cursor::new(content); - let stream = ReaderStream::new(rdr).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx_event, - provider.stream_idle_timeout(), - otel_event_manager, - )); - Ok(ResponseStream { rx_event }) } -fn rate_limit_regex() -> &'static Regex { - static RE: OnceLock = OnceLock::new(); - - #[expect(clippy::unwrap_used)] - RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").unwrap()) -} +fn wrap_stream(stream: ApiClientStream) -> ResponseStream { + let (tx, rx) = mpsc::channel::>(1600); -fn try_parse_retry_after(err: &Error) -> Option { - if err.code != Some("rate_limit_exceeded".to_string()) { - return None; - } - - // parse the Please try again in 1.898s format using regex - let re = rate_limit_regex(); - if let Some(message) = &err.message - && let Some(captures) = re.captures(message) - { - let seconds = captures.get(1); - let unit = captures.get(2); - - if let (Some(value), Some(unit)) = (seconds, unit) { - let value = value.as_str().parse::().ok()?; - let unit = unit.as_str(); + tokio::spawn(async move { + let mut stream = stream; + while let Some(item) = stream.next().await { + let mapped = match item { + Ok(event) => Ok(event), + Err(err) => Err(map_api_error(err)), + }; - if unit == "s" { - return Some(Duration::from_secs_f64(value)); - } else if unit == "ms" { - return Some(Duration::from_millis(value as u64)); + if tx.send(mapped).await.is_err() { + break; } } - } - None -} + }); -fn is_context_window_error(error: &Error) -> bool { - error.code.as_deref() == Some("context_length_exceeded") + codex_api_client::EventStream::from_receiver(rx) } -#[cfg(test)] -mod tests { - use super::*; - use assert_matches::assert_matches; - use serde_json::json; - use tokio::sync::mpsc; - use tokio_test::io::Builder as IoBuilder; - use tokio_util::io::ReaderStream; - - // ──────────────────────────── - // Helpers - // ──────────────────────────── - - /// Runs the SSE parser on pre-chunked byte slices and returns every event - /// (including any final `Err` from a stream-closure check). - async fn collect_events( - chunks: &[&[u8]], - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec> { - let mut builder = IoBuilder::new(); - for chunk in chunks { - builder.read(chunk); - } - - let reader = builder.build(); - let stream = ReaderStream::new(reader).map_err(CodexErr::Io); - let (tx, mut rx) = mpsc::channel::>(16); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut events = Vec::new(); - while let Some(ev) = rx.recv().await { - events.push(ev); - } - events - } - - /// Builds an in-memory SSE stream from JSON fixtures and returns only the - /// successfully parsed events (panics on internal channel errors). - async fn run_sse( - events: Vec, - provider: ModelProviderInfo, - otel_event_manager: OtelEventManager, - ) -> Vec { - let mut body = String::new(); - for e in events { - let kind = e - .get("type") - .and_then(|v| v.as_str()) - .expect("fixture event missing type"); - if e.as_object().map(|o| o.len() == 1).unwrap_or(false) { - body.push_str(&format!("event: {kind}\n\n")); - } else { - body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); - } - } - - let (tx, mut rx) = mpsc::channel::>(8); - let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io); - tokio::spawn(process_sse( - stream, - tx, - provider.stream_idle_timeout(), - otel_event_manager, - )); - - let mut out = Vec::new(); - while let Some(ev) = rx.recv().await { - out.push(ev.expect("channel closed")); - } - out - } - - fn otel_event_manager() -> OtelEventManager { - OtelEventManager::new( - ConversationId::new(), - "test", - "test", - None, - Some("test@test.com".to_string()), - Some(AuthMode::ChatGPT), - false, - "test".to_string(), - ) - } - - // ──────────────────────────── - // Tests from `implement-test-for-responses-api-sse-parser` - // ──────────────────────────── - - #[tokio::test] - async fn parses_items_and_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let item2 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "World"}] - } - }) - .to_string(); - - let completed = json!({ - "type": "response.completed", - "response": { "id": "resp1" } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); - let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events( - &[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()], - provider, - otel_event_manager, - ) - .await; - - assert_eq!(events.len(), 3); - - matches!( - &events[0], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - matches!( - &events[1], - Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) - if role == "assistant" - ); - - match &events[2] { - Ok(ResponseEvent::Completed { - response_id, - token_usage, - }) => { - assert_eq!(response_id, "resp1"); - assert!(token_usage.is_none()); - } - other => panic!("unexpected third event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_missing_completed() { - let item1 = json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": "Hello"}] - } - }) - .to_string(); - - let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 2); - - matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); - - match &events[1] { - Err(CodexErr::Stream(msg, _)) => { - assert_eq!(msg, "stream closed before response.completed") - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn error_when_error_event() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_689bcf18d7f08194bf3440ba62fe05d803fee0cdac429894","object":"response","created_at":1755041560,"status":"failed","background":false,"error":{"code":"rate_limit_exceeded","message":"Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more."}, "usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(CodexErr::Stream(msg, delay)) => { - assert_eq!( - msg, - "Rate limit reached for gpt-5 in organization org-AAA on tokens per min (TPM): Limit 30000, Used 22999, Requested 12528. Please try again in 11.054s. Visit https://platform.openai.com/account/rate-limits to learn more." - ); - assert_eq!(*delay, Some(Duration::from_secs_f64(11.054))); - } - other => panic!("unexpected second event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), - } - } - - #[tokio::test] - async fn context_window_error_with_newline_is_fatal() { - let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#; - - let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; - - assert_eq!(events.len(), 1); - - match &events[0] { - Err(err @ CodexErr::ContextWindowExceeded) => { - assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); - } - other => panic!("unexpected context window event: {other:?}"), +fn map_api_error(err: codex_api_client::Error) -> CodexErr { + match err { + codex_api_client::Error::UnsupportedOperation(msg) => CodexErr::UnsupportedOperation(msg), + codex_api_client::Error::Http(source) => { + CodexErr::ConnectionFailed(ConnectionFailedError { source }) } - } - - // ──────────────────────────── - // Table-driven test from `main` - // ──────────────────────────── - - /// Verifies that the adapter produces the right `ResponseEvent` for a - /// variety of incoming `type` values. - #[tokio::test] - async fn table_driven_event_kinds() { - struct TestCase { - name: &'static str, - event: serde_json::Value, - expect_first: fn(&ResponseEvent) -> bool, - expected_len: usize, + codex_api_client::Error::ResponseStreamFailed { source, request_id } => { + CodexErr::ResponseStreamFailed(ResponseStreamFailed { source, request_id }) } - - fn is_created(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Created) + codex_api_client::Error::Stream(message, delay) => CodexErr::Stream(message, delay), + codex_api_client::Error::UnexpectedStatus { status, body } => { + CodexErr::UnexpectedStatus(UnexpectedResponseError { + status, + body, + request_id: None, + }) } - fn is_output(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::OutputItemDone(_)) + codex_api_client::Error::RetryLimit { status, request_id } => { + CodexErr::RetryLimit(RetryLimitReachedError { status, request_id }) } - fn is_completed(ev: &ResponseEvent) -> bool { - matches!(ev, ResponseEvent::Completed { .. }) + codex_api_client::Error::MissingEnvVar { var, instructions } => { + CodexErr::EnvVar(EnvVarError { var, instructions }) } - - let completed = json!({ - "type": "response.completed", - "response": { - "id": "c", - "usage": { - "input_tokens": 0, - "input_tokens_details": null, - "output_tokens": 0, - "output_tokens_details": null, - "total_tokens": 0 - }, - "output": [] - } - }); - - let cases = vec![ - TestCase { - name: "created", - event: json!({"type": "response.created", "response": {}}), - expect_first: is_created, - expected_len: 2, - }, - TestCase { - name: "output_item.done", - event: json!({ - "type": "response.output_item.done", - "item": { - "type": "message", - "role": "assistant", - "content": [ - {"type": "output_text", "text": "hi"} - ] - } - }), - expect_first: is_output, - expected_len: 2, - }, - TestCase { - name: "unknown", - event: json!({"type": "response.new_tool_event"}), - expect_first: is_completed, - expected_len: 1, - }, - ]; - - for case in cases { - let mut evs = vec![case.event]; - evs.push(completed.clone()); - - let provider = ModelProviderInfo { - name: "test".to_string(), - base_url: Some("https://test.com".to_string()), - env_key: Some("TEST_API_KEY".to_string()), - env_key_instructions: None, - experimental_bearer_token: None, - wire_api: WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: Some(1000), - requires_openai_auth: false, - }; - - let otel_event_manager = otel_event_manager(); - - let out = run_sse(evs, provider, otel_event_manager).await; - assert_eq!(out.len(), case.expected_len, "case {}", case.name); - assert!( - (case.expect_first)(&out[0]), - "first event mismatch in case {}", - case.name - ); - } - } - - #[test] - fn test_try_parse_retry_after() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5 in organization org- on tokens per min (TPM): Limit 1, Used 1, Requested 19304. Please try again in 28ms. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; - - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_millis(28))); - } - - #[test] - fn test_try_parse_retry_after_no_delay() { - let err = Error { - r#type: None, - message: Some("Rate limit reached for gpt-5 in organization on tokens per min (TPM): Limit 30000, Used 6899, Requested 24050. Please try again in 1.898s. Visit https://platform.openai.com/account/rate-limits to learn more.".to_string()), - code: Some("rate_limit_exceeded".to_string()), - plan_type: None, - resets_at: None - }; - let delay = try_parse_retry_after(&err); - assert_eq!(delay, Some(Duration::from_secs_f64(1.898))); - } - - #[test] - fn error_response_deserializes_schema_known_plan_type_and_serializes_back() { - use crate::token_data::KnownPlan; - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"pro","resets_at":1704067200}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Known(KnownPlan::Pro))); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"pro\""); - } - - #[test] - fn error_response_deserializes_schema_unknown_plan_type_and_serializes_back() { - use crate::token_data::PlanType; - - let json = - r#"{"error":{"type":"usage_limit_reached","plan_type":"vip","resets_at":1704067260}}"#; - let resp: ErrorResponse = serde_json::from_str(json).expect("should deserialize schema"); - - assert_matches!(resp.error.plan_type, Some(PlanType::Unknown(ref s)) if s == "vip"); - - let plan_json = serde_json::to_string(&resp.error.plan_type).expect("serialize plan_type"); - assert_eq!(plan_json, "\"vip\""); + codex_api_client::Error::Auth(message) => CodexErr::Fatal(message), + codex_api_client::Error::Json(err) => CodexErr::Json(err), + codex_api_client::Error::Other(message) => CodexErr::Fatal(message), } } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 2ac02f5f66..5ecc8a26a0 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,348 +1,45 @@ -use crate::client_common::tools::ToolSpec; +use std::borrow::Cow; +use std::ops::Deref; + use crate::error::Result; -use crate::model_family::ModelFamily; -use crate::protocol::RateLimitSnapshot; -use crate::protocol::TokenUsage; +use codex_api_client::EventStream; +pub use codex_api_client::Prompt; +pub use codex_api_client::Reasoning; +pub use codex_api_client::TextControls; +pub use codex_api_client::TextFormat; +pub use codex_api_client::TextFormatType; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::config_types::Verbosity as VerbosityConfig; -use codex_protocol::models::ResponseItem; -use futures::Stream; -use serde::Deserialize; -use serde::Serialize; use serde_json::Value; -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Deref; -use std::pin::Pin; -use std::task::Context; -use std::task::Poll; -use tokio::sync::mpsc; + +use crate::model_family::ModelFamily; /// Review thread system prompt. Edit `core/src/review_prompt.md` to customize. pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md"); -// Centralized templates for review-related user messages pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exit_success.xml"); pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str = include_str!("../templates/review/exit_interrupted.xml"); -/// API request payload for a single model turn -#[derive(Default, Debug, Clone)] -pub struct Prompt { - /// Conversation context input items. - pub input: Vec, - - /// Tools available to the model, including additional tools sourced from - /// external MCP servers. - pub(crate) tools: Vec, - - /// Whether parallel tool calls are permitted for this prompt. - pub(crate) parallel_tool_calls: bool, - - /// Optional override for the built-in BASE_INSTRUCTIONS. - pub base_instructions_override: Option, - - /// Optional the output schema for the model's response. - pub output_schema: Option, -} - -impl Prompt { - pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> { - let base = self - .base_instructions_override - .as_deref() - .unwrap_or(model.base_instructions.deref()); - // When there are no custom instructions, add apply_patch_tool_instructions if: - // - the model needs special instructions (4.1) - // AND - // - there is no apply_patch tool present - let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool { - ToolSpec::Function(f) => f.name == "apply_patch", - ToolSpec::Freeform(f) => f.name == "apply_patch", - _ => false, - }); - if self.base_instructions_override.is_none() - && model.needs_special_apply_patch_instructions - && !is_apply_patch_tool_present - { - Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}")) - } else { - Cow::Borrowed(base) - } - } - - pub(crate) fn get_formatted_input(&self) -> Vec { - let mut input = self.input.clone(); - - // when using the *Freeform* apply_patch tool specifically, tool outputs - // should be structured text, not json. Do NOT reserialize when using - // the Function tool - note that this differs from the check above for - // instructions. We declare the result as a named variable for clarity. - let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool { - ToolSpec::Freeform(f) => f.name == "apply_patch", - _ => false, - }); - if is_freeform_apply_patch_tool_present { - reserialize_shell_outputs(&mut input); - } - - input - } -} - -fn reserialize_shell_outputs(items: &mut [ResponseItem]) { - let mut shell_call_ids: HashSet = HashSet::new(); - - items.iter_mut().for_each(|item| match item { - ResponseItem::LocalShellCall { call_id, id, .. } => { - if let Some(identifier) = call_id.clone().or_else(|| id.clone()) { - shell_call_ids.insert(identifier); - } - } - ResponseItem::CustomToolCall { - id: _, - status: _, - call_id, - name, - input: _, - } => { - if name == "apply_patch" { - shell_call_ids.insert(call_id.clone()); - } - } - ResponseItem::CustomToolCallOutput { call_id, output } => { - if shell_call_ids.remove(call_id) - && let Some(structured) = parse_structured_shell_output(output) - { - *output = structured - } - } - ResponseItem::FunctionCall { name, call_id, .. } - if is_shell_tool_name(name) || name == "apply_patch" => - { - shell_call_ids.insert(call_id.clone()); - } - ResponseItem::FunctionCallOutput { call_id, output } => { - if shell_call_ids.remove(call_id) - && let Some(structured) = parse_structured_shell_output(&output.content) - { - output.content = structured - } - } - _ => {} - }) -} - -fn is_shell_tool_name(name: &str) -> bool { - matches!(name, "shell" | "container.exec") -} - -#[derive(Deserialize)] -struct ExecOutputJson { - output: String, - metadata: ExecOutputMetadataJson, -} - -#[derive(Deserialize)] -struct ExecOutputMetadataJson { - exit_code: i32, - duration_seconds: f32, -} - -fn parse_structured_shell_output(raw: &str) -> Option { - let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?; - Some(build_structured_output(&parsed)) -} - -fn build_structured_output(parsed: &ExecOutputJson) -> String { - let mut sections = Vec::new(); - sections.push(format!("Exit code: {}", parsed.metadata.exit_code)); - sections.push(format!( - "Wall time: {} seconds", - parsed.metadata.duration_seconds - )); - - let mut output = parsed.output.clone(); - if let Some(total_lines) = extract_total_output_lines(&parsed.output) { - sections.push(format!("Total output lines: {total_lines}")); - if let Some(stripped) = strip_total_output_header(&output) { - output = stripped.to_string(); - } - } - - sections.push("Output:".to_string()); - sections.push(output); - - sections.join("\n") -} - -fn extract_total_output_lines(output: &str) -> Option { - let marker_start = output.find("[... omitted ")?; - let marker = &output[marker_start..]; - let (_, after_of) = marker.split_once(" of ")?; - let (total_segment, _) = after_of.split_once(' ')?; - total_segment.parse::().ok() -} - -fn strip_total_output_header(output: &str) -> Option<&str> { - let after_prefix = output.strip_prefix("Total output lines: ")?; - let (_, remainder) = after_prefix.split_once('\n')?; - let remainder = remainder.strip_prefix('\n').unwrap_or(remainder); - Some(remainder) -} - -#[derive(Debug)] -pub enum ResponseEvent { - Created, - OutputItemDone(ResponseItem), - OutputItemAdded(ResponseItem), - Completed { - response_id: String, - token_usage: Option, - }, - OutputTextDelta(String), - ReasoningSummaryDelta(String), - ReasoningContentDelta(String), - ReasoningSummaryPartAdded, - RateLimits(RateLimitSnapshot), -} - -#[derive(Debug, Serialize)] -pub(crate) struct Reasoning { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) effort: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) summary: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "snake_case")] -pub(crate) enum TextFormatType { - #[default] - JsonSchema, -} - -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextFormat { - pub(crate) r#type: TextFormatType, - pub(crate) strict: bool, - pub(crate) schema: Value, - pub(crate) name: String, -} - -/// Controls under the `text` field in the Responses API for GPT-5. -#[derive(Debug, Serialize, Default, Clone)] -pub(crate) struct TextControls { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) verbosity: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) format: Option, -} - -#[derive(Debug, Serialize, Default, Clone)] -#[serde(rename_all = "lowercase")] -pub(crate) enum OpenAiVerbosity { - Low, - #[default] - Medium, - High, -} - -impl From for OpenAiVerbosity { - fn from(v: VerbosityConfig) -> Self { - match v { - VerbosityConfig::Low => OpenAiVerbosity::Low, - VerbosityConfig::Medium => OpenAiVerbosity::Medium, - VerbosityConfig::High => OpenAiVerbosity::High, - } +pub fn compute_full_instructions<'a>( + base_override: Option<&'a str>, + model: &'a ModelFamily, + is_apply_patch_present: bool, +) -> Cow<'a, str> { + let base = base_override.unwrap_or(model.base_instructions.deref()); + if base_override.is_none() + && model.needs_special_apply_patch_instructions + && !is_apply_patch_present + { + Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}")) + } else { + Cow::Borrowed(base) } } -/// Request object that is serialized as JSON and POST'ed when using the -/// Responses API. -#[derive(Debug, Serialize)] -pub(crate) struct ResponsesApiRequest<'a> { - pub(crate) model: &'a str, - pub(crate) instructions: &'a str, - // TODO(mbolin): ResponseItem::Other should not be serialized. Currently, - // we code defensively to avoid this case, but perhaps we should use a - // separate enum for serialization. - pub(crate) input: &'a Vec, - pub(crate) tools: &'a [serde_json::Value], - pub(crate) tool_choice: &'static str, - pub(crate) parallel_tool_calls: bool, - pub(crate) reasoning: Option, - pub(crate) store: bool, - pub(crate) stream: bool, - pub(crate) include: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) prompt_cache_key: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) text: Option, -} - -pub(crate) mod tools { - use crate::tools::spec::JsonSchema; - use serde::Deserialize; - use serde::Serialize; - - /// When serialized as JSON, this produces a valid "Tool" in the OpenAI - /// Responses API. - #[derive(Debug, Clone, Serialize, PartialEq)] - #[serde(tag = "type")] - pub(crate) enum ToolSpec { - #[serde(rename = "function")] - Function(ResponsesApiTool), - #[serde(rename = "local_shell")] - LocalShell {}, - // TODO: Understand why we get an error on web_search although the API docs say it's supported. - // https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C - #[serde(rename = "web_search")] - WebSearch {}, - #[serde(rename = "custom")] - Freeform(FreeformTool), - } - - impl ToolSpec { - pub(crate) fn name(&self) -> &str { - match self { - ToolSpec::Function(tool) => tool.name.as_str(), - ToolSpec::LocalShell {} => "local_shell", - ToolSpec::WebSearch {} => "web_search", - ToolSpec::Freeform(tool) => tool.name.as_str(), - } - } - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct FreeformTool { - pub(crate) name: String, - pub(crate) description: String, - pub(crate) format: FreeformToolFormat, - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub struct FreeformToolFormat { - pub(crate) r#type: String, - pub(crate) syntax: String, - pub(crate) definition: String, - } - - #[derive(Debug, Clone, Serialize, PartialEq)] - pub struct ResponsesApiTool { - pub(crate) name: String, - pub(crate) description: String, - /// TODO: Validation. When strict is set to true, the JSON schema, - /// `required` and `additional_properties` must be present. All fields in - /// `properties` must be present in `required`. - pub(crate) strict: bool, - pub(crate) parameters: JsonSchema, - } -} - -pub(crate) fn create_reasoning_param_for_request( +pub fn create_reasoning_param_for_request( model_family: &ModelFamily, effort: Option, summary: ReasoningSummaryConfig, @@ -357,7 +54,7 @@ pub(crate) fn create_reasoning_param_for_request( }) } -pub(crate) fn create_text_param_for_request( +pub fn create_text_param_for_request( verbosity: Option, output_schema: &Option, ) -> Option { @@ -366,7 +63,11 @@ pub(crate) fn create_text_param_for_request( } Some(TextControls { - verbosity: verbosity.map(std::convert::Into::into), + verbosity: verbosity.map(|v| match v { + VerbosityConfig::Low => "low".to_string(), + VerbosityConfig::Medium => "medium".to_string(), + VerbosityConfig::High => "high".to_string(), + }), format: output_schema.as_ref().map(|schema| TextFormat { r#type: TextFormatType::JsonSchema, strict: true, @@ -376,178 +77,54 @@ pub(crate) fn create_text_param_for_request( }) } -pub struct ResponseStream { - pub(crate) rx_event: mpsc::Receiver>, -} +pub use codex_api_client::ResponseEvent; -impl Stream for ResponseStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx_event.poll_recv(cx) - } -} +pub type ResponseStream = EventStream>; #[cfg(test)] mod tests { - use crate::model_family::find_family_for_model; - use pretty_assertions::assert_eq; - use super::*; + use crate::model_family::find_family_for_model; - struct InstructionsTestCase { - pub slug: &'static str, - pub expects_apply_patch_instructions: bool, - } #[test] - fn get_full_instructions_no_user_content() { - let prompt = Prompt { - ..Default::default() - }; - let test_cases = vec![ - InstructionsTestCase { - slug: "gpt-3.5", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-4.1", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-4o", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-5", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "codex-mini-latest", - expects_apply_patch_instructions: true, - }, - InstructionsTestCase { - slug: "gpt-oss:120b", - expects_apply_patch_instructions: false, - }, - InstructionsTestCase { - slug: "gpt-5-codex", - expects_apply_patch_instructions: false, - }, - ]; - for test_case in test_cases { - let model_family = find_family_for_model(test_case.slug).expect("known model slug"); - let expected = if test_case.expects_apply_patch_instructions { - format!( - "{}\n{}", - model_family.clone().base_instructions, - APPLY_PATCH_TOOL_INSTRUCTIONS - ) - } else { - model_family.clone().base_instructions - }; - - let full = prompt.get_full_instructions(&model_family); - assert_eq!(full, expected); - } + fn compute_full_instructions_respects_apply_patch_flag() { + let model = find_family_for_model("gpt-4.1").expect("model"); + let with_tool = compute_full_instructions(None, &model, true); + assert_eq!(with_tool.as_ref(), model.base_instructions.deref()); + + let without_tool = compute_full_instructions(None, &model, false); + assert!( + without_tool + .as_ref() + .ends_with(APPLY_PATCH_TOOL_INSTRUCTIONS) + ); } #[test] - fn serializes_text_verbosity_when_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(TextControls { - verbosity: Some(OpenAiVerbosity::Low), - format: None, - }), - }; - - let v = serde_json::to_value(&req).expect("json"); - assert_eq!( - v.get("text") - .and_then(|t| t.get("verbosity")) - .and_then(|s| s.as_str()), - Some("low") - ); + fn create_text_controls_includes_verbosity() { + let controls = create_text_param_for_request(Some(VerbosityConfig::Low), &None) + .expect("text controls"); + assert_eq!(controls.verbosity.as_deref(), Some("low")); + assert!(controls.format.is_none()); } #[test] - fn serializes_text_schema_with_strict_format() { - let input: Vec = vec![]; - let tools: Vec = vec![]; + fn create_text_controls_includes_schema() { let schema = serde_json::json!({ "type": "object", - "properties": { - "answer": {"type": "string"} - }, + "properties": {"answer": {"type": "string"}}, "required": ["answer"], }); - let text_controls = + let controls = create_text_param_for_request(None, &Some(schema.clone())).expect("text controls"); - - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: Some(text_controls), - }; - - let v = serde_json::to_value(&req).expect("json"); - let text = v.get("text").expect("text field"); - assert!(text.get("verbosity").is_none()); - let format = text.get("format").expect("format field"); - - assert_eq!( - format.get("name"), - Some(&serde_json::Value::String("codex_output_schema".into())) - ); - assert_eq!( - format.get("type"), - Some(&serde_json::Value::String("json_schema".into())) - ); - assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true))); - assert_eq!(format.get("schema"), Some(&schema)); + let format = controls.format.expect("format"); + assert_eq!(format.name, "codex_output_schema"); + assert!(format.strict); + assert_eq!(format.schema, schema); } #[test] - fn omits_text_when_not_set() { - let input: Vec = vec![]; - let tools: Vec = vec![]; - let req = ResponsesApiRequest { - model: "gpt-5", - instructions: "i", - input: &input, - tools: &tools, - tool_choice: "auto", - parallel_tool_calls: true, - reasoning: None, - store: false, - stream: true, - include: vec![], - prompt_cache_key: None, - text: None, - }; - - let v = serde_json::to_value(&req).expect("json"); - assert!(v.get("text").is_none()); + fn create_text_controls_none_when_no_options() { + assert!(create_text_param_for_request(None, &None).is_none()); } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f7a5d92bf5..6b5c1f5147 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -51,7 +51,6 @@ use tracing::error; use tracing::info; use tracing::warn; -use crate::ModelProviderInfo; use crate::client::ModelClient; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; @@ -64,8 +63,10 @@ use crate::error::CodexErr; use crate::error::Result as CodexResult; #[cfg(test)] use crate::exec::StreamOutput; +use codex_api_client::ModelProviderInfo; // Removed: legacy executor wiring replaced by ToolOrchestrator flows. // legacy normalize_exec_result no longer used after orchestrator migration +use crate::conversation_history::ResponsesApiChainState; use crate::mcp::auth::compute_auth_statuses; use crate::mcp_connection_manager::McpConnectionManager; use crate::model_family::find_family_for_model; @@ -301,7 +302,7 @@ pub(crate) struct SessionConfiguration { provider: ModelProviderInfo, /// If not specified, server will use its default model. - model: String, + pub(crate) model: String, model_reasoning_effort: Option, model_reasoning_summary: ReasoningSummaryConfig, @@ -313,7 +314,7 @@ pub(crate) struct SessionConfiguration { user_instructions: Option, /// Base instructions override. - base_instructions: Option, + pub(crate) base_instructions: Option, /// Compact prompt override. compact_prompt: Option, @@ -333,7 +334,7 @@ pub(crate) struct SessionConfiguration { cwd: PathBuf, /// Set of feature flags for this session - features: Features, + pub(crate) features: Features, // TODO(pakrym): Remove config from here original_config_do_not_use: Arc, @@ -586,8 +587,9 @@ impl Session { config.active_profile.clone(), ); - // Create the mutable state for the Session. - let state = SessionState::new(session_configuration.clone()); + let model_family = find_family_for_model(&session_configuration.model) + .unwrap_or_else(|| config.model_family.clone()); + let state = SessionState::new(session_configuration.clone(), model_family); let services = SessionServices { mcp_connection_manager, @@ -694,7 +696,6 @@ impl Session { pub(crate) async fn update_settings(&self, updates: SessionSettingsUpdate) { let mut state = self.state.lock().await; - state.session_configuration = state.session_configuration.apply(&updates); } @@ -978,6 +979,31 @@ impl Session { state.replace_history(items); } + async fn update_responses_api_chain_state( + &self, + response_id: Option, + ) { + let mut state = self.state.lock().await; + + let Some(response_id) = response_id.filter(|id| !id.is_empty()) else { + state.reset_responses_api_chain(); + return; + }; + + let mut history = state.clone_history(); + let prompt_items = history.get_history_for_prompt(); + let last_message_id = prompt_items + .iter() + .rev() + .find_map(crate::state::response_item_id) + .map(ToString::to_string); + + state.set_responses_api_chain(ResponsesApiChainState { + last_response_id: Some(response_id), + last_message_id, + }); + } + async fn persist_rollout_response_items(&self, items: &[ResponseItem]) { let rollout_items: Vec = items .iter() @@ -1761,30 +1787,32 @@ pub(crate) async fn run_task( .collect::>(); // Construct the input that we will send to the model. - let turn_input: Vec = { - sess.record_conversation_items(&turn_context, &pending_input) - .await; - sess.clone_history().await.get_history_for_prompt() - }; - - let turn_input_messages: Vec = turn_input - .iter() - .filter_map(|item| match item { - ResponseItem::Message { content, .. } => Some(content), - _ => None, - }) - .flat_map(|content| { - content.iter().filter_map(|item| match item { - ContentItem::OutputText { text } => Some(text.clone()), + sess.record_conversation_items(&turn_context, &pending_input) + .await; + let mut state = sess.state.lock().await; + let prompt = state.prompt_for_turn(); + + let turn_input_messages: Vec = { + prompt + .input + .iter() + .filter_map(|item| match item { + ResponseItem::Message { content, .. } => Some(content), _ => None, }) - }) - .collect(); + .flat_map(|content| { + content.iter().filter_map(|item| match item { + ContentItem::OutputText { text } => Some(text.clone()), + _ => None, + }) + }) + .collect() + }; match run_turn( Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - turn_input, + prompt, cancellation_token.child_token(), ) .await @@ -1870,7 +1898,7 @@ async fn run_turn( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, - input: Vec, + mut prompt: Prompt, cancellation_token: CancellationToken, ) -> CodexResult { let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); @@ -1879,27 +1907,39 @@ async fn run_turn( Some(mcp_tools), )); + let tool_specs = router.specs(); + let (tools_json, has_freeform_apply_patch) = + crate::tools::spec::tools_metadata_for_prompt(&tool_specs)?; + crate::conversation_history::format_prompt_items(&mut prompt.input, has_freeform_apply_patch); + + let apply_patch_present = tool_specs.iter().any(|spec| spec.name() == "apply_patch"); + + let instructions = crate::client_common::compute_full_instructions( + turn_context.base_instructions.as_deref(), + &turn_context.client.get_model_family(), + apply_patch_present, + ) + .into_owned(); + let model_supports_parallel = turn_context .client .get_model_family() .supports_parallel_tool_calls; let parallel_tool_calls = model_supports_parallel; - let prompt = Prompt { - input, - tools: router.specs(), - parallel_tool_calls, - base_instructions_override: turn_context.base_instructions.clone(), - output_schema: turn_context.final_output_json_schema.clone(), - }; + prompt.instructions = instructions.clone(); + prompt.tools = tools_json; + prompt.parallel_tool_calls = parallel_tool_calls; + prompt.output_schema = turn_context.final_output_json_schema.clone(); let mut retries = 0; loop { + let attempt_prompt = prompt.clone(); match try_run_turn( Arc::clone(&router), Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), - &prompt, + attempt_prompt, cancellation_token.child_token(), ) .await @@ -1980,7 +2020,7 @@ async fn try_run_turn( sess: Arc, turn_context: Arc, turn_diff_tracker: SharedTurnDiffTracker, - prompt: &Prompt, + prompt: Prompt, cancellation_token: CancellationToken, ) -> CodexResult { let rollout_item = RolloutItem::TurnContext(TurnContextItem { @@ -1996,7 +2036,7 @@ async fn try_run_turn( let mut stream = turn_context .client .clone() - .stream(prompt) + .stream(&prompt) .or_cancel(&cancellation_token) .await??; @@ -2129,7 +2169,7 @@ async fn try_run_turn( sess.update_rate_limits(&turn_context, snapshot).await; } ResponseEvent::Completed { - response_id: _, + response_id, token_usage, } => { sess.update_token_usage_info(&turn_context, token_usage.as_ref()) @@ -2139,6 +2179,10 @@ async fn try_run_turn( let mut tracker = turn_diff_tracker.lock().await; tracker.get_unified_diff() }; + sess.update_responses_api_chain_state( + Some(response_id.clone()), + ) + .await; if let Ok(Some(unified_diff)) = unified_diff { let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); sess.send_event(&turn_context, msg).await; @@ -2534,7 +2578,9 @@ mod tests { session_source: SessionSource::Exec, }; - let state = SessionState::new(session_configuration.clone()); + let model_family = find_family_for_model(&session_configuration.model) + .unwrap_or_else(|| config.model_family.clone()); + let state = SessionState::new(session_configuration.clone(), model_family); let services = SessionServices { mcp_connection_manager: McpConnectionManager::default(), @@ -2610,7 +2656,9 @@ mod tests { session_source: SessionSource::Exec, }; - let state = SessionState::new(session_configuration.clone()); + let model_family = find_family_for_model(&session_configuration.model) + .unwrap_or_else(|| config.model_family.clone()); + let state = SessionState::new(session_configuration.clone(), model_family); let services = SessionServices { mcp_connection_manager: McpConnectionManager::default(), diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index eba9ebe286..ae6de3f2ac 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::Session; use super::TurnContext; use super::get_last_assistant_message_from_turn; -use crate::Prompt; +use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::error::CodexErr; use crate::error::Result as CodexResult; @@ -84,11 +84,9 @@ async fn run_compact_task_inner( loop { let turn_input = history.get_history_for_prompt(); - let prompt = Prompt { - input: turn_input.clone(), - ..Default::default() - }; - let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await; + let turn_input_len = turn_input.len(); + let (prompt, _) = crate::state::build_prompt_from_items(turn_input, None); + let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), prompt).await; match attempt_result { Ok(()) => { @@ -107,7 +105,7 @@ async fn run_compact_task_inner( return; } Err(e @ CodexErr::ContextWindowExceeded) => { - if turn_input.len() > 1 { + if turn_input_len > 1 { // Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact. error!( "Context window exceeded while compacting; removing oldest history item. Error: {e}" @@ -251,9 +249,9 @@ fn build_compacted_history_with_limit( async fn drain_to_completed( sess: &Session, turn_context: &TurnContext, - prompt: &Prompt, + prompt: Prompt, ) -> CodexResult<()> { - let mut stream = turn_context.client.clone().stream(prompt).await?; + let mut stream = turn_context.client.clone().stream(&prompt).await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else { diff --git a/codex-rs/core/src/config/mod.rs b/codex-rs/core/src/config/mod.rs index 17fddcf3d3..9b8884162f 100644 --- a/codex-rs/core/src/config/mod.rs +++ b/codex-rs/core/src/config/mod.rs @@ -25,13 +25,13 @@ use crate::git_info::resolve_root_git_project_for_trust; use crate::model_family::ModelFamily; use crate::model_family::derive_default_model_family; use crate::model_family::find_family_for_model; -use crate::model_provider_info::ModelProviderInfo; -use crate::model_provider_info::built_in_model_providers; use crate::openai_model_info::get_model_info; use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME; use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME; use crate::protocol::AskForApproval; use crate::protocol::SandboxPolicy; +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_app_server_protocol::Tools; use codex_app_server_protocol::UserSavedConfig; use codex_protocol::config_types::ForcedLoginMethod; @@ -2802,7 +2802,7 @@ model_verbosity = "high" name: "OpenAI using Chat Completions".to_string(), base_url: Some("https://api.openai.com/v1".to_string()), env_key: Some("OPENAI_API_KEY".to_string()), - wire_api: crate::WireApi::Chat, + wire_api: codex_api_client::WireApi::Chat, env_key_instructions: None, experimental_bearer_token: None, query_params: None, diff --git a/codex-rs/core/src/conversation_history.rs b/codex-rs/core/src/conversation_history.rs index bc660d1cd9..b32bb727db 100644 --- a/codex-rs/core/src/conversation_history.rs +++ b/codex-rs/core/src/conversation_history.rs @@ -7,6 +7,7 @@ use codex_protocol::protocol::TokenUsage; use codex_protocol::protocol::TokenUsageInfo; use codex_utils_string::take_bytes_at_char_boundary; use codex_utils_string::take_last_bytes_at_char_boundary; +use std::collections::HashSet; use std::ops::Deref; // Model-formatting limits: clients get full streams; only content sent to the model is truncated. @@ -22,6 +23,13 @@ pub(crate) struct ConversationHistory { /// The oldest items are at the beginning of the vector. items: Vec, token_info: Option, + responses_api_chain: Option, +} + +#[derive(Debug, Clone, Default)] +pub(crate) struct ResponsesApiChainState { + pub last_response_id: Option, + pub last_message_id: Option, } impl ConversationHistory { @@ -29,6 +37,7 @@ impl ConversationHistory { Self { items: Vec::new(), token_info: TokenUsageInfo::new_or_append(&None, &None, None), + responses_api_chain: None, } } @@ -71,6 +80,10 @@ impl ConversationHistory { // Returns the history prepared for sending to the model. // With extra response items filtered out and GhostCommits removed. pub(crate) fn get_history_for_prompt(&mut self) -> Vec { + self.build_prompt_history() + } + + fn build_prompt_history(&mut self) -> Vec { let mut history = self.get_history(); Self::remove_ghost_snapshots(&mut history); Self::remove_reasoning_before_last_turn(&mut history); @@ -91,6 +104,7 @@ impl ConversationHistory { pub(crate) fn replace(&mut self, items: Vec) { self.items = items; + self.reset_responses_api_chain(); } pub(crate) fn update_token_info( @@ -429,6 +443,18 @@ impl ConversationHistory { | ResponseItem::Other => item.clone(), } } + + pub(crate) fn responses_api_chain(&self) -> Option { + self.responses_api_chain.clone() + } + + pub(crate) fn reset_responses_api_chain(&mut self) { + self.responses_api_chain = None; + } + + pub(crate) fn set_responses_api_chain(&mut self, chain: ResponsesApiChainState) { + self.responses_api_chain = Some(chain); + } } pub(crate) fn format_output_for_model_body(content: &str) -> String { @@ -519,6 +545,102 @@ fn is_api_message(message: &ResponseItem) -> bool { } } +fn reserialize_shell_outputs(items: &mut [ResponseItem]) { + let mut shell_call_ids: HashSet = HashSet::new(); + items.iter_mut().for_each(|item| match item { + ResponseItem::LocalShellCall { call_id, id, .. } => { + if let Some(identifier) = call_id.clone().or_else(|| id.clone()) { + shell_call_ids.insert(identifier); + } + } + ResponseItem::CustomToolCall { call_id, name, .. } => { + if name == "apply_patch" { + shell_call_ids.insert(call_id.clone()); + } + } + ResponseItem::CustomToolCallOutput { call_id, output } => { + if shell_call_ids.remove(call_id) + && let Some(structured) = parse_structured_shell_output(output) + { + *output = structured; + } + } + ResponseItem::FunctionCall { name, call_id, .. } + if name == "shell" || name == "container.exec" || name == "apply_patch" => + { + shell_call_ids.insert(call_id.clone()); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + if shell_call_ids.remove(call_id) + && let Some(structured) = parse_structured_shell_output(&output.content) + { + output.content = structured; + } + } + _ => {} + }); +} + +#[derive(serde::Deserialize)] +struct ExecOutputJson { + output: String, + metadata: ExecOutputMetadataJson, +} + +#[derive(serde::Deserialize)] +struct ExecOutputMetadataJson { + exit_code: i32, + duration_seconds: f32, +} + +fn parse_structured_shell_output(raw: &str) -> Option { + let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?; + Some(build_structured_output(&parsed)) +} + +fn build_structured_output(parsed: &ExecOutputJson) -> String { + let mut sections = Vec::new(); + sections.push(format!("Exit code: {}", parsed.metadata.exit_code)); + sections.push(format!( + "Wall time: {} seconds", + parsed.metadata.duration_seconds + )); + + let mut output = parsed.output.clone(); + if let Some(total_lines) = extract_total_output_lines(&parsed.output) { + sections.push(format!("Total output lines: {total_lines}")); + if let Some(stripped) = strip_total_output_header(&output) { + output = stripped.to_string(); + } + } + + sections.push("Output:".to_string()); + sections.push(output); + + sections.join("\n") +} + +fn extract_total_output_lines(output: &str) -> Option { + let marker_start = output.find("[... omitted ")?; + let marker = &output[marker_start..]; + let (_, after_of) = marker.split_once(" of ")?; + let (total_segment, _) = after_of.split_once(' ')?; + total_segment.parse::().ok() +} + +fn strip_total_output_header(output: &str) -> Option<&str> { + let after_prefix = output.strip_prefix("Total output lines: ")?; + let (_, remainder) = after_prefix.split_once('\n')?; + let remainder = remainder.strip_prefix('\n').unwrap_or(remainder); + Some(remainder) +} + +pub(crate) fn format_prompt_items(items: &mut [ResponseItem], has_freeform_apply_patch: bool) { + if has_freeform_apply_patch { + reserialize_shell_outputs(items); + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/codex-rs/core/src/default_client.rs b/codex-rs/core/src/default_client.rs index 8e4635460c..b29f06e29c 100644 --- a/codex-rs/core/src/default_client.rs +++ b/codex-rs/core/src/default_client.rs @@ -41,6 +41,14 @@ impl CodexHttpClient { Self { inner } } + pub fn inner(&self) -> &reqwest::Client { + &self.inner + } + + pub fn clone_inner(&self) -> reqwest::Client { + self.inner.clone() + } + pub fn get(&self, url: U) -> CodexRequestBuilder where U: IntoUrl, diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 0c4356d75c..c7f3021f98 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -43,6 +43,8 @@ pub enum Feature { SandboxCommandAssessment, /// Create a ghost commit at each turn. GhostCommit, + /// Enable chaining Responses API calls via previous response IDs. + ResponsesApiChaining, } impl Feature { @@ -295,4 +297,10 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Experimental, default_enabled: false, }, + FeatureSpec { + id: Feature::ResponsesApiChaining, + key: "responses_api_chaining", + stage: Stage::Experimental, + default_enabled: false, + }, ]; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index b9bd97ca23..e4b8e4d95c 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -8,7 +8,6 @@ mod apply_patch; pub mod auth; pub mod bash; -mod chat_completions; mod client; mod client_common; pub mod codex; @@ -19,9 +18,11 @@ mod command_safety; pub mod config; pub mod config_loader; mod conversation_history; +mod conversation_manager; pub mod custom_prompts; mod environment_context; pub mod error; +mod event_mapping; pub mod exec; pub mod exec_env; pub mod features; @@ -32,22 +33,14 @@ pub mod mcp; mod mcp_connection_manager; mod mcp_tool_call; mod message_history; -mod model_provider_info; pub mod parse_command; mod response_processing; +pub mod review_format; pub mod sandboxing; pub mod token_data; mod truncate; mod unified_exec; mod user_instructions; -pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; -pub use model_provider_info::ModelProviderInfo; -pub use model_provider_info::WireApi; -pub use model_provider_info::built_in_model_providers; -pub use model_provider_info::create_oss_provider_with_base_url; -mod conversation_manager; -mod event_mapping; -pub mod review_format; pub use codex_protocol::protocol::InitialHistory; pub use conversation_manager::ConversationManager; pub use conversation_manager::NewConversation; diff --git a/codex-rs/core/src/sandboxing/assessment.rs b/codex-rs/core/src/sandboxing/assessment.rs index c7310c1f13..a5a33ff38f 100644 --- a/codex-rs/core/src/sandboxing/assessment.rs +++ b/codex-rs/core/src/sandboxing/assessment.rs @@ -5,13 +5,13 @@ use std::time::Duration; use std::time::Instant; use crate::AuthManager; -use crate::ModelProviderInfo; use crate::client::ModelClient; -use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; use crate::protocol::SandboxPolicy; use askama::Template; +use codex_api_client::ModelProviderInfo; +use codex_api_client::Prompt; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; use codex_protocol::models::ContentItem; @@ -126,12 +126,10 @@ pub(crate) async fn assess_command( role: "user".to_string(), content: vec![ContentItem::InputText { text: user_prompt }], }], - tools: Vec::new(), - parallel_tool_calls: false, - base_instructions_override: Some(system_prompt), output_schema: Some(sandbox_assessment_schema()), + instructions: system_prompt, + ..Default::default() }; - let child_otel = parent_otel.with_model(config.model.as_str(), config.model_family.slug.as_str()); diff --git a/codex-rs/core/src/state/mod.rs b/codex-rs/core/src/state/mod.rs index 642433a786..d6b2ed364b 100644 --- a/codex-rs/core/src/state/mod.rs +++ b/codex-rs/core/src/state/mod.rs @@ -4,6 +4,8 @@ mod turn; pub(crate) use service::SessionServices; pub(crate) use session::SessionState; +pub(crate) use session::build_prompt_from_items; +pub(crate) use session::response_item_id; pub(crate) use turn::ActiveTurn; pub(crate) use turn::RunningTask; pub(crate) use turn::TaskKind; diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index a41d2b6342..b9bb90d20f 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -2,26 +2,41 @@ use codex_protocol::models::ResponseItem; +use crate::client_common::Prompt; +use crate::client_common::compute_full_instructions; use crate::codex::SessionConfiguration; use crate::conversation_history::ConversationHistory; +use crate::conversation_history::ResponsesApiChainState; +use crate::conversation_history::format_prompt_items; +use crate::features::Feature; +use crate::model_family::ModelFamily; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; use crate::protocol::TokenUsageInfo; +use crate::tools::spec::ToolsConfig; +use crate::tools::spec::ToolsConfigParams; +use crate::tools::spec::build_specs; +use crate::tools::spec::tools_metadata_for_prompt; /// Persistent, session-scoped state previously stored directly on `Session`. pub(crate) struct SessionState { pub(crate) session_configuration: SessionConfiguration, pub(crate) history: ConversationHistory, pub(crate) latest_rate_limits: Option, + pub(crate) model_family: ModelFamily, } impl SessionState { /// Create a new session state mirroring previous `State::default()` semantics. - pub(crate) fn new(session_configuration: SessionConfiguration) -> Self { + pub(crate) fn new( + session_configuration: SessionConfiguration, + model_family: ModelFamily, + ) -> Self { Self { session_configuration, history: ConversationHistory::new(), latest_rate_limits: None, + model_family, } } @@ -42,6 +57,16 @@ impl SessionState { self.history.replace(items); } + pub(crate) fn reset_responses_api_chain(&mut self) { + self.history.reset_responses_api_chain(); + } + + pub(crate) fn set_responses_api_chain(&mut self, chain: ResponsesApiChainState) { + if self.session_configuration.features.enabled(Feature::ResponsesApiChaining) { + self.history.set_responses_api_chain(chain); + } + } + // Token/rate limit helpers pub(crate) fn update_token_info_from_usage( &mut self, @@ -68,4 +93,84 @@ impl SessionState { pub(crate) fn set_token_usage_full(&mut self, context_window: i64) { self.history.set_token_usage_full(context_window); } + + pub(crate) fn prompt_for_turn(&mut self) -> Prompt { + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_family: &self.model_family, + features: &self.session_configuration.features, + }); + let (tool_specs, _registry) = build_specs(&tools_config, None).build(); + let tool_specs = tool_specs.into_iter().map(|c| c.spec).collect::>(); + + let prompt_items = self.history.get_history_for_prompt(); + let chain_state = self.history.responses_api_chain(); + let (mut prompt, reset_chain) = build_prompt_from_items(prompt_items, chain_state.as_ref()); + if reset_chain { + self.reset_responses_api_chain(); + } + + // Populate prompt fields that depend only on session state. + let (tools_json, has_freeform_apply_patch) = + tools_metadata_for_prompt(&tool_specs).unwrap_or_default(); + format_prompt_items(&mut prompt.input, has_freeform_apply_patch); + + let apply_patch_present = tool_specs.iter().any(|spec| spec.name() == "apply_patch"); + let base_override = self.session_configuration.base_instructions.as_deref(); + let instructions = + compute_full_instructions(base_override, &self.model_family, apply_patch_present) + .into_owned(); + + prompt.instructions = instructions; + prompt.tools = tools_json; + prompt.parallel_tool_calls = self.model_family.supports_parallel_tool_calls; + + prompt + } +} + +pub(crate) fn response_item_id(item: &ResponseItem) -> Option<&str> { + match item { + ResponseItem::Message { id: Some(id), .. } + | ResponseItem::Reasoning { id, .. } + | ResponseItem::LocalShellCall { id: Some(id), .. } + | ResponseItem::FunctionCall { id: Some(id), .. } + | ResponseItem::CustomToolCall { id: Some(id), .. } + | ResponseItem::WebSearchCall { id: Some(id), .. } => Some(id.as_str()), + _ => None, + } +} + +pub(crate) fn build_prompt_from_items( + prompt_items: Vec, + chain_state: Option<&ResponsesApiChainState>, +) -> (Prompt, bool) { + let mut prompt = Prompt { + ..Prompt::default() + }; + + if let Some(state) = chain_state { + if let Some(last_message_id) = state.last_message_id.as_ref() { + if let Some(position) = prompt_items + .iter() + .position(|item| response_item_id(item) == Some(last_message_id.as_str())) + { + if let Some(previous_response_id) = state.last_response_id.clone() { + prompt.previous_response_id = Some(previous_response_id); + } + prompt.input = prompt_items.into_iter().skip(position + 1).collect(); + return (prompt, false); + } + prompt.input = prompt_items; + return (prompt, true); + } + + if let Some(previous_response_id) = state.last_response_id.clone() { + prompt.previous_response_id = Some(previous_response_id); + } + prompt.input = prompt_items; + return (prompt, false); + } + + prompt.input = prompt_items; + (prompt, false) } diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 1e82b9cf10..9791d73228 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -3,10 +3,6 @@ use std::collections::BTreeMap; use crate::apply_patch; use crate::apply_patch::InternalApplyPatchInvocation; use crate::apply_patch::convert_apply_patch_to_protocol; -use crate::client_common::tools::FreeformTool; -use crate::client_common::tools::FreeformToolFormat; -use crate::client_common::tools::ResponsesApiTool; -use crate::client_common::tools::ToolSpec; use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; @@ -20,7 +16,11 @@ use crate::tools::runtimes::apply_patch::ApplyPatchRequest; use crate::tools::runtimes::apply_patch::ApplyPatchRuntime; use crate::tools::sandboxing::ToolCtx; use crate::tools::spec::ApplyPatchToolArgs; +use crate::tools::spec::FreeformTool; +use crate::tools::spec::FreeformToolFormat; use crate::tools::spec::JsonSchema; +use crate::tools::spec::ResponsesApiTool; +use crate::tools::spec::ToolSpec; use async_trait::async_trait; use serde::Deserialize; use serde::Serialize; diff --git a/codex-rs/core/src/tools/handlers/plan.rs b/codex-rs/core/src/tools/handlers/plan.rs index 073319bf1c..a0b3361106 100644 --- a/codex-rs/core/src/tools/handlers/plan.rs +++ b/codex-rs/core/src/tools/handlers/plan.rs @@ -1,5 +1,3 @@ -use crate::client_common::tools::ResponsesApiTool; -use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; @@ -9,6 +7,8 @@ use crate::tools::context::ToolPayload; use crate::tools::registry::ToolHandler; use crate::tools::registry::ToolKind; use crate::tools::spec::JsonSchema; +use crate::tools::spec::ResponsesApiTool; +use crate::tools::spec::ToolSpec; use async_trait::async_trait; use codex_protocol::plan_tool::UpdatePlanArgs; use codex_protocol::protocol::EventMsg; diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 8769259794..b594401447 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -6,11 +6,11 @@ use async_trait::async_trait; use codex_protocol::models::ResponseInputItem; use tracing::warn; -use crate::client_common::tools::ToolSpec; use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; +use crate::tools::spec::ToolSpec; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum ToolKind { diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 19098aa80d..86d0aa7a02 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; @@ -10,6 +9,7 @@ use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; use crate::tools::registry::ConfiguredToolSpec; use crate::tools::registry::ToolRegistry; +use crate::tools::spec::ToolSpec; use crate::tools::spec::ToolsConfig; use crate::tools::spec::build_specs; use codex_protocol::models::LocalShellAction; diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index eba9fd517c..24e887c6f1 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -1,5 +1,3 @@ -use crate::client_common::tools::ResponsesApiTool; -use crate::client_common::tools::ToolSpec; use crate::features::Feature; use crate::features::Features; use crate::model_family::ModelFamily; @@ -22,6 +20,52 @@ pub enum ConfigShellToolType { Streamable, } +#[derive(Debug, Clone, Serialize, PartialEq)] +#[serde(tag = "type")] +pub(crate) enum ToolSpec { + #[serde(rename = "function")] + Function(ResponsesApiTool), + #[serde(rename = "local_shell")] + LocalShell {}, + #[serde(rename = "web_search")] + WebSearch {}, + #[serde(rename = "custom")] + Freeform(FreeformTool), +} + +impl ToolSpec { + pub(crate) fn name(&self) -> &str { + match self { + ToolSpec::Function(tool) => tool.name.as_str(), + ToolSpec::LocalShell {} => "local_shell", + ToolSpec::WebSearch {} => "web_search", + ToolSpec::Freeform(tool) => tool.name.as_str(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FreeformTool { + pub(crate) name: String, + pub(crate) description: String, + pub(crate) format: FreeformToolFormat, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FreeformToolFormat { + pub(crate) r#type: String, + pub(crate) syntax: String, + pub(crate) definition: String, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct ResponsesApiTool { + pub(crate) name: String, + pub(crate) description: String, + pub(crate) strict: bool, + pub(crate) parameters: JsonSchema, +} + #[derive(Debug, Clone)] pub(crate) struct ToolsConfig { pub shell_type: ConfigShellToolType, @@ -666,9 +710,6 @@ pub(crate) struct ApplyPatchToolArgs { pub(crate) input: String, } -/// Returns JSON values that are compatible with Function Calling in the -/// Responses API: -/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses pub fn create_tools_json_for_responses_api( tools: &[ToolSpec], ) -> crate::error::Result> { @@ -681,35 +722,16 @@ pub fn create_tools_json_for_responses_api( Ok(tools_json) } -/// Returns JSON values that are compatible with Function Calling in the -/// Chat Completions API: -/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat -pub(crate) fn create_tools_json_for_chat_completions_api( - tools: &[ToolSpec], -) -> crate::error::Result> { - // We start with the JSON for the Responses API and than rewrite it to match - // the chat completions tool call format. - let responses_api_tools_json = create_tools_json_for_responses_api(tools)?; - let tools_json = responses_api_tools_json - .into_iter() - .filter_map(|mut tool| { - if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) { - return None; - } - if let Some(map) = tool.as_object_mut() { - // Remove "type" field as it is not needed in chat completions. - map.remove("type"); - Some(json!({ - "type": "function", - "function": map, - })) - } else { - None - } - }) - .collect::>(); - Ok(tools_json) +pub fn tools_metadata_for_prompt( + tools: &[ToolSpec], +) -> crate::error::Result<(Vec, bool)> { + let tools_json = create_tools_json_for_responses_api(tools)?; + let has_freeform_apply_patch = tools.iter().any(|tool| match tool { + ToolSpec::Freeform(freeform) => freeform.name == "apply_patch", + _ => false, + }); + Ok((tools_json, has_freeform_apply_patch)) } pub(crate) fn mcp_tool_to_openai_tool( @@ -1002,7 +1024,6 @@ pub(crate) fn build_specs( #[cfg(test)] mod tests { - use crate::client_common::tools::FreeformTool; use crate::model_family::find_family_for_model; use crate::tools::registry::ConfiguredToolSpec; use mcp_types::ToolInputSchema; diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs index cadf6be294..ac28cc117d 100644 --- a/codex-rs/core/tests/chat_completions_payload.rs +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -1,15 +1,15 @@ use std::sync::Arc; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_app_server_protocol::AuthMode; use codex_core::ContentItem; use codex_core::LocalShellAction; use codex_core::LocalShellExecAction; use codex_core::LocalShellStatus; use codex_core::ModelClient; -use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseItem; -use codex_core::WireApi; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; @@ -97,10 +97,12 @@ async fn run_request(input: Vec) -> Value { codex_protocol::protocol::SessionSource::Exec, ); - let mut prompt = Prompt::default(); - prompt.input = input; + let prompt = Prompt { + input, + ..Prompt::default() + }; - let mut stream = match client.stream(&prompt).await { + let mut stream = match client.stream_for_test(prompt).await { Ok(s) => s, Err(e) => panic!("stream chat failed: {e}"), }; diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs index 46378b0823..06d2ebb4ca 100644 --- a/codex-rs/core/tests/chat_completions_sse.rs +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -2,14 +2,14 @@ use assert_matches::assert_matches; use std::sync::Arc; use tracing_test::traced_test; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_app_server_protocol::AuthMode; use codex_core::ContentItem; use codex_core::ModelClient; -use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; -use codex_core::WireApi; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; @@ -97,16 +97,18 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec { codex_protocol::protocol::SessionSource::Exec, ); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { - text: "hello".to_string(), + let prompt = Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "hello".to_string(), + }], }], - }]; + ..Prompt::default() + }; - let mut stream = match client.stream(&prompt).await { + let mut stream = match client.stream_for_test(prompt).await { Ok(s) => s, Err(e) => panic!("stream chat failed: {e}"), }; diff --git a/codex-rs/core/tests/common/Cargo.toml b/codex-rs/core/tests/common/Cargo.toml index 65abe23c63..b346ca0a2c 100644 --- a/codex-rs/core/tests/common/Cargo.toml +++ b/codex-rs/core/tests/common/Cargo.toml @@ -9,6 +9,7 @@ path = "lib.rs" [dependencies] anyhow = { workspace = true } assert_cmd = { workspace = true } +codex-api-client = { workspace = true } codex-core = { workspace = true } codex-protocol = { workspace = true } notify = { workspace = true } diff --git a/codex-rs/core/tests/common/test_codex.rs b/codex-rs/core/tests/common/test_codex.rs index 0f9fdaae8f..90d98fd05b 100644 --- a/codex-rs/core/tests/common/test_codex.rs +++ b/codex-rs/core/tests/common/test_codex.rs @@ -4,11 +4,11 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::CodexConversation; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; -use codex_core::built_in_model_providers; use codex_core::config::Config; use codex_core::features::Feature; use codex_core::protocol::AskForApproval; diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 7b6f645f2f..cf265fcbfa 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -1,13 +1,13 @@ use std::sync::Arc; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_app_server_protocol::AuthMode; use codex_core::ContentItem; use codex_core::ModelClient; -use codex_core::ModelProviderInfo; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; -use codex_core::WireApi; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::ConversationId; use codex_protocol::protocol::SessionSource; @@ -82,16 +82,18 @@ async fn responses_stream_includes_subagent_header_on_review() { SessionSource::SubAgent(codex_protocol::protocol::SubAgentSource::Review), ); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), + let prompt = Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], }], - }]; + ..Prompt::default() + }; - let mut stream = client.stream(&prompt).await.expect("stream failed"); + let mut stream = client.stream_for_test(prompt).await.expect("stream failed"); while let Some(event) = stream.next().await { if matches!(event, Ok(ResponseEvent::Completed { .. })) { break; @@ -172,16 +174,18 @@ async fn responses_stream_includes_subagent_header_on_other() { )), ); - let mut prompt = Prompt::default(); - prompt.input = vec![ResponseItem::Message { - id: None, - role: "user".into(), - content: vec![ContentItem::InputText { - text: "hello".into(), + let prompt = Prompt { + input: vec![ResponseItem::Message { + id: None, + role: "user".into(), + content: vec![ContentItem::InputText { + text: "hello".into(), + }], }], - }]; + ..Prompt::default() + }; - let mut stream = client.stream(&prompt).await.expect("stream failed"); + let mut stream = client.stream_for_test(prompt).await.expect("stream failed"); while let Some(event) = stream.next().await { if matches!(event, Ok(ResponseEvent::Completed { .. })) { break; diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 07bee704f4..b17b8e02da 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1,3 +1,6 @@ +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; +use codex_api_client::built_in_model_providers; use codex_app_server_protocol::AuthMode; use codex_core::CodexAuth; use codex_core::ContentItem; @@ -6,15 +9,13 @@ use codex_core::LocalShellAction; use codex_core::LocalShellExecAction; use codex_core::LocalShellStatus; use codex_core::ModelClient; -use codex_core::ModelProviderInfo; use codex_core::NewConversation; use codex_core::Prompt; use codex_core::ResponseEvent; use codex_core::ResponseItem; -use codex_core::WireApi; use codex_core::auth::AuthCredentialsStoreMode; -use codex_core::built_in_model_providers; use codex_core::error::CodexErr; +use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; @@ -678,6 +679,98 @@ async fn includes_developer_instructions_message_in_request() { assert_message_ends_with(&request_body["input"][2], ""); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn responses_api_chaining_sets_store_and_previous_id() { + skip_if_no_network!(); + + let server = MockServer::start().await; + let first_response = responses::sse(vec![ + responses::ev_response_created("resp-first"), + responses::ev_assistant_message("m1", "hi there"), + responses::ev_completed("resp-first"), + ]); + let second_response = responses::sse(vec![ + responses::ev_response_created("resp-second"), + responses::ev_assistant_message("m2", "second reply"), + responses::ev_completed("resp-second"), + ]); + let response_mock = + responses::mount_sse_sequence(&server, vec![first_response, second_response]).await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + config.features.enable(Feature::ResponsesApiChaining); + + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); + let codex = conversation_manager + .new_conversation(config) + .await + .expect("create new conversation") + .conversation; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "first turn".into(), + }], + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "second turn".into(), + }], + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let requests = response_mock.requests(); + assert_eq!( + requests.len(), + 2, + "expected two responses API calls for two turns" + ); + + let first_body = requests[0].body_json(); + assert_eq!(first_body["store"], serde_json::Value::Bool(true)); + assert!( + first_body.get("previous_response_id").is_none(), + "first request should not set previous_response_id" + ); + + let second_body = requests[1].body_json(); + assert_eq!(second_body["store"], serde_json::Value::Bool(true)); + assert_eq!( + second_body["previous_response_id"].as_str(), + Some("resp-first") + ); + + let second_input = requests[1].input(); + assert_eq!( + second_input.len(), + 1, + "second request should only send new user input items" + ); + let user_item = &second_input[0]; + assert_eq!(user_item["type"].as_str(), Some("message")); + assert_eq!(user_item["role"].as_str(), Some("user")); + let content = user_item["content"][0]["text"] + .as_str() + .expect("missing user message text"); + assert_eq!(content, "second turn"); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn azure_responses_request_includes_store_and_reasoning_ids() { skip_if_no_network!(); @@ -800,7 +893,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { }); let mut stream = client - .stream(&prompt) + .stream_for_test(prompt) .await .expect("responses stream to start"); diff --git a/codex-rs/core/tests/suite/compact.rs b/codex-rs/core/tests/suite/compact.rs index 0dea8a02cb..3ed83fb876 100644 --- a/codex-rs/core/tests/suite/compact.rs +++ b/codex-rs/core/tests/suite/compact.rs @@ -1,8 +1,8 @@ +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; use codex_core::NewConversation; -use codex_core::built_in_model_providers; use codex_core::protocol::ErrorEvent; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; diff --git a/codex-rs/core/tests/suite/compact_resume_fork.rs b/codex-rs/core/tests/suite/compact_resume_fork.rs index b13c6e14fd..12b5cf78fb 100644 --- a/codex-rs/core/tests/suite/compact_resume_fork.rs +++ b/codex-rs/core/tests/suite/compact_resume_fork.rs @@ -9,12 +9,12 @@ use super::compact::FIRST_REPLY; use super::compact::SUMMARY_TEXT; +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::CodexConversation; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; use codex_core::NewConversation; -use codex_core::built_in_model_providers; use codex_core::codex::compact::SUMMARIZATION_PROMPT; use codex_core::config::Config; use codex_core::config::OPENAI_DEFAULT_MODEL; diff --git a/codex-rs/core/tests/suite/fork_conversation.rs b/codex-rs/core/tests/suite/fork_conversation.rs index 75b37ae7ef..c7caf5662c 100644 --- a/codex-rs/core/tests/suite/fork_conversation.rs +++ b/codex-rs/core/tests/suite/fork_conversation.rs @@ -1,8 +1,8 @@ +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; use codex_core::NewConversation; -use codex_core::built_in_model_providers; use codex_core::parse_turn_item; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; diff --git a/codex-rs/core/tests/suite/model_tools.rs b/codex-rs/core/tests/suite/model_tools.rs index 2a04b88a16..079887e48c 100644 --- a/codex-rs/core/tests/suite/model_tools.rs +++ b/codex-rs/core/tests/suite/model_tools.rs @@ -1,9 +1,9 @@ #![allow(clippy::unwrap_used)] +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; -use codex_core::built_in_model_providers; use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; use codex_core::protocol::EventMsg; diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index 04304126f0..e086f81f05 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -1,9 +1,9 @@ #![allow(clippy::unwrap_used)] +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; -use codex_core::built_in_model_providers; use codex_core::config::OPENAI_DEFAULT_MODEL; use codex_core::features::Feature; use codex_core::model_family::find_family_for_model; diff --git a/codex-rs/core/tests/suite/review.rs b/codex-rs/core/tests/suite/review.rs index 093fd99268..2d7680f99c 100644 --- a/codex-rs/core/tests/suite/review.rs +++ b/codex-rs/core/tests/suite/review.rs @@ -1,11 +1,11 @@ +use codex_api_client::ModelProviderInfo; +use codex_api_client::built_in_model_providers; use codex_core::CodexAuth; use codex_core::CodexConversation; use codex_core::ContentItem; use codex_core::ConversationManager; -use codex_core::ModelProviderInfo; use codex_core::REVIEW_PROMPT; use codex_core::ResponseItem; -use codex_core::built_in_model_providers; use codex_core::config::Config; use codex_core::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG; use codex_core::protocol::EventMsg; diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 672633ace5..edd47a6822 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -422,7 +422,7 @@ async fn stdio_image_completions_round_trip() -> anyhow::Result<()> { let fixture = test_codex() .with_config(move |config| { - config.model_provider.wire_api = codex_core::WireApi::Chat; + config.model_provider.wire_api = codex_api_client::WireApi::Chat; config.features.enable(Feature::RmcpClient); config.mcp_servers.insert( server_name.to_string(), diff --git a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs index ba86f8c155..cbc8c22530 100644 --- a/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs +++ b/codex-rs/core/tests/suite/stream_error_allows_next_turn.rs @@ -1,7 +1,7 @@ use std::time::Duration; -use codex_core::ModelProviderInfo; -use codex_core::WireApi; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_protocol::user_input::UserInput; diff --git a/codex-rs/core/tests/suite/stream_no_completed.rs b/codex-rs/core/tests/suite/stream_no_completed.rs index 550bb3f9c2..0f13be4043 100644 --- a/codex-rs/core/tests/suite/stream_no_completed.rs +++ b/codex-rs/core/tests/suite/stream_no_completed.rs @@ -3,8 +3,8 @@ use std::time::Duration; -use codex_core::ModelProviderInfo; -use codex_core::WireApi; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; use codex_protocol::user_input::UserInput; diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index 8fc1e38875..45e96f6cd7 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -24,6 +24,7 @@ codex-common = { workspace = true, features = [ "sandbox_summary", ] } codex-core = { workspace = true } +codex-api-client = { workspace = true } codex-ollama = { workspace = true } codex-protocol = { workspace = true } mcp-types = { workspace = true } diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index a086990dff..74f5e1b793 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -11,8 +11,8 @@ pub mod event_processor_with_jsonl_output; pub mod exec_events; pub use cli::Cli; +use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::AuthManager; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::ConversationManager; use codex_core::NewConversation; use codex_core::auth::enforce_login_restrictions; diff --git a/codex-rs/ollama/Cargo.toml b/codex-rs/ollama/Cargo.toml index 14dd6d2fcd..39c39874b2 100644 --- a/codex-rs/ollama/Cargo.toml +++ b/codex-rs/ollama/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] async-stream = { workspace = true } bytes = { workspace = true } +codex-api-client = { workspace = true } codex-core = { workspace = true } futures = { workspace = true } reqwest = { workspace = true, features = ["json", "stream"] } diff --git a/codex-rs/ollama/src/client.rs b/codex-rs/ollama/src/client.rs index 04b7e9dea2..460aca4b63 100644 --- a/codex-rs/ollama/src/client.rs +++ b/codex-rs/ollama/src/client.rs @@ -10,9 +10,9 @@ use crate::pull::PullEvent; use crate::pull::PullProgressReporter; use crate::url::base_url_to_host_root; use crate::url::is_openai_compatible_base_url; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; -use codex_core::ModelProviderInfo; -use codex_core::WireApi; +use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID; +use codex_api_client::ModelProviderInfo; +use codex_api_client::WireApi; use codex_core::config::Config; const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama"; @@ -47,7 +47,7 @@ impl OllamaClient { #[cfg(test)] async fn try_from_provider_with_base_url(base_url: &str) -> io::Result { - let provider = codex_core::create_oss_provider_with_base_url(base_url); + let provider = codex_api_client::create_oss_provider_with_base_url(base_url); Self::try_from_provider(&provider).await } diff --git a/codex-rs/tui/Cargo.toml b/codex-rs/tui/Cargo.toml index f087202c22..f98b2f9199 100644 --- a/codex-rs/tui/Cargo.toml +++ b/codex-rs/tui/Cargo.toml @@ -34,6 +34,7 @@ codex-common = { workspace = true, features = [ "sandbox_summary", ] } codex-core = { workspace = true } +codex-api-client = { workspace = true } codex-file-search = { workspace = true } codex-login = { workspace = true } codex-ollama = { workspace = true } diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 487b687a57..1102b0a819 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -6,9 +6,9 @@ use additional_dirs::add_dir_warning_message; use app::App; pub use app::AppExitInfo; +use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_app_server_protocol::AuthMode; use codex_core::AuthManager; -use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID; use codex_core::CodexAuth; use codex_core::INTERACTIVE_SESSION_SOURCES; use codex_core::RolloutRecorder; diff --git a/codex-rs/utils/git/Cargo.toml b/codex-rs/utils/git/Cargo.toml index 072587bdc8..293f24f871 100644 --- a/codex-rs/utils/git/Cargo.toml +++ b/codex-rs/utils/git/Cargo.toml @@ -8,7 +8,6 @@ readme = "README.md" workspace = true [dependencies] -once_cell = "1" regex = "1" schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/codex-rs/utils/git/src/apply.rs b/codex-rs/utils/git/src/apply.rs index c9e8503242..5fbc5430ac 100644 --- a/codex-rs/utils/git/src/apply.rs +++ b/codex-rs/utils/git/src/apply.rs @@ -6,12 +6,12 @@ //! mode via [`ApplyGitRequest::preflight`] and inspect the resulting paths to //! learn what would change before applying for real. -use once_cell::sync::Lazy; use regex::Regex; use std::ffi::OsStr; use std::io; use std::path::Path; use std::path::PathBuf; +use std::sync::LazyLock; /// Parameters for invoking [`apply_git_patch`]. #[derive(Debug, Clone)] @@ -192,7 +192,7 @@ fn render_command_for_log(cwd: &Path, git_cfg: &[String], args: &[String]) -> St /// Collect every path referenced by the diff headers inside `diff --git` sections. pub fn extract_paths_from_patch(diff_text: &str) -> Vec { - static RE: Lazy = Lazy::new(|| { + static RE: LazyLock = LazyLock::new(|| { Regex::new(r"(?m)^diff --git a/(.*?) b/(.*)$") .unwrap_or_else(|e| panic!("invalid regex: {e}")) }); @@ -275,62 +275,64 @@ pub fn parse_git_apply_output( } } - static APPLIED_CLEAN: Lazy = - Lazy::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P.+?)\\s+cleanly\\.?$")); - static APPLIED_CONFLICTS: Lazy = - Lazy::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P.+?)\\s+with conflicts\\.?$")); - static APPLYING_WITH_REJECTS: Lazy = Lazy::new(|| { + static APPLIED_CLEAN: LazyLock = + LazyLock::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P.+?)\\s+cleanly\\.?$")); + static APPLIED_CONFLICTS: LazyLock = LazyLock::new(|| { + regex_ci("^Applied patch(?: to)?\\s+(?P.+?)\\s+with conflicts\\.?$") + }); + static APPLYING_WITH_REJECTS: LazyLock = LazyLock::new(|| { regex_ci("^Applying patch\\s+(?P.+?)\\s+with\\s+\\d+\\s+rejects?\\.{0,3}$") }); - static CHECKING_PATCH: Lazy = - Lazy::new(|| regex_ci("^Checking patch\\s+(?P.+?)\\.\\.\\.$")); - static UNMERGED_LINE: Lazy = Lazy::new(|| regex_ci("^U\\s+(?P.+)$")); - static PATCH_FAILED: Lazy = - Lazy::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P.+?)(?::\\d+)?(?:\\s|$)")); - static DOES_NOT_APPLY: Lazy = - Lazy::new(|| regex_ci("^error:\\s+(?P.+?):\\s+patch does not apply$")); - static THREE_WAY_START: Lazy = Lazy::new(|| { + static CHECKING_PATCH: LazyLock = + LazyLock::new(|| regex_ci("^Checking patch\\s+(?P.+?)\\.\\.\\.$")); + static UNMERGED_LINE: LazyLock = LazyLock::new(|| regex_ci("^U\\s+(?P.+)$")); + static PATCH_FAILED: LazyLock = + LazyLock::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P.+?)(?::\\d+)?(?:\\s|$)")); + static DOES_NOT_APPLY: LazyLock = + LazyLock::new(|| regex_ci("^error:\\s+(?P.+?):\\s+patch does not apply$")); + static THREE_WAY_START: LazyLock = LazyLock::new(|| { regex_ci("^(?:Performing three-way merge|Falling back to three-way merge)\\.\\.\\.$") }); - static THREE_WAY_FAILED: Lazy = - Lazy::new(|| regex_ci("^Failed to perform three-way merge\\.\\.\\.$")); - static FALLBACK_DIRECT: Lazy = - Lazy::new(|| regex_ci("^Falling back to direct application\\.\\.\\.$")); - static LACKS_BLOB: Lazy = Lazy::new(|| { + static THREE_WAY_FAILED: LazyLock = + LazyLock::new(|| regex_ci("^Failed to perform three-way merge\\.\\.\\.$")); + static FALLBACK_DIRECT: LazyLock = + LazyLock::new(|| regex_ci("^Falling back to direct application\\.\\.\\.$")); + static LACKS_BLOB: LazyLock = LazyLock::new(|| { regex_ci( "^(?:error: )?repository lacks the necessary blob to (?:perform|fall back on) 3-?way merge\\.?$", ) }); - static INDEX_MISMATCH: Lazy = - Lazy::new(|| regex_ci("^error:\\s+(?P.+?):\\s+does not match index\\b")); - static NOT_IN_INDEX: Lazy = - Lazy::new(|| regex_ci("^error:\\s+(?P.+?):\\s+does not exist in index\\b")); - static ALREADY_EXISTS_WT: Lazy = Lazy::new(|| { + static INDEX_MISMATCH: LazyLock = + LazyLock::new(|| regex_ci("^error:\\s+(?P.+?):\\s+does not match index\\b")); + static NOT_IN_INDEX: LazyLock = + LazyLock::new(|| regex_ci("^error:\\s+(?P.+?):\\s+does not exist in index\\b")); + static ALREADY_EXISTS_WT: LazyLock = LazyLock::new(|| { regex_ci("^error:\\s+(?P.+?)\\s+already exists in (?:the )?working directory\\b") }); - static FILE_EXISTS: Lazy = - Lazy::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P.+?)\\s+File exists")); - static RENAMED_DELETED: Lazy = - Lazy::new(|| regex_ci("^error:\\s+path\\s+(?P.+?)\\s+has been renamed\\/deleted")); - static CANNOT_APPLY_BINARY: Lazy = Lazy::new(|| { + static FILE_EXISTS: LazyLock = + LazyLock::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P.+?)\\s+File exists")); + static RENAMED_DELETED: LazyLock = LazyLock::new(|| { + regex_ci("^error:\\s+path\\s+(?P.+?)\\s+has been renamed\\/deleted") + }); + static CANNOT_APPLY_BINARY: LazyLock = LazyLock::new(|| { regex_ci( "^error:\\s+cannot apply binary patch to\\s+['\\\"]?(?P.+?)['\\\"]?\\s+without full index line$", ) }); - static BINARY_DOES_NOT_APPLY: Lazy = Lazy::new(|| { + static BINARY_DOES_NOT_APPLY: LazyLock = LazyLock::new(|| { regex_ci("^error:\\s+binary patch does not apply to\\s+['\\\"]?(?P.+?)['\\\"]?$") }); - static BINARY_INCORRECT_RESULT: Lazy = Lazy::new(|| { + static BINARY_INCORRECT_RESULT: LazyLock = LazyLock::new(|| { regex_ci( "^error:\\s+binary patch to\\s+['\\\"]?(?P.+?)['\\\"]?\\s+creates incorrect result\\b", ) }); - static CANNOT_READ_CURRENT: Lazy = Lazy::new(|| { + static CANNOT_READ_CURRENT: LazyLock = LazyLock::new(|| { regex_ci("^error:\\s+cannot read the current contents of\\s+['\\\"]?(?P.+?)['\\\"]?$") }); - static SKIPPED_PATCH: Lazy = - Lazy::new(|| regex_ci("^Skipped patch\\s+['\\\"]?(?P.+?)['\\\"]\\.$")); - static CANNOT_MERGE_BINARY_WARN: Lazy = Lazy::new(|| { + static SKIPPED_PATCH: LazyLock = + LazyLock::new(|| regex_ci("^Skipped patch\\s+['\\\"]?(?P.+?)['\\\"]\\.$")); + static CANNOT_MERGE_BINARY_WARN: LazyLock = LazyLock::new(|| { regex_ci( "^warning:\\s*Cannot merge binary files:\\s+(?P.+?)\\s+\\(ours\\s+vs\\.\\s+theirs\\)", )