diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 1b21f6e0..5f4f2a1c 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -19,6 +19,7 @@ use tracing::debug; use tracing::trace; use tracing::warn; +use crate::chat_completions::AggregateStreamExt; use crate::chat_completions::stream_chat_completions; use crate::client_common::Payload; use crate::client_common::Prompt; @@ -111,7 +112,31 @@ impl ModelClient { match self.provider.wire_api { WireApi::Responses => self.stream_responses(prompt).await, WireApi::Chat => { - stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await + // Create the raw streaming connection first. + let response_stream = + stream_chat_completions(prompt, &self.model, &self.client, &self.provider) + .await?; + + // Wrap it with the aggregation adapter so callers see *only* + // the final assistant message per turn (matching the + // behaviour of the Responses API). + let mut aggregated = response_stream.aggregate(); + + // Bridge the aggregated stream back into a standard + // `ResponseStream` by forwarding events through a channel. + let (tx, rx) = mpsc::channel::>(16); + + tokio::spawn(async move { + use futures::StreamExt; + while let Some(ev) = aggregated.next().await { + // Exit early if receiver hung up. + if tx.send(ev).await.is_err() { + break; + } + } + }); + + Ok(ResponseStream { rx_event: rx }) } } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f68eb73f..7d056adc 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -32,7 +32,6 @@ use tracing::trace; use tracing::warn; use crate::WireApi; -use crate::chat_completions::AggregateStreamExt; use crate::client::ModelClient; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; @@ -864,7 +863,7 @@ async fn try_run_turn( sub_id: &str, prompt: &Prompt, ) -> CodexResult> { - let mut stream = sess.client.clone().stream(prompt).await?.aggregate(); + let mut stream = sess.client.clone().stream(prompt).await?; // Buffer all the incoming messages from the stream first, then execute them. // If we execute a function call in the middle of handling the stream, it can time out.