From b127a3643fd3dc81b0fce1c2707f0c72fb62ed65 Mon Sep 17 00:00:00 2001 From: Dominik Kundel Date: Tue, 2 Sep 2025 19:49:03 -0700 Subject: [PATCH] Improve gpt-oss compatibility (#2461) The gpt-oss models require reasoning with subsequent Chat Completions requests because otherwise the model forgets why the tools were called. This change fixes that and also adds some additional missing documentation around how to handle context windows in Ollama and how to show the CoT if you desire to. --- codex-rs/core/src/chat_completions.rs | 232 ++++++++++-- codex-rs/core/src/client_common.rs | 4 +- codex-rs/core/src/lib.rs | 11 + .../core/tests/chat_completions_payload.rs | 345 ++++++++++++++++++ codex-rs/core/tests/chat_completions_sse.rs | 320 ++++++++++++++++ 5 files changed, 886 insertions(+), 26 deletions(-) create mode 100644 codex-rs/core/tests/chat_completions_payload.rs create mode 100644 codex-rs/core/tests/chat_completions_sse.rs diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 6eca119f..fc8602de 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -43,7 +43,107 @@ pub(crate) async fn stream_chat_completions( let input = prompt.get_formatted_input(); + // Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user. + // - If the last emitted message is a user message, drop all reasoning. + // - Otherwise, for each Reasoning item after the last user message, attach it + // to the immediate previous assistant message (stop turns) or the immediate + // next assistant anchor (tool-call turns: function/local shell call, or assistant message). + let mut reasoning_by_anchor_index: std::collections::HashMap = + std::collections::HashMap::new(); + + // Determine the last role that would be emitted to Chat Completions. + let mut last_emitted_role: Option<&str> = None; for item in &input { + match item { + ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + last_emitted_role = Some("assistant") + } + ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), + ResponseItem::Reasoning { .. } | ResponseItem::Other => {} + ResponseItem::CustomToolCall { .. } => {} + ResponseItem::CustomToolCallOutput { .. } => {} + ResponseItem::WebSearchCall { .. } => {} + } + } + + // Find the last user message index in the input. + let mut last_user_index: Option = None; + for (idx, item) in input.iter().enumerate() { + if let ResponseItem::Message { role, .. } = item + && role == "user" + { + last_user_index = Some(idx); + } + } + + // Attach reasoning only if the conversation does not end with a user message. + if !matches!(last_emitted_role, Some("user")) { + for (idx, item) in input.iter().enumerate() { + // Only consider reasoning that appears after the last user message. + if let Some(u_idx) = last_user_index + && idx <= u_idx + { + continue; + } + + if let ResponseItem::Reasoning { + content: Some(items), + .. + } = item + { + let mut text = String::new(); + for c in items { + match c { + ReasoningItemContent::ReasoningText { text: t } + | ReasoningItemContent::Text { text: t } => text.push_str(t), + } + } + if text.trim().is_empty() { + continue; + } + + // Prefer immediate previous assistant message (stop turns) + let mut attached = false; + if idx > 0 + && let ResponseItem::Message { role, .. } = &input[idx - 1] + && role == "assistant" + { + reasoning_by_anchor_index + .entry(idx - 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + attached = true; + } + + // Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message) + if !attached && idx + 1 < input.len() { + match &input[idx + 1] { + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + ResponseItem::Message { role, .. } if role == "assistant" => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + _ => {} + } + } + } + } + } + + // Track last assistant text we emitted to avoid duplicate assistant messages + // in the outbound Chat Completions payload (can happen if a final + // aggregated assistant message was recorded alongside an earlier partial). + let mut last_assistant_text: Option = None; + + for (idx, item) in input.iter().enumerate() { match item { ResponseItem::Message { role, content, .. } => { let mut text = String::new(); @@ -56,7 +156,24 @@ pub(crate) async fn stream_chat_completions( _ => {} } } - messages.push(json!({"role": role, "content": text})); + // Skip exact-duplicate assistant messages. + if role == "assistant" { + if let Some(prev) = &last_assistant_text + && prev == &text + { + continue; + } + last_assistant_text = Some(text.clone()); + } + + let mut msg = json!({"role": role, "content": text}); + if role == "assistant" + && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::FunctionCall { name, @@ -64,7 +181,7 @@ pub(crate) async fn stream_chat_completions( call_id, .. } => { - messages.push(json!({ + let mut msg = json!({ "role": "assistant", "content": null, "tool_calls": [{ @@ -75,7 +192,13 @@ pub(crate) async fn stream_chat_completions( "arguments": arguments, } }] - })); + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::LocalShellCall { id, @@ -84,7 +207,7 @@ pub(crate) async fn stream_chat_completions( action, } => { // Confirm with API team. - messages.push(json!({ + let mut msg = json!({ "role": "assistant", "content": null, "tool_calls": [{ @@ -93,7 +216,13 @@ pub(crate) async fn stream_chat_completions( "status": status, "action": action, }] - })); + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::FunctionCallOutput { call_id, output } => { messages.push(json!({ @@ -331,7 +460,10 @@ async fn process_chat_sse( // Some providers stream `reasoning` as a plain string while others // nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`). if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { - let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string()); + let mut maybe_text = reasoning_val + .as_str() + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()); if maybe_text.is_none() && reasoning_val.is_object() { if let Some(s) = reasoning_val @@ -350,12 +482,39 @@ async fn process_chat_sse( } if let Some(reasoning) = maybe_text { + // Accumulate so we can emit a terminal Reasoning item at the end. + reasoning_text.push_str(&reasoning); let _ = tx_event .send(Ok(ResponseEvent::ReasoningContentDelta(reasoning))) .await; } } + // Some providers only include reasoning on the final message object. + if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) + { + // Accept either a plain string or an object with { text | content } + if let Some(s) = message_reasoning.as_str() { + if !s.is_empty() { + reasoning_text.push_str(s); + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string()))) + .await; + } + } else if let Some(obj) = message_reasoning.as_object() + && let Some(s) = obj + .get("text") + .and_then(|v| v.as_str()) + .or_else(|| obj.get("content").and_then(|v| v.as_str())) + && !s.is_empty() + { + reasoning_text.push_str(s); + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string()))) + .await; + } + } + // Handle streaming function / tool calls. if let Some(tool_calls) = choice .get("delta") @@ -511,27 +670,47 @@ where // do NOT emit yet. Forward any other item (e.g. FunctionCall) right // away so downstream consumers see it. - let is_assistant_delta = matches!(&item, codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"); + let is_assistant_message = matches!( + &item, + codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant" + ); - if is_assistant_delta { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let codex_protocol::models::ResponseItem::Message { content, .. } = - &item - && let Some(text) = content.iter().find_map(|c| match c { - codex_protocol::models::ContentItem::OutputText { text } => { - Some(text) + if is_assistant_message { + match this.mode { + AggregateMode::AggregatedOnly => { + // Only use the final assistant message if we have not + // seen any deltas; otherwise, deltas already built the + // cumulative text and this would duplicate it. + if this.cumulative.is_empty() + && let codex_protocol::models::ResponseItem::Message { + content, + .. + } = &item + && let Some(text) = content.iter().find_map(|c| match c { + codex_protocol::models::ContentItem::OutputText { + text, + } => Some(text), + _ => None, + }) + { + this.cumulative.push_str(text); } - _ => None, - }) - { - this.cumulative.push_str(text); + // Swallow assistant message here; emit on Completed. + continue; + } + AggregateMode::Streaming => { + // In streaming mode, if we have not seen any deltas, forward + // the final assistant message directly. If deltas were seen, + // suppress the final message to avoid duplication. + if this.cumulative.is_empty() { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + item, + )))); + } else { + continue; + } + } } - - // Swallow assistant message here; emit on Completed. - continue; } // Not an assistant message – forward immediately. @@ -563,6 +742,11 @@ where emitted_any = true; } + // Always emit the final aggregated assistant message when any + // content deltas have been observed. In AggregatedOnly mode this + // is the sole assistant output; in Streaming mode this finalizes + // the streamed deltas into a terminal OutputItemDone so callers + // can persist/render the message once per turn. if !this.cumulative.is_empty() { let aggregated_message = codex_protocol::models::ResponseItem::Message { id: None, diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 27401051..77f6a63a 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -35,7 +35,7 @@ pub struct Prompt { /// Tools available to the model, including additional tools sourced from /// external MCP servers. - pub tools: Vec, + pub(crate) tools: Vec, /// Optional override for the built-in BASE_INSTRUCTIONS. pub base_instructions_override: Option, @@ -174,7 +174,7 @@ pub(crate) fn create_text_param_for_request( }) } -pub(crate) struct ResponseStream { +pub struct ResponseStream { pub(crate) rx_event: mpsc::Receiver>, } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index e3d787e7..b4966e79 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -69,3 +69,14 @@ pub use codex_protocol::protocol; // Re-export protocol config enums to ensure call sites can use the same types // as those in the protocol crate when constructing protocol messages. pub use codex_protocol::config_types as protocol_config_types; + +pub use client::ModelClient; +pub use client_common::Prompt; +pub use client_common::ResponseEvent; +pub use client_common::ResponseStream; +pub use codex_protocol::models::ContentItem; +pub use codex_protocol::models::LocalShellAction; +pub use codex_protocol::models::LocalShellExecAction; +pub use codex_protocol::models::LocalShellStatus; +pub use codex_protocol::models::ReasoningItemContent; +pub use codex_protocol::models::ResponseItem; diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs new file mode 100644 index 00000000..6b21894d --- /dev/null +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -0,0 +1,345 @@ +use std::sync::Arc; + +use codex_core::ContentItem; +use codex_core::LocalShellAction; +use codex_core::LocalShellExecAction; +use codex_core::LocalShellStatus; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ReasoningItemContent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use core_test_support::load_default_config_for_test; +use futures::StreamExt; +use serde_json::Value; +use tempfile::TempDir; +use uuid::Uuid; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +fn network_disabled() -> bool { + std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() +} + +async fn run_request(input: Vec) -> Value { + let server = MockServer::start().await; + + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + "data: {\"choices\":[{\"delta\":{}}]}\n\ndata: [DONE]\n\n", + "text/event-stream", + ); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "mock".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = match TempDir::new() { + Ok(dir) => dir, + Err(e) => panic!("failed to create TempDir: {e}"), + }; + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + config.show_raw_agent_reasoning = true; + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let client = ModelClient::new( + Arc::clone(&config), + None, + provider, + effort, + summary, + Uuid::new_v4(), + ); + + let mut prompt = Prompt::default(); + prompt.input = input; + + let mut stream = match client.stream(&prompt).await { + Ok(s) => s, + Err(e) => panic!("stream chat failed: {e}"), + }; + while let Some(event) = stream.next().await { + if let Err(e) = event { + panic!("stream event error: {e}"); + } + } + + let requests = match server.received_requests().await { + Some(reqs) => reqs, + None => panic!("request not made"), + }; + match requests[0].body_json() { + Ok(v) => v, + Err(e) => panic!("invalid json body: {e}"), + } +} + +fn user_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + } +} + +fn assistant_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + } +} + +fn reasoning_item(text: &str) -> ResponseItem { + ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![ReasoningItemContent::ReasoningText { + text: text.to_string(), + }]), + encrypted_content: None, + } +} + +fn function_call() -> ResponseItem { + ResponseItem::FunctionCall { + id: None, + name: "f".to_string(), + arguments: "{}".to_string(), + call_id: "c1".to_string(), + } +} + +fn local_shell_call() -> ResponseItem { + ResponseItem::LocalShellCall { + id: Some("id1".to_string()), + call_id: None, + status: LocalShellStatus::InProgress, + action: LocalShellAction::Exec(LocalShellExecAction { + command: vec!["echo".to_string()], + timeout_ms: Some(1_000), + working_directory: None, + env: None, + user: None, + }), + } +} + +fn messages_from(body: &Value) -> Vec { + match body["messages"].as_array() { + Some(arr) => arr.clone(), + None => panic!("messages array missing"), + } +} + +fn first_assistant(messages: &[Value]) -> &Value { + match messages.iter().find(|msg| msg["role"] == "assistant") { + Some(v) => v, + None => panic!("assistant message not present"), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn omits_reasoning_when_none_present() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![user_message("u1"), assistant_message("a1")]).await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["content"], Value::String("a1".into())); + assert!(assistant.get("reasoning").is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_previous_assistant() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + reasoning_item("rA"), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["content"], Value::String("a1".into())); + assert_eq!(assistant["reasoning"], Value::String("rA".into())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_function_call_anchor() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + reasoning_item("rFunc"), + function_call(), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["reasoning"], Value::String("rFunc".into())); + let tool_calls = match assistant["tool_calls"].as_array() { + Some(arr) => arr, + None => panic!("tool call list missing"), + }; + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0]["type"], Value::String("function".into())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_local_shell_call() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + reasoning_item("rShell"), + local_shell_call(), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["reasoning"], Value::String("rShell".into())); + assert_eq!( + assistant["tool_calls"][0]["type"], + Value::String("local_shell_call".into()) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn drops_reasoning_when_last_role_is_user() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + assistant_message("aPrev"), + reasoning_item("rHist"), + user_message("uNew"), + ]) + .await; + let messages = messages_from(&body); + assert!(messages.iter().all(|msg| msg.get("reasoning").is_none())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn ignores_reasoning_before_last_user() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + user_message("u2"), + reasoning_item("rAfterU1"), + ]) + .await; + let messages = messages_from(&body); + assert!(messages.iter().all(|msg| msg.get("reasoning").is_none())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn skips_empty_reasoning_segments() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + reasoning_item(""), + reasoning_item(" "), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + assert!(assistant.get("reasoning").is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn suppresses_duplicate_assistant_messages() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![assistant_message("dup"), assistant_message("dup")]).await; + let messages = messages_from(&body); + let assistant_messages: Vec<_> = messages + .iter() + .filter(|msg| msg["role"] == "assistant") + .collect(); + assert_eq!(assistant_messages.len(), 1); + assert_eq!( + assistant_messages[0]["content"], + Value::String("dup".into()) + ); +} diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs new file mode 100644 index 00000000..1df658da --- /dev/null +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -0,0 +1,320 @@ +use std::sync::Arc; + +use codex_core::ContentItem; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ResponseEvent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use core_test_support::load_default_config_for_test; +use futures::StreamExt; +use tempfile::TempDir; +use uuid::Uuid; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +fn network_disabled() -> bool { + std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() +} + +async fn run_stream(sse_body: &str) -> Vec { + let server = MockServer::start().await; + + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_body.to_string(), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "mock".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = match TempDir::new() { + Ok(dir) => dir, + Err(e) => panic!("failed to create TempDir: {e}"), + }; + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + config.show_raw_agent_reasoning = true; + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let client = ModelClient::new( + Arc::clone(&config), + None, + provider, + effort, + summary, + Uuid::new_v4(), + ); + + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "hello".to_string(), + }], + }]; + + let mut stream = match client.stream(&prompt).await { + Ok(s) => s, + Err(e) => panic!("stream chat failed: {e}"), + }; + let mut events = Vec::new(); + while let Some(event) = stream.next().await { + match event { + Ok(ev) => events.push(ev), + Err(e) => panic!("stream event error: {e}"), + } + } + events +} + +fn assert_message(item: &ResponseItem, expected: &str) { + if let ResponseItem::Message { content, .. } = item { + let text = content.iter().find_map(|part| match part { + ContentItem::OutputText { text } | ContentItem::InputText { text } => Some(text), + _ => None, + }); + let Some(text) = text else { + panic!("message missing text: {item:?}"); + }; + assert_eq!(text, expected); + } else { + panic!("expected message item, got: {item:?}"); + } +} + +fn assert_reasoning(item: &ResponseItem, expected: &str) { + if let ResponseItem::Reasoning { + content: Some(parts), + .. + } = item + { + let mut combined = String::new(); + for part in parts { + match part { + codex_core::ReasoningItemContent::ReasoningText { text } + | codex_core::ReasoningItemContent::Text { text } => combined.push_str(text), + } + } + assert_eq!(combined, expected); + } else { + panic!("expected reasoning item, got: {item:?}"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_text_without_reasoning() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{}}]}\n\n", + "data: [DONE]\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 3, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "hi"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "hi"), + other => panic!("expected terminal message, got {other:?}"), + } + + assert!(matches!(events[2], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_string_delta() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":\"think1\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 5, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "think1"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "ok"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "think1"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[3] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "ok"), + other => panic!("expected message item, got {other:?}"), + } + + assert!(matches!(events[4], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_object_delta() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":{\"text\":\"partA\"}}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"reasoning\":{\"content\":\"partB\"}}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"content\":\"answer\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 6, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partA"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partB"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "answer"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[3] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "partApartB"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[4] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "answer"), + other => panic!("expected message item, got {other:?}"), + } + + assert!(matches!(events[5], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_final_message() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = "data: {\"choices\":[{\"message\":{\"reasoning\":\"final-cot\"},\"finish_reason\":\"stop\"}]}\n\n"; + + let events = run_stream(sse).await; + assert_eq!(events.len(), 3, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "final-cot"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "final-cot"), + other => panic!("expected reasoning item, got {other:?}"), + } + + assert!(matches!(events[2], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_before_tool_call() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":\"pre-tool\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"run\",\"arguments\":\"{}\"}}]},\"finish_reason\":\"tool_calls\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 4, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "pre-tool"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "pre-tool"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + }) => { + assert_eq!(name, "run"); + assert_eq!(arguments, "{}"); + assert_eq!(call_id, "call_1"); + } + other => panic!("expected function call, got {other:?}"), + } + + assert!(matches!(events[3], ResponseEvent::Completed { .. })); +}