diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 350773a3..1884dc72 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -6019,6 +6019,7 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", "pin-project-lite", "tokio", ] diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 367ccbce..4259e64f 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -61,7 +61,7 @@ tokio = { workspace = true, features = [ "rt-multi-thread", "signal", ] } -tokio-util = { workspace = true } +tokio-util = { workspace = true, features = ["rt"] } toml = { workspace = true } toml_edit = { workspace = true } tracing = { workspace = true, features = ["log"] } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 25f93d54..7d7bd60a 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource; use codex_protocol::protocol::TaskStartedEvent; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnContextItem; +use futures::future::BoxFuture; use futures::prelude::*; +use futures::stream::FuturesOrdered; use mcp_types::CallToolResult; use serde_json; use serde_json::Value; @@ -2101,14 +2103,15 @@ async fn try_run_turn( sess.persist_rollout_items(&[rollout_item]).await; let mut stream = turn_context.client.clone().stream(&prompt).await?; - let mut output = Vec::new(); - let mut tool_runtime = ToolCallRuntime::new( + let tool_runtime = ToolCallRuntime::new( Arc::clone(&router), Arc::clone(&sess), Arc::clone(&turn_context), Arc::clone(&turn_diff_tracker), sub_id.to_string(), ); + let mut output: FuturesOrdered>> = + FuturesOrdered::new(); loop { // Poll the next item from the model stream. We must inspect *both* Ok and Err @@ -2116,9 +2119,8 @@ async fn try_run_turn( // `response.completed`) bubble up and trigger the caller's retry logic. let event = stream.next().await; let event = match event { - Some(event) => event, + Some(res) => res?, None => { - tool_runtime.abort_all(); return Err(CodexErr::Stream( "stream closed before response.completed".into(), None, @@ -2126,14 +2128,8 @@ async fn try_run_turn( } }; - let event = match event { - Ok(ev) => ev, - Err(e) => { - tool_runtime.abort_all(); - // Propagate the underlying stream error to the caller (run_turn), which - // will apply the configured `stream_max_retries` policy. - return Err(e); - } + let add_completed = &mut |response_item: ProcessedResponseItem| { + output.push_back(future::ready(Ok(response_item)).boxed()); }; match event { @@ -2143,14 +2139,18 @@ async fn try_run_turn( Ok(Some(call)) => { let payload_preview = call.payload.log_payload().into_owned(); tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); - let index = output.len(); - output.push(ProcessedResponseItem { - item, - response: None, - }); - tool_runtime - .handle_tool_call(call, index, output.as_mut_slice()) - .await?; + + let response = tool_runtime.handle_tool_call(call); + + output.push_back( + async move { + Ok(ProcessedResponseItem { + item, + response: Some(response.await?), + }) + } + .boxed(), + ); } Ok(None) => { let response = handle_non_tool_response_item( @@ -2160,7 +2160,7 @@ async fn try_run_turn( item.clone(), ) .await?; - output.push(ProcessedResponseItem { item, response }); + add_completed(ProcessedResponseItem { item, response }); } Err(FunctionCallError::MissingLocalShellCallId) => { let msg = "LocalShellCall without call_id or id"; @@ -2177,7 +2177,7 @@ async fn try_run_turn( success: None, }, }; - output.push(ProcessedResponseItem { + add_completed(ProcessedResponseItem { item, response: Some(response), }); @@ -2190,7 +2190,7 @@ async fn try_run_turn( success: None, }, }; - output.push(ProcessedResponseItem { + add_completed(ProcessedResponseItem { item, response: Some(response), }); @@ -2221,7 +2221,7 @@ async fn try_run_turn( sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref()) .await; - tool_runtime.resolve_pending(output.as_mut_slice()).await?; + let processed_items: Vec = output.try_collect().await?; let unified_diff = { let mut tracker = turn_diff_tracker.lock().await; @@ -2237,7 +2237,7 @@ async fn try_run_turn( } let result = TurnRunResult { - processed_items: output, + processed_items, total_token_usage: token_usage.clone(), }; diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index ff4104d0..26dfed8e 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use tokio::task::JoinHandle; +use tokio::sync::RwLock; +use tokio_util::either::Either; +use tokio_util::task::AbortOnDropHandle; use crate::codex::Session; use crate::codex::TurnContext; @@ -11,20 +13,13 @@ use crate::tools::router::ToolCall; use crate::tools::router::ToolRouter; use codex_protocol::models::ResponseInputItem; -use crate::codex::ProcessedResponseItem; - -struct PendingToolCall { - index: usize, - handle: JoinHandle>, -} - pub(crate) struct ToolCallRuntime { router: Arc, session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, sub_id: String, - pending_calls: Vec, + parallel_execution: Arc>, } impl ToolCallRuntime { @@ -41,97 +36,45 @@ impl ToolCallRuntime { turn_context, tracker, sub_id, - pending_calls: Vec::new(), + parallel_execution: Arc::new(RwLock::new(())), } } - pub(crate) async fn handle_tool_call( - &mut self, + pub(crate) fn handle_tool_call( + &self, call: ToolCall, - output_index: usize, - output: &mut [ProcessedResponseItem], - ) -> Result<(), CodexErr> { + ) -> impl std::future::Future> { let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); - if supports_parallel { - self.spawn_parallel(call, output_index); - } else { - self.resolve_pending(output).await?; - let response = self.dispatch_serial(call).await?; - let slot = output.get_mut(output_index).ok_or_else(|| { - CodexErr::Fatal(format!("tool output index {output_index} out of bounds")) - })?; - slot.response = Some(response); - } - Ok(()) - } - - pub(crate) fn abort_all(&mut self) { - while let Some(pending) = self.pending_calls.pop() { - pending.handle.abort(); - } - } - - pub(crate) async fn resolve_pending( - &mut self, - output: &mut [ProcessedResponseItem], - ) -> Result<(), CodexErr> { - while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() { - match handle.await { - Ok(Ok(response)) => { - if let Some(slot) = output.get_mut(index) { - slot.response = Some(response); - } - } - Ok(Err(FunctionCallError::Fatal(message))) => { - self.abort_all(); - return Err(CodexErr::Fatal(message)); - } - Ok(Err(other)) => { - self.abort_all(); - return Err(CodexErr::Fatal(other.to_string())); - } - Err(join_err) => { - self.abort_all(); - return Err(CodexErr::Fatal(format!( - "tool task failed to join: {join_err}" - ))); - } - } - } - - Ok(()) - } - - fn spawn_parallel(&mut self, call: ToolCall, index: usize) { let router = Arc::clone(&self.router); let session = Arc::clone(&self.session); let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); let sub_id = self.sub_id.clone(); - let handle = tokio::spawn(async move { - router - .dispatch_tool_call(session, turn, tracker, sub_id, call) - .await - }); - self.pending_calls.push(PendingToolCall { index, handle }); - } + let lock = Arc::clone(&self.parallel_execution); - async fn dispatch_serial(&self, call: ToolCall) -> Result { - match self - .router - .dispatch_tool_call( - Arc::clone(&self.session), - Arc::clone(&self.turn_context), - Arc::clone(&self.tracker), - self.sub_id.clone(), - call, - ) - .await - { - Ok(response) => Ok(response), - Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), - Err(other) => Err(CodexErr::Fatal(other.to_string())), + let handle: AbortOnDropHandle> = + AbortOnDropHandle::new(tokio::spawn(async move { + let _guard = if supports_parallel { + Either::Left(lock.read().await) + } else { + Either::Right(lock.write().await) + }; + + router + .dispatch_tool_call(session, turn, tracker, sub_id, call) + .await + })); + + async move { + match handle.await { + Ok(Ok(response)) => Ok(response), + Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)), + Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())), + Err(err) => Err(CodexErr::Fatal(format!( + "tool task failed to receive: {err:?}" + ))), + } } } }