use std::sync::Arc; use tokio::sync::RwLock; use tokio_util::either::Either; use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; use crate::codex::Session; use crate::codex::TurnContext; use crate::error::CodexErr; use crate::function_tool::FunctionCallError; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; use crate::tools::router::ToolCall; use crate::tools::router::ToolRouter; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; use codex_utils_readiness::Readiness; pub(crate) struct ToolCallRuntime { router: Arc, session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, parallel_execution: Arc>, } impl ToolCallRuntime { pub(crate) fn new( router: Arc, session: Arc, turn_context: Arc, tracker: SharedTurnDiffTracker, ) -> Self { Self { router, session, turn_context, tracker, parallel_execution: Arc::new(RwLock::new(())), } } pub(crate) fn handle_tool_call( &self, call: ToolCall, cancellation_token: CancellationToken, ) -> impl std::future::Future> { let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); let router = Arc::clone(&self.router); let session = Arc::clone(&self.session); let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); let lock = Arc::clone(&self.parallel_execution); let aborted_response = Self::aborted_response(&call); let readiness = self.turn_context.tool_call_gate.clone(); let handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { tokio::select! { _ = cancellation_token.cancelled() => Ok(aborted_response), res = async { tracing::info!("waiting for tool gate"); readiness.wait_ready().await; tracing::info!("tool gate released"); let _guard = if supports_parallel { Either::Left(lock.read().await) } else { Either::Right(lock.write().await) }; router .dispatch_tool_call(session, turn, tracker, call) .await } => res, } })); async move { match handle.await { Ok(Ok(response)) => Ok(response), Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)), Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())), Err(err) => Err(CodexErr::Fatal(format!( "tool task failed to receive: {err:?}" ))), } } } } impl ToolCallRuntime { fn aborted_response(call: &ToolCall) -> ResponseInputItem { match &call.payload { ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { call_id: call.call_id.clone(), output: "aborted".to_string(), }, ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput { call_id: call.call_id.clone(), result: Err("aborted".to_string()), }, _ => ResponseInputItem::FunctionCallOutput { call_id: call.call_id.clone(), output: FunctionCallOutputPayload { content: "aborted".to_string(), success: None, }, }, } } }