diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 6eca119f..fc8602de 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -43,7 +43,107 @@ pub(crate) async fn stream_chat_completions( let input = prompt.get_formatted_input(); + // Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user. + // - If the last emitted message is a user message, drop all reasoning. + // - Otherwise, for each Reasoning item after the last user message, attach it + // to the immediate previous assistant message (stop turns) or the immediate + // next assistant anchor (tool-call turns: function/local shell call, or assistant message). + let mut reasoning_by_anchor_index: std::collections::HashMap = + std::collections::HashMap::new(); + + // Determine the last role that would be emitted to Chat Completions. + let mut last_emitted_role: Option<&str> = None; for item in &input { + match item { + ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()), + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + last_emitted_role = Some("assistant") + } + ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"), + ResponseItem::Reasoning { .. } | ResponseItem::Other => {} + ResponseItem::CustomToolCall { .. } => {} + ResponseItem::CustomToolCallOutput { .. } => {} + ResponseItem::WebSearchCall { .. } => {} + } + } + + // Find the last user message index in the input. + let mut last_user_index: Option = None; + for (idx, item) in input.iter().enumerate() { + if let ResponseItem::Message { role, .. } = item + && role == "user" + { + last_user_index = Some(idx); + } + } + + // Attach reasoning only if the conversation does not end with a user message. + if !matches!(last_emitted_role, Some("user")) { + for (idx, item) in input.iter().enumerate() { + // Only consider reasoning that appears after the last user message. + if let Some(u_idx) = last_user_index + && idx <= u_idx + { + continue; + } + + if let ResponseItem::Reasoning { + content: Some(items), + .. + } = item + { + let mut text = String::new(); + for c in items { + match c { + ReasoningItemContent::ReasoningText { text: t } + | ReasoningItemContent::Text { text: t } => text.push_str(t), + } + } + if text.trim().is_empty() { + continue; + } + + // Prefer immediate previous assistant message (stop turns) + let mut attached = false; + if idx > 0 + && let ResponseItem::Message { role, .. } = &input[idx - 1] + && role == "assistant" + { + reasoning_by_anchor_index + .entry(idx - 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + attached = true; + } + + // Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message) + if !attached && idx + 1 < input.len() { + match &input[idx + 1] { + ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + ResponseItem::Message { role, .. } if role == "assistant" => { + reasoning_by_anchor_index + .entry(idx + 1) + .and_modify(|v| v.push_str(&text)) + .or_insert(text.clone()); + } + _ => {} + } + } + } + } + } + + // Track last assistant text we emitted to avoid duplicate assistant messages + // in the outbound Chat Completions payload (can happen if a final + // aggregated assistant message was recorded alongside an earlier partial). + let mut last_assistant_text: Option = None; + + for (idx, item) in input.iter().enumerate() { match item { ResponseItem::Message { role, content, .. } => { let mut text = String::new(); @@ -56,7 +156,24 @@ pub(crate) async fn stream_chat_completions( _ => {} } } - messages.push(json!({"role": role, "content": text})); + // Skip exact-duplicate assistant messages. + if role == "assistant" { + if let Some(prev) = &last_assistant_text + && prev == &text + { + continue; + } + last_assistant_text = Some(text.clone()); + } + + let mut msg = json!({"role": role, "content": text}); + if role == "assistant" + && let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::FunctionCall { name, @@ -64,7 +181,7 @@ pub(crate) async fn stream_chat_completions( call_id, .. } => { - messages.push(json!({ + let mut msg = json!({ "role": "assistant", "content": null, "tool_calls": [{ @@ -75,7 +192,13 @@ pub(crate) async fn stream_chat_completions( "arguments": arguments, } }] - })); + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::LocalShellCall { id, @@ -84,7 +207,7 @@ pub(crate) async fn stream_chat_completions( action, } => { // Confirm with API team. - messages.push(json!({ + let mut msg = json!({ "role": "assistant", "content": null, "tool_calls": [{ @@ -93,7 +216,13 @@ pub(crate) async fn stream_chat_completions( "status": status, "action": action, }] - })); + }); + if let Some(reasoning) = reasoning_by_anchor_index.get(&idx) + && let Some(obj) = msg.as_object_mut() + { + obj.insert("reasoning".to_string(), json!(reasoning)); + } + messages.push(msg); } ResponseItem::FunctionCallOutput { call_id, output } => { messages.push(json!({ @@ -331,7 +460,10 @@ async fn process_chat_sse( // Some providers stream `reasoning` as a plain string while others // nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`). if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) { - let mut maybe_text = reasoning_val.as_str().map(|s| s.to_string()); + let mut maybe_text = reasoning_val + .as_str() + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()); if maybe_text.is_none() && reasoning_val.is_object() { if let Some(s) = reasoning_val @@ -350,12 +482,39 @@ async fn process_chat_sse( } if let Some(reasoning) = maybe_text { + // Accumulate so we can emit a terminal Reasoning item at the end. + reasoning_text.push_str(&reasoning); let _ = tx_event .send(Ok(ResponseEvent::ReasoningContentDelta(reasoning))) .await; } } + // Some providers only include reasoning on the final message object. + if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning")) + { + // Accept either a plain string or an object with { text | content } + if let Some(s) = message_reasoning.as_str() { + if !s.is_empty() { + reasoning_text.push_str(s); + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string()))) + .await; + } + } else if let Some(obj) = message_reasoning.as_object() + && let Some(s) = obj + .get("text") + .and_then(|v| v.as_str()) + .or_else(|| obj.get("content").and_then(|v| v.as_str())) + && !s.is_empty() + { + reasoning_text.push_str(s); + let _ = tx_event + .send(Ok(ResponseEvent::ReasoningContentDelta(s.to_string()))) + .await; + } + } + // Handle streaming function / tool calls. if let Some(tool_calls) = choice .get("delta") @@ -511,27 +670,47 @@ where // do NOT emit yet. Forward any other item (e.g. FunctionCall) right // away so downstream consumers see it. - let is_assistant_delta = matches!(&item, codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"); + let is_assistant_message = matches!( + &item, + codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant" + ); - if is_assistant_delta { - // Only use the final assistant message if we have not - // seen any deltas; otherwise, deltas already built the - // cumulative text and this would duplicate it. - if this.cumulative.is_empty() - && let codex_protocol::models::ResponseItem::Message { content, .. } = - &item - && let Some(text) = content.iter().find_map(|c| match c { - codex_protocol::models::ContentItem::OutputText { text } => { - Some(text) + if is_assistant_message { + match this.mode { + AggregateMode::AggregatedOnly => { + // Only use the final assistant message if we have not + // seen any deltas; otherwise, deltas already built the + // cumulative text and this would duplicate it. + if this.cumulative.is_empty() + && let codex_protocol::models::ResponseItem::Message { + content, + .. + } = &item + && let Some(text) = content.iter().find_map(|c| match c { + codex_protocol::models::ContentItem::OutputText { + text, + } => Some(text), + _ => None, + }) + { + this.cumulative.push_str(text); } - _ => None, - }) - { - this.cumulative.push_str(text); + // Swallow assistant message here; emit on Completed. + continue; + } + AggregateMode::Streaming => { + // In streaming mode, if we have not seen any deltas, forward + // the final assistant message directly. If deltas were seen, + // suppress the final message to avoid duplication. + if this.cumulative.is_empty() { + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone( + item, + )))); + } else { + continue; + } + } } - - // Swallow assistant message here; emit on Completed. - continue; } // Not an assistant message – forward immediately. @@ -563,6 +742,11 @@ where emitted_any = true; } + // Always emit the final aggregated assistant message when any + // content deltas have been observed. In AggregatedOnly mode this + // is the sole assistant output; in Streaming mode this finalizes + // the streamed deltas into a terminal OutputItemDone so callers + // can persist/render the message once per turn. if !this.cumulative.is_empty() { let aggregated_message = codex_protocol::models::ResponseItem::Message { id: None, diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 27401051..77f6a63a 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -35,7 +35,7 @@ pub struct Prompt { /// Tools available to the model, including additional tools sourced from /// external MCP servers. - pub tools: Vec, + pub(crate) tools: Vec, /// Optional override for the built-in BASE_INSTRUCTIONS. pub base_instructions_override: Option, @@ -174,7 +174,7 @@ pub(crate) fn create_text_param_for_request( }) } -pub(crate) struct ResponseStream { +pub struct ResponseStream { pub(crate) rx_event: mpsc::Receiver>, } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index e3d787e7..b4966e79 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -69,3 +69,14 @@ pub use codex_protocol::protocol; // Re-export protocol config enums to ensure call sites can use the same types // as those in the protocol crate when constructing protocol messages. pub use codex_protocol::config_types as protocol_config_types; + +pub use client::ModelClient; +pub use client_common::Prompt; +pub use client_common::ResponseEvent; +pub use client_common::ResponseStream; +pub use codex_protocol::models::ContentItem; +pub use codex_protocol::models::LocalShellAction; +pub use codex_protocol::models::LocalShellExecAction; +pub use codex_protocol::models::LocalShellStatus; +pub use codex_protocol::models::ReasoningItemContent; +pub use codex_protocol::models::ResponseItem; diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs new file mode 100644 index 00000000..6b21894d --- /dev/null +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -0,0 +1,345 @@ +use std::sync::Arc; + +use codex_core::ContentItem; +use codex_core::LocalShellAction; +use codex_core::LocalShellExecAction; +use codex_core::LocalShellStatus; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ReasoningItemContent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use core_test_support::load_default_config_for_test; +use futures::StreamExt; +use serde_json::Value; +use tempfile::TempDir; +use uuid::Uuid; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +fn network_disabled() -> bool { + std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() +} + +async fn run_request(input: Vec) -> Value { + let server = MockServer::start().await; + + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw( + "data: {\"choices\":[{\"delta\":{}}]}\n\ndata: [DONE]\n\n", + "text/event-stream", + ); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "mock".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = match TempDir::new() { + Ok(dir) => dir, + Err(e) => panic!("failed to create TempDir: {e}"), + }; + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + config.show_raw_agent_reasoning = true; + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let client = ModelClient::new( + Arc::clone(&config), + None, + provider, + effort, + summary, + Uuid::new_v4(), + ); + + let mut prompt = Prompt::default(); + prompt.input = input; + + let mut stream = match client.stream(&prompt).await { + Ok(s) => s, + Err(e) => panic!("stream chat failed: {e}"), + }; + while let Some(event) = stream.next().await { + if let Err(e) = event { + panic!("stream event error: {e}"); + } + } + + let requests = match server.received_requests().await { + Some(reqs) => reqs, + None => panic!("request not made"), + }; + match requests[0].body_json() { + Ok(v) => v, + Err(e) => panic!("invalid json body: {e}"), + } +} + +fn user_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: text.to_string(), + }], + } +} + +fn assistant_message(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: text.to_string(), + }], + } +} + +fn reasoning_item(text: &str) -> ResponseItem { + ResponseItem::Reasoning { + id: String::new(), + summary: Vec::new(), + content: Some(vec![ReasoningItemContent::ReasoningText { + text: text.to_string(), + }]), + encrypted_content: None, + } +} + +fn function_call() -> ResponseItem { + ResponseItem::FunctionCall { + id: None, + name: "f".to_string(), + arguments: "{}".to_string(), + call_id: "c1".to_string(), + } +} + +fn local_shell_call() -> ResponseItem { + ResponseItem::LocalShellCall { + id: Some("id1".to_string()), + call_id: None, + status: LocalShellStatus::InProgress, + action: LocalShellAction::Exec(LocalShellExecAction { + command: vec!["echo".to_string()], + timeout_ms: Some(1_000), + working_directory: None, + env: None, + user: None, + }), + } +} + +fn messages_from(body: &Value) -> Vec { + match body["messages"].as_array() { + Some(arr) => arr.clone(), + None => panic!("messages array missing"), + } +} + +fn first_assistant(messages: &[Value]) -> &Value { + match messages.iter().find(|msg| msg["role"] == "assistant") { + Some(v) => v, + None => panic!("assistant message not present"), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn omits_reasoning_when_none_present() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![user_message("u1"), assistant_message("a1")]).await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["content"], Value::String("a1".into())); + assert!(assistant.get("reasoning").is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_previous_assistant() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + reasoning_item("rA"), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["content"], Value::String("a1".into())); + assert_eq!(assistant["reasoning"], Value::String("rA".into())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_function_call_anchor() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + reasoning_item("rFunc"), + function_call(), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["reasoning"], Value::String("rFunc".into())); + let tool_calls = match assistant["tool_calls"].as_array() { + Some(arr) => arr, + None => panic!("tool call list missing"), + }; + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0]["type"], Value::String("function".into())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn attaches_reasoning_to_local_shell_call() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + reasoning_item("rShell"), + local_shell_call(), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + + assert_eq!(assistant["reasoning"], Value::String("rShell".into())); + assert_eq!( + assistant["tool_calls"][0]["type"], + Value::String("local_shell_call".into()) + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn drops_reasoning_when_last_role_is_user() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + assistant_message("aPrev"), + reasoning_item("rHist"), + user_message("uNew"), + ]) + .await; + let messages = messages_from(&body); + assert!(messages.iter().all(|msg| msg.get("reasoning").is_none())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn ignores_reasoning_before_last_user() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + user_message("u2"), + reasoning_item("rAfterU1"), + ]) + .await; + let messages = messages_from(&body); + assert!(messages.iter().all(|msg| msg.get("reasoning").is_none())); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn skips_empty_reasoning_segments() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![ + user_message("u1"), + assistant_message("a1"), + reasoning_item(""), + reasoning_item(" "), + ]) + .await; + let messages = messages_from(&body); + let assistant = first_assistant(&messages); + assert!(assistant.get("reasoning").is_none()); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn suppresses_duplicate_assistant_messages() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let body = run_request(vec![assistant_message("dup"), assistant_message("dup")]).await; + let messages = messages_from(&body); + let assistant_messages: Vec<_> = messages + .iter() + .filter(|msg| msg["role"] == "assistant") + .collect(); + assert_eq!(assistant_messages.len(), 1); + assert_eq!( + assistant_messages[0]["content"], + Value::String("dup".into()) + ); +} diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs new file mode 100644 index 00000000..1df658da --- /dev/null +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -0,0 +1,320 @@ +use std::sync::Arc; + +use codex_core::ContentItem; +use codex_core::ModelClient; +use codex_core::ModelProviderInfo; +use codex_core::Prompt; +use codex_core::ResponseEvent; +use codex_core::ResponseItem; +use codex_core::WireApi; +use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use core_test_support::load_default_config_for_test; +use futures::StreamExt; +use tempfile::TempDir; +use uuid::Uuid; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +fn network_disabled() -> bool { + std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() +} + +async fn run_stream(sse_body: &str) -> Vec { + let server = MockServer::start().await; + + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_body.to_string(), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/chat/completions")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "mock".into(), + base_url: Some(format!("{}/v1", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = match TempDir::new() { + Ok(dir) => dir, + Err(e) => panic!("failed to create TempDir: {e}"), + }; + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + config.show_raw_agent_reasoning = true; + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let client = ModelClient::new( + Arc::clone(&config), + None, + provider, + effort, + summary, + Uuid::new_v4(), + ); + + let mut prompt = Prompt::default(); + prompt.input = vec![ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: "hello".to_string(), + }], + }]; + + let mut stream = match client.stream(&prompt).await { + Ok(s) => s, + Err(e) => panic!("stream chat failed: {e}"), + }; + let mut events = Vec::new(); + while let Some(event) = stream.next().await { + match event { + Ok(ev) => events.push(ev), + Err(e) => panic!("stream event error: {e}"), + } + } + events +} + +fn assert_message(item: &ResponseItem, expected: &str) { + if let ResponseItem::Message { content, .. } = item { + let text = content.iter().find_map(|part| match part { + ContentItem::OutputText { text } | ContentItem::InputText { text } => Some(text), + _ => None, + }); + let Some(text) = text else { + panic!("message missing text: {item:?}"); + }; + assert_eq!(text, expected); + } else { + panic!("expected message item, got: {item:?}"); + } +} + +fn assert_reasoning(item: &ResponseItem, expected: &str) { + if let ResponseItem::Reasoning { + content: Some(parts), + .. + } = item + { + let mut combined = String::new(); + for part in parts { + match part { + codex_core::ReasoningItemContent::ReasoningText { text } + | codex_core::ReasoningItemContent::Text { text } => combined.push_str(text), + } + } + assert_eq!(combined, expected); + } else { + panic!("expected reasoning item, got: {item:?}"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_text_without_reasoning() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{}}]}\n\n", + "data: [DONE]\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 3, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "hi"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "hi"), + other => panic!("expected terminal message, got {other:?}"), + } + + assert!(matches!(events[2], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_string_delta() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":\"think1\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"content\":\"ok\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 5, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "think1"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "ok"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "think1"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[3] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "ok"), + other => panic!("expected message item, got {other:?}"), + } + + assert!(matches!(events[4], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_object_delta() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":{\"text\":\"partA\"}}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"reasoning\":{\"content\":\"partB\"}}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"content\":\"answer\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{} ,\"finish_reason\":\"stop\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 6, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partA"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "partB"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputTextDelta(text) => assert_eq!(text, "answer"), + other => panic!("expected text delta, got {other:?}"), + } + + match &events[3] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "partApartB"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[4] { + ResponseEvent::OutputItemDone(item) => assert_message(item, "answer"), + other => panic!("expected message item, got {other:?}"), + } + + assert!(matches!(events[5], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_from_final_message() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = "data: {\"choices\":[{\"message\":{\"reasoning\":\"final-cot\"},\"finish_reason\":\"stop\"}]}\n\n"; + + let events = run_stream(sse).await; + assert_eq!(events.len(), 3, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "final-cot"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "final-cot"), + other => panic!("expected reasoning item, got {other:?}"), + } + + assert!(matches!(events[2], ResponseEvent::Completed { .. })); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streams_reasoning_before_tool_call() { + if network_disabled() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"reasoning\":\"pre-tool\"}}]}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"run\",\"arguments\":\"{}\"}}]},\"finish_reason\":\"tool_calls\"}]}\n\n", + ); + + let events = run_stream(sse).await; + assert_eq!(events.len(), 4, "unexpected events: {events:?}"); + + match &events[0] { + ResponseEvent::ReasoningContentDelta(text) => assert_eq!(text, "pre-tool"), + other => panic!("expected reasoning delta, got {other:?}"), + } + + match &events[1] { + ResponseEvent::OutputItemDone(item) => assert_reasoning(item, "pre-tool"), + other => panic!("expected reasoning item, got {other:?}"), + } + + match &events[2] { + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + }) => { + assert_eq!(name, "run"); + assert_eq!(arguments, "{}"); + assert_eq!(call_id, "call_1"); + } + other => panic!("expected function call, got {other:?}"), + } + + assert!(matches!(events[3], ResponseEvent::Completed { .. })); +}