diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index bf1919ce..0bd7848d 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -227,7 +227,7 @@ impl ModelClient { input: &input_with_instructions, tools: &tools_json, tool_choice: "auto", - parallel_tool_calls: false, + parallel_tool_calls: prompt.parallel_tool_calls, reasoning, store: azure_workaround, stream: true, diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index dcd244db..5a361062 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -33,6 +33,9 @@ pub struct Prompt { /// external MCP servers. pub(crate) tools: Vec, + /// Whether parallel tool calls are permitted for this prompt. + pub(crate) parallel_tool_calls: bool, + /// Optional override for the built-in BASE_INSTRUCTIONS. pub base_instructions_override: Option, @@ -288,6 +291,17 @@ pub(crate) mod tools { Freeform(FreeformTool), } + impl ToolSpec { + pub(crate) fn name(&self) -> &str { + match self { + ToolSpec::Function(tool) => tool.name.as_str(), + ToolSpec::LocalShell {} => "local_shell", + ToolSpec::WebSearch {} => "web_search", + ToolSpec::Freeform(tool) => tool.name.as_str(), + } + } + } + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct FreeformTool { pub(crate) name: String, @@ -433,7 +447,7 @@ mod tests { input: &input, tools: &tools, tool_choice: "auto", - parallel_tool_calls: false, + parallel_tool_calls: true, reasoning: None, store: false, stream: true, @@ -474,7 +488,7 @@ mod tests { input: &input, tools: &tools, tool_choice: "auto", - parallel_tool_calls: false, + parallel_tool_calls: true, reasoning: None, store: false, stream: true, @@ -510,7 +524,7 @@ mod tests { input: &input, tools: &tools, tool_choice: "auto", - parallel_tool_calls: false, + parallel_tool_calls: true, reasoning: None, store: false, stream: true, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 1699867f..25f93d54 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -100,7 +100,9 @@ use crate::tasks::CompactTask; use crate::tasks::RegularTask; use crate::tasks::ReviewTask; use crate::tools::ToolRouter; +use crate::tools::context::SharedTurnDiffTracker; use crate::tools::format_exec_output_str; +use crate::tools::parallel::ToolCallRuntime; use crate::turn_diff_tracker::TurnDiffTracker; use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; @@ -818,7 +820,7 @@ impl Session { async fn on_exec_command_begin( &self, - turn_diff_tracker: &mut TurnDiffTracker, + turn_diff_tracker: SharedTurnDiffTracker, exec_command_context: ExecCommandContext, ) { let ExecCommandContext { @@ -834,7 +836,10 @@ impl Session { user_explicitly_approved_this_action, changes, }) => { - turn_diff_tracker.on_patch_begin(&changes); + { + let mut tracker = turn_diff_tracker.lock().await; + tracker.on_patch_begin(&changes); + } EventMsg::PatchApplyBegin(PatchApplyBeginEvent { call_id, @@ -861,7 +866,7 @@ impl Session { async fn on_exec_command_end( &self, - turn_diff_tracker: &mut TurnDiffTracker, + turn_diff_tracker: SharedTurnDiffTracker, sub_id: &str, call_id: &str, output: &ExecToolCallOutput, @@ -909,7 +914,10 @@ impl Session { // If this is an apply_patch, after we emit the end patch, emit a second event // with the full turn diff if there is one. if is_apply_patch { - let unified_diff = turn_diff_tracker.get_unified_diff(); + let unified_diff = { + let mut tracker = turn_diff_tracker.lock().await; + tracker.get_unified_diff() + }; if let Ok(Some(unified_diff)) = unified_diff { let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); let event = Event { @@ -926,7 +934,7 @@ impl Session { /// Returns the output of the exec tool call. pub(crate) async fn run_exec_with_events( &self, - turn_diff_tracker: &mut TurnDiffTracker, + turn_diff_tracker: SharedTurnDiffTracker, prepared: PreparedExec, approval_policy: AskForApproval, ) -> Result { @@ -935,7 +943,7 @@ impl Session { let sub_id = context.sub_id.clone(); let call_id = context.call_id.clone(); - self.on_exec_command_begin(turn_diff_tracker, context.clone()) + self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone()) .await; let result = self @@ -1644,7 +1652,7 @@ pub(crate) async fn run_task( let mut last_agent_message: Option = None; // Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains // many turns, from the perspective of the user, it is a single turn. - let mut turn_diff_tracker = TurnDiffTracker::new(); + let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let mut auto_compact_recently_attempted = false; loop { @@ -1692,9 +1700,9 @@ pub(crate) async fn run_task( }) .collect(); match run_turn( - &sess, - turn_context.as_ref(), - &mut turn_diff_tracker, + Arc::clone(&sess), + Arc::clone(&turn_context), + Arc::clone(&turn_diff_tracker), sub_id.clone(), turn_input, ) @@ -1917,18 +1925,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent { } async fn run_turn( - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, + sess: Arc, + turn_context: Arc, + turn_diff_tracker: SharedTurnDiffTracker, sub_id: String, input: Vec, ) -> CodexResult { let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); - let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools)); + let router = Arc::new(ToolRouter::from_config( + &turn_context.tools_config, + Some(mcp_tools), + )); + let model_supports_parallel = turn_context + .client + .get_model_family() + .supports_parallel_tool_calls; + let parallel_tool_calls = model_supports_parallel; let prompt = Prompt { input, - tools: router.specs().to_vec(), + tools: router.specs(), + parallel_tool_calls, base_instructions_override: turn_context.base_instructions.clone(), output_schema: turn_context.final_output_json_schema.clone(), }; @@ -1936,10 +1953,10 @@ async fn run_turn( let mut retries = 0; loop { match try_run_turn( - &router, - sess, - turn_context, - turn_diff_tracker, + Arc::clone(&router), + Arc::clone(&sess), + Arc::clone(&turn_context), + Arc::clone(&turn_diff_tracker), &sub_id, &prompt, ) @@ -1950,7 +1967,7 @@ async fn run_turn( Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::ContextWindowExceeded) => { - sess.set_total_tokens_full(&sub_id, turn_context).await; + sess.set_total_tokens_full(&sub_id, &turn_context).await; return Err(e); } Err(CodexErr::UsageLimitReached(e)) => { @@ -1999,9 +2016,9 @@ async fn run_turn( /// "handled" such that it produces a `ResponseInputItem` that needs to be /// sent back to the model on the next turn. #[derive(Debug)] -struct ProcessedResponseItem { - item: ResponseItem, - response: Option, +pub(crate) struct ProcessedResponseItem { + pub(crate) item: ResponseItem, + pub(crate) response: Option, } #[derive(Debug)] @@ -2011,10 +2028,10 @@ struct TurnRunResult { } async fn try_run_turn( - router: &crate::tools::ToolRouter, - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, + router: Arc, + sess: Arc, + turn_context: Arc, + turn_diff_tracker: SharedTurnDiffTracker, sub_id: &str, prompt: &Prompt, ) -> CodexResult { @@ -2085,24 +2102,34 @@ async fn try_run_turn( let mut stream = turn_context.client.clone().stream(&prompt).await?; let mut output = Vec::new(); + let mut tool_runtime = ToolCallRuntime::new( + Arc::clone(&router), + Arc::clone(&sess), + Arc::clone(&turn_context), + Arc::clone(&turn_diff_tracker), + sub_id.to_string(), + ); loop { // Poll the next item from the model stream. We must inspect *both* Ok and Err // cases so that transient stream failures (e.g., dropped SSE connection before // `response.completed`) bubble up and trigger the caller's retry logic. let event = stream.next().await; - let Some(event) = event else { - // Channel closed without yielding a final Completed event or explicit error. - // Treat as a disconnected stream so the caller can retry. - return Err(CodexErr::Stream( - "stream closed before response.completed".into(), - None, - )); + let event = match event { + Some(event) => event, + None => { + tool_runtime.abort_all(); + return Err(CodexErr::Stream( + "stream closed before response.completed".into(), + None, + )); + } }; let event = match event { Ok(ev) => ev, Err(e) => { + tool_runtime.abort_all(); // Propagate the underlying stream error to the caller (run_turn), which // will apply the configured `stream_max_retries` policy. return Err(e); @@ -2112,16 +2139,66 @@ async fn try_run_turn( match event { ResponseEvent::Created => {} ResponseEvent::OutputItemDone(item) => { - let response = handle_response_item( - router, - sess, - turn_context, - turn_diff_tracker, - sub_id, - item.clone(), - ) - .await?; - output.push(ProcessedResponseItem { item, response }); + match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) { + Ok(Some(call)) => { + let payload_preview = call.payload.log_payload().into_owned(); + tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); + let index = output.len(); + output.push(ProcessedResponseItem { + item, + response: None, + }); + tool_runtime + .handle_tool_call(call, index, output.as_mut_slice()) + .await?; + } + Ok(None) => { + let response = handle_non_tool_response_item( + Arc::clone(&sess), + Arc::clone(&turn_context), + sub_id, + item.clone(), + ) + .await?; + output.push(ProcessedResponseItem { item, response }); + } + Err(FunctionCallError::MissingLocalShellCallId) => { + let msg = "LocalShellCall without call_id or id"; + turn_context + .client + .get_otel_event_manager() + .log_tool_failed("local_shell", msg); + error!(msg); + + let response = ResponseInputItem::FunctionCallOutput { + call_id: String::new(), + output: FunctionCallOutputPayload { + content: msg.to_string(), + success: None, + }, + }; + output.push(ProcessedResponseItem { + item, + response: Some(response), + }); + } + Err(FunctionCallError::RespondToModel(message)) => { + let response = ResponseInputItem::FunctionCallOutput { + call_id: String::new(), + output: FunctionCallOutputPayload { + content: message, + success: None, + }, + }; + output.push(ProcessedResponseItem { + item, + response: Some(response), + }); + } + Err(FunctionCallError::Fatal(message)) => { + return Err(CodexErr::Fatal(message)); + } + } } ResponseEvent::WebSearchCallBegin { call_id } => { let _ = sess @@ -2141,10 +2218,15 @@ async fn try_run_turn( response_id: _, token_usage, } => { - sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref()) + sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref()) .await; - let unified_diff = turn_diff_tracker.get_unified_diff(); + tool_runtime.resolve_pending(output.as_mut_slice()).await?; + + let unified_diff = { + let mut tracker = turn_diff_tracker.lock().await; + tracker.get_unified_diff() + }; if let Ok(Some(unified_diff)) = unified_diff { let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); let event = Event { @@ -2203,88 +2285,40 @@ async fn try_run_turn( } } -async fn handle_response_item( - router: &crate::tools::ToolRouter, - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, +async fn handle_non_tool_response_item( + sess: Arc, + turn_context: Arc, sub_id: &str, item: ResponseItem, ) -> CodexResult> { debug!(?item, "Output item"); - match ToolRouter::build_tool_call(sess, item.clone()) { - Ok(Some(call)) => { - let payload_preview = call.payload.log_payload().into_owned(); - tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); - match router - .dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call) - .await - { - Ok(response) => Ok(Some(response)), - Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), - Err(other) => unreachable!("non-fatal tool error returned: {other:?}"), + match &item { + ResponseItem::Message { .. } + | ResponseItem::Reasoning { .. } + | ResponseItem::WebSearchCall { .. } => { + let msgs = match &item { + ResponseItem::Message { .. } if turn_context.is_review_mode => { + trace!("suppressing assistant Message in review mode"); + Vec::new() + } + _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()), + }; + for msg in msgs { + let event = Event { + id: sub_id.to_string(), + msg, + }; + sess.send_event(event).await; } } - Ok(None) => { - match &item { - ResponseItem::Message { .. } - | ResponseItem::Reasoning { .. } - | ResponseItem::WebSearchCall { .. } => { - let msgs = match &item { - ResponseItem::Message { .. } if turn_context.is_review_mode => { - trace!("suppressing assistant Message in review mode"); - Vec::new() - } - _ => map_response_item_to_event_messages( - &item, - sess.show_raw_agent_reasoning(), - ), - }; - for msg in msgs { - let event = Event { - id: sub_id.to_string(), - msg, - }; - sess.send_event(event).await; - } - } - ResponseItem::FunctionCallOutput { .. } - | ResponseItem::CustomToolCallOutput { .. } => { - debug!("unexpected tool output from stream"); - } - _ => {} - } - - Ok(None) + ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => { + debug!("unexpected tool output from stream"); } - Err(FunctionCallError::MissingLocalShellCallId) => { - let msg = "LocalShellCall without call_id or id"; - turn_context - .client - .get_otel_event_manager() - .log_tool_failed("local_shell", msg); - error!(msg); - - Ok(Some(ResponseInputItem::FunctionCallOutput { - call_id: String::new(), - output: FunctionCallOutputPayload { - content: msg.to_string(), - success: None, - }, - })) - } - Err(FunctionCallError::RespondToModel(msg)) => { - Ok(Some(ResponseInputItem::FunctionCallOutput { - call_id: String::new(), - output: FunctionCallOutputPayload { - content: msg, - success: None, - }, - })) - } - Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), + _ => {} } + + Ok(None) } pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option { @@ -2927,13 +2961,10 @@ mod tests { #[tokio::test] async fn fatal_tool_error_stops_turn_and_reports_error() { let (session, turn_context, _rx) = make_session_and_context_with_rx(); - let session_ref = session.as_ref(); - let turn_context_ref = turn_context.as_ref(); let router = ToolRouter::from_config( - &turn_context_ref.tools_config, - Some(session_ref.services.mcp_connection_manager.list_all_tools()), + &turn_context.tools_config, + Some(session.services.mcp_connection_manager.list_all_tools()), ); - let mut tracker = TurnDiffTracker::new(); let item = ResponseItem::CustomToolCall { id: None, status: None, @@ -2942,22 +2973,26 @@ mod tests { input: "{}".to_string(), }; - let err = handle_response_item( - &router, - session_ref, - turn_context_ref, - &mut tracker, - "sub-id", - item, - ) - .await - .expect_err("expected fatal error"); + let call = ToolRouter::build_tool_call(session.as_ref(), item.clone()) + .expect("build tool call") + .expect("tool call present"); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let err = router + .dispatch_tool_call( + Arc::clone(&session), + Arc::clone(&turn_context), + tracker, + "sub-id".to_string(), + call, + ) + .await + .expect_err("expected fatal error"); match err { - CodexErr::Fatal(message) => { + FunctionCallError::Fatal(message) => { assert_eq!(message, "tool shell invoked with incompatible payload"); } - other => panic!("expected CodexErr::Fatal, got {other:?}"), + other => panic!("expected FunctionCallError::Fatal, got {other:?}"), } } @@ -3071,9 +3106,11 @@ mod tests { use crate::turn_diff_tracker::TurnDiffTracker; use std::collections::HashMap; - let (session, mut turn_context) = make_session_and_context(); + let (session, mut turn_context_raw) = make_session_and_context(); // Ensure policy is NOT OnRequest so the early rejection path triggers - turn_context.approval_policy = AskForApproval::OnFailure; + turn_context_raw.approval_policy = AskForApproval::OnFailure; + let session = Arc::new(session); + let mut turn_context = Arc::new(turn_context_raw); let params = ExecParams { command: if cfg!(windows) { @@ -3101,7 +3138,7 @@ mod tests { ..params.clone() }; - let mut turn_diff_tracker = TurnDiffTracker::new(); + let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); let tool_name = "shell"; let sub_id = "test-sub".to_string(); @@ -3110,9 +3147,9 @@ mod tests { let resp = handle_container_exec_with_params( tool_name, params, - &session, - &turn_context, - &mut turn_diff_tracker, + Arc::clone(&session), + Arc::clone(&turn_context), + Arc::clone(&turn_diff_tracker), sub_id, call_id, ) @@ -3131,14 +3168,16 @@ mod tests { // Now retry the same command WITHOUT escalated permissions; should succeed. // Force DangerFullAccess to avoid platform sandbox dependencies in tests. - turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess; + Arc::get_mut(&mut turn_context) + .expect("unique turn context Arc") + .sandbox_policy = SandboxPolicy::DangerFullAccess; let resp2 = handle_container_exec_with_params( tool_name, params2, - &session, - &turn_context, - &mut turn_diff_tracker, + Arc::clone(&session), + Arc::clone(&turn_context), + Arc::clone(&turn_diff_tracker), "test-sub".to_string(), "test-call-2".to_string(), ) diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 1910e070..5387cf9d 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -35,6 +35,10 @@ pub struct ModelFamily { // See https://platform.openai.com/docs/guides/tools-local-shell pub uses_local_shell_tool: bool, + /// Whether this model supports parallel tool calls when using the + /// Responses API. + pub supports_parallel_tool_calls: bool, + /// Present if the model performs better when `apply_patch` is provided as /// a tool call instead of just a bash command pub apply_patch_tool_type: Option, @@ -58,6 +62,7 @@ macro_rules! model_family { supports_reasoning_summaries: false, reasoning_summary_format: ReasoningSummaryFormat::None, uses_local_shell_tool: false, + supports_parallel_tool_calls: false, apply_patch_tool_type: None, base_instructions: BASE_INSTRUCTIONS.to_string(), experimental_supported_tools: Vec::new(), @@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option { model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true) } else if slug.starts_with("gpt-3.5") { model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true) + } else if slug.starts_with("test-gpt-5-codex") { + model_family!( + slug, slug, + supports_reasoning_summaries: true, + reasoning_summary_format: ReasoningSummaryFormat::Experimental, + base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), + experimental_supported_tools: vec![ + "read_file".to_string(), + "test_sync_tool".to_string() + ], + supports_parallel_tool_calls: true, + ) } else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") { model_family!( slug, slug, @@ -110,6 +127,8 @@ pub fn find_family_for_model(slug: &str) -> Option { reasoning_summary_format: ReasoningSummaryFormat::Experimental, base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), apply_patch_tool_type: Some(ApplyPatchToolType::Freeform), + // experimental_supported_tools: vec!["read_file".to_string()], + // supports_parallel_tool_calls: true, ) } else if slug.starts_with("gpt-5") { model_family!( @@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily { supports_reasoning_summaries: false, reasoning_summary_format: ReasoningSummaryFormat::None, uses_local_shell_tool: false, + supports_parallel_tool_calls: false, apply_patch_tool_type: None, base_instructions: BASE_INSTRUCTIONS.to_string(), experimental_supported_tools: Vec::new(), diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs index b6b458f1..7ab4691a 100644 --- a/codex-rs/core/src/tools/context.rs +++ b/codex-rs/core/src/tools/context.rs @@ -14,12 +14,17 @@ use mcp_types::CallToolResult; use std::borrow::Cow; use std::collections::HashMap; use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::Mutex; -pub struct ToolInvocation<'a> { - pub session: &'a Session, - pub turn: &'a TurnContext, - pub tracker: &'a mut TurnDiffTracker, - pub sub_id: &'a str, +pub type SharedTurnDiffTracker = Arc>; + +#[derive(Clone)] +pub struct ToolInvocation { + pub session: Arc, + pub turn: Arc, + pub tracker: SharedTurnDiffTracker, + pub sub_id: String, pub call_id: String, pub tool_name: String, pub payload: ToolPayload, diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 1ad8a95d..d85ac8b7 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; use std::collections::HashMap; +use std::sync::Arc; use crate::client_common::tools::FreeformTool; use crate::client_common::tools::FreeformToolFormat; @@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler { ) } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, turn, @@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler { let content = handle_container_exec_with_params( tool_name.as_str(), exec_params, - session, - turn, - tracker, - sub_id.to_string(), + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + sub_id.clone(), call_id.clone(), ) .await?; diff --git a/codex-rs/core/src/tools/handlers/exec_stream.rs b/codex-rs/core/src/tools/handlers/exec_stream.rs index db9d4b0b..7f14c673 100644 --- a/codex-rs/core/src/tools/handlers/exec_stream.rs +++ b/codex-rs/core/src/tools/handlers/exec_stream.rs @@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler { ToolKind::Function } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, tool_name, diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs index 17eae7ea..ba95a5ea 100644 --- a/codex-rs/core/src/tools/handlers/mcp.rs +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -16,10 +16,7 @@ impl ToolHandler for McpHandler { ToolKind::Mcp } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, sub_id, @@ -45,8 +42,8 @@ impl ToolHandler for McpHandler { let arguments_str = raw_arguments; let response = handle_mcp_tool_call( - session, - sub_id, + session.as_ref(), + &sub_id, call_id.clone(), server, tool, diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs index af410b99..caa778c9 100644 --- a/codex-rs/core/src/tools/handlers/mod.rs +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -4,6 +4,7 @@ mod mcp; mod plan; mod read_file; mod shell; +mod test_sync; mod unified_exec; mod view_image; @@ -15,5 +16,6 @@ pub use mcp::McpHandler; pub use plan::PlanHandler; pub use read_file::ReadFileHandler; pub use shell::ShellHandler; +pub use test_sync::TestSyncHandler; pub use unified_exec::UnifiedExecHandler; pub use view_image::ViewImageHandler; diff --git a/codex-rs/core/src/tools/handlers/plan.rs b/codex-rs/core/src/tools/handlers/plan.rs index f5208030..386933a5 100644 --- a/codex-rs/core/src/tools/handlers/plan.rs +++ b/codex-rs/core/src/tools/handlers/plan.rs @@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler { ToolKind::Function } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, sub_id, @@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler { } }; - let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?; + let content = + handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?; Ok(ToolOutput::Function { content, diff --git a/codex-rs/core/src/tools/handlers/read_file.rs b/codex-rs/core/src/tools/handlers/read_file.rs index 4988593b..38b76f28 100644 --- a/codex-rs/core/src/tools/handlers/read_file.rs +++ b/codex-rs/core/src/tools/handlers/read_file.rs @@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler { ToolKind::Function } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { payload, .. } = invocation; let arguments = match payload { diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index fbcb493e..1b27a58e 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use codex_protocol::models::ShellToolCallParams; +use std::sync::Arc; use crate::codex::TurnContext; use crate::exec::ExecParams; @@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler { ) } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, turn, @@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler { "failed to parse function arguments: {e:?}" )) })?; - let exec_params = Self::to_exec_params(params, turn); + let exec_params = Self::to_exec_params(params, turn.as_ref()); let content = handle_container_exec_with_params( tool_name.as_str(), exec_params, - session, - turn, - tracker, - sub_id.to_string(), + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + sub_id.clone(), call_id.clone(), ) .await?; @@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler { }) } ToolPayload::LocalShell { params } => { - let exec_params = Self::to_exec_params(params, turn); + let exec_params = Self::to_exec_params(params, turn.as_ref()); let content = handle_container_exec_with_params( tool_name.as_str(), exec_params, - session, - turn, - tracker, - sub_id.to_string(), + Arc::clone(&session), + Arc::clone(&turn), + Arc::clone(&tracker), + sub_id.clone(), call_id.clone(), ) .await?; diff --git a/codex-rs/core/src/tools/handlers/test_sync.rs b/codex-rs/core/src/tools/handlers/test_sync.rs new file mode 100644 index 00000000..e340ab47 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/test_sync.rs @@ -0,0 +1,158 @@ +use std::collections::HashMap; +use std::collections::hash_map::Entry; +use std::sync::Arc; +use std::sync::OnceLock; +use std::time::Duration; + +use async_trait::async_trait; +use serde::Deserialize; +use tokio::sync::Barrier; +use tokio::time::sleep; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct TestSyncHandler; + +const DEFAULT_TIMEOUT_MS: u64 = 1_000; + +static BARRIERS: OnceLock>> = OnceLock::new(); + +struct BarrierState { + barrier: Arc, + participants: usize, +} + +#[derive(Debug, Deserialize)] +struct BarrierArgs { + id: String, + participants: usize, + #[serde(default = "default_timeout_ms")] + timeout_ms: u64, +} + +#[derive(Debug, Deserialize)] +struct TestSyncArgs { + #[serde(default)] + sleep_before_ms: Option, + #[serde(default)] + sleep_after_ms: Option, + #[serde(default)] + barrier: Option, +} + +fn default_timeout_ms() -> u64 { + DEFAULT_TIMEOUT_MS +} + +fn barrier_map() -> &'static tokio::sync::Mutex> { + BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new())) +} + +#[async_trait] +impl ToolHandler for TestSyncHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle(&self, invocation: ToolInvocation) -> Result { + let ToolInvocation { payload, .. } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "test_sync_tool handler received unsupported payload".to_string(), + )); + } + }; + + let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {err:?}" + )) + })?; + + if let Some(delay) = args.sleep_before_ms + && delay > 0 + { + sleep(Duration::from_millis(delay)).await; + } + + if let Some(barrier) = args.barrier { + wait_on_barrier(barrier).await?; + } + + if let Some(delay) = args.sleep_after_ms + && delay > 0 + { + sleep(Duration::from_millis(delay)).await; + } + + Ok(ToolOutput::Function { + content: "ok".to_string(), + success: Some(true), + }) + } +} + +async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> { + if args.participants == 0 { + return Err(FunctionCallError::RespondToModel( + "barrier participants must be greater than zero".to_string(), + )); + } + + if args.timeout_ms == 0 { + return Err(FunctionCallError::RespondToModel( + "barrier timeout must be greater than zero".to_string(), + )); + } + + let barrier_id = args.id.clone(); + let barrier = { + let mut map = barrier_map().lock().await; + match map.entry(barrier_id.clone()) { + Entry::Occupied(entry) => { + let state = entry.get(); + if state.participants != args.participants { + let existing = state.participants; + return Err(FunctionCallError::RespondToModel(format!( + "barrier {barrier_id} already registered with {existing} participants" + ))); + } + state.barrier.clone() + } + Entry::Vacant(entry) => { + let barrier = Arc::new(Barrier::new(args.participants)); + entry.insert(BarrierState { + barrier: barrier.clone(), + participants: args.participants, + }); + barrier + } + } + }; + + let timeout = Duration::from_millis(args.timeout_ms); + let wait_result = tokio::time::timeout(timeout, barrier.wait()) + .await + .map_err(|_| { + FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string()) + })?; + + if wait_result.is_leader() { + let mut map = barrier_map().lock().await; + if let Some(state) = map.get(&barrier_id) + && Arc::ptr_eq(&state.barrier, &barrier) + { + map.remove(&barrier_id); + } + } + + Ok(()) +} diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index 7175afb9..ce47dded 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler { ) } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, payload, .. } = invocation; diff --git a/codex-rs/core/src/tools/handlers/view_image.rs b/codex-rs/core/src/tools/handlers/view_image.rs index 4ebfd8f3..2396e19c 100644 --- a/codex-rs/core/src/tools/handlers/view_image.rs +++ b/codex-rs/core/src/tools/handlers/view_image.rs @@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler { ToolKind::Function } - async fn handle( - &self, - invocation: ToolInvocation<'_>, - ) -> Result { + async fn handle(&self, invocation: ToolInvocation) -> Result { let ToolInvocation { session, turn, diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs index 4bd08c66..691c6dc0 100644 --- a/codex-rs/core/src/tools/mod.rs +++ b/codex-rs/core/src/tools/mod.rs @@ -1,5 +1,6 @@ pub mod context; pub(crate) mod handlers; +pub mod parallel; pub mod registry; pub mod router; pub mod spec; @@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec; use crate::function_tool::FunctionCallError; use crate::tools::context::ApplyPatchCommandContext; use crate::tools::context::ExecCommandContext; -use crate::turn_diff_tracker::TurnDiffTracker; +use crate::tools::context::SharedTurnDiffTracker; use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_protocol::protocol::AskForApproval; @@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary; use codex_utils_string::take_last_bytes_at_char_boundary; pub use router::ToolRouter; use serde::Serialize; +use std::sync::Arc; use tracing::trace; // Model-formatting limits: clients get full streams; only content sent to the model is truncated. @@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str = pub(crate) async fn handle_container_exec_with_params( tool_name: &str, params: ExecParams, - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, + sess: Arc, + turn_context: Arc, + turn_diff_tracker: SharedTurnDiffTracker, sub_id: String, call_id: String, ) -> Result { @@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params( // check if this was a patch, and apply it if so let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { MaybeApplyPatchVerified::Body(changes) => { - match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await { + match apply_patch::apply_patch( + sess.as_ref(), + turn_context.as_ref(), + &sub_id, + &call_id, + changes, + ) + .await + { InternalApplyPatchInvocation::Output(item) => return item, InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { Some(apply_patch_exec) @@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params( let output_result = sess .run_exec_with_events( - turn_diff_tracker, + turn_diff_tracker.clone(), prepared_exec, turn_context.approval_policy, ) diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs new file mode 100644 index 00000000..ff4104d0 --- /dev/null +++ b/codex-rs/core/src/tools/parallel.rs @@ -0,0 +1,137 @@ +use std::sync::Arc; + +use tokio::task::JoinHandle; + +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::error::CodexErr; +use crate::function_tool::FunctionCallError; +use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::router::ToolCall; +use crate::tools::router::ToolRouter; +use codex_protocol::models::ResponseInputItem; + +use crate::codex::ProcessedResponseItem; + +struct PendingToolCall { + index: usize, + handle: JoinHandle>, +} + +pub(crate) struct ToolCallRuntime { + router: Arc, + session: Arc, + turn_context: Arc, + tracker: SharedTurnDiffTracker, + sub_id: String, + pending_calls: Vec, +} + +impl ToolCallRuntime { + pub(crate) fn new( + router: Arc, + session: Arc, + turn_context: Arc, + tracker: SharedTurnDiffTracker, + sub_id: String, + ) -> Self { + Self { + router, + session, + turn_context, + tracker, + sub_id, + pending_calls: Vec::new(), + } + } + + pub(crate) async fn handle_tool_call( + &mut self, + call: ToolCall, + output_index: usize, + output: &mut [ProcessedResponseItem], + ) -> Result<(), CodexErr> { + let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); + if supports_parallel { + self.spawn_parallel(call, output_index); + } else { + self.resolve_pending(output).await?; + let response = self.dispatch_serial(call).await?; + let slot = output.get_mut(output_index).ok_or_else(|| { + CodexErr::Fatal(format!("tool output index {output_index} out of bounds")) + })?; + slot.response = Some(response); + } + + Ok(()) + } + + pub(crate) fn abort_all(&mut self) { + while let Some(pending) = self.pending_calls.pop() { + pending.handle.abort(); + } + } + + pub(crate) async fn resolve_pending( + &mut self, + output: &mut [ProcessedResponseItem], + ) -> Result<(), CodexErr> { + while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() { + match handle.await { + Ok(Ok(response)) => { + if let Some(slot) = output.get_mut(index) { + slot.response = Some(response); + } + } + Ok(Err(FunctionCallError::Fatal(message))) => { + self.abort_all(); + return Err(CodexErr::Fatal(message)); + } + Ok(Err(other)) => { + self.abort_all(); + return Err(CodexErr::Fatal(other.to_string())); + } + Err(join_err) => { + self.abort_all(); + return Err(CodexErr::Fatal(format!( + "tool task failed to join: {join_err}" + ))); + } + } + } + + Ok(()) + } + + fn spawn_parallel(&mut self, call: ToolCall, index: usize) { + let router = Arc::clone(&self.router); + let session = Arc::clone(&self.session); + let turn = Arc::clone(&self.turn_context); + let tracker = Arc::clone(&self.tracker); + let sub_id = self.sub_id.clone(); + let handle = tokio::spawn(async move { + router + .dispatch_tool_call(session, turn, tracker, sub_id, call) + .await + }); + self.pending_calls.push(PendingToolCall { index, handle }); + } + + async fn dispatch_serial(&self, call: ToolCall) -> Result { + match self + .router + .dispatch_tool_call( + Arc::clone(&self.session), + Arc::clone(&self.turn_context), + Arc::clone(&self.tracker), + self.sub_id.clone(), + call, + ) + .await + { + Ok(response) => Ok(response), + Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), + Err(other) => Err(CodexErr::Fatal(other.to_string())), + } + } +} diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 7c7b1d25..c44cdbd8 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync { ) } - async fn handle(&self, invocation: ToolInvocation<'_>) - -> Result; + async fn handle(&self, invocation: ToolInvocation) -> Result; } pub struct ToolRegistry { @@ -57,9 +56,9 @@ impl ToolRegistry { // } // } - pub async fn dispatch<'a>( + pub async fn dispatch( &self, - invocation: ToolInvocation<'a>, + invocation: ToolInvocation, ) -> Result { let tool_name = invocation.tool_name.clone(); let call_id_owned = invocation.call_id.clone(); @@ -137,9 +136,24 @@ impl ToolRegistry { } } +#[derive(Debug, Clone)] +pub struct ConfiguredToolSpec { + pub spec: ToolSpec, + pub supports_parallel_tool_calls: bool, +} + +impl ConfiguredToolSpec { + pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self { + Self { + spec, + supports_parallel_tool_calls, + } + } +} + pub struct ToolRegistryBuilder { handlers: HashMap>, - specs: Vec, + specs: Vec, } impl ToolRegistryBuilder { @@ -151,7 +165,16 @@ impl ToolRegistryBuilder { } pub fn push_spec(&mut self, spec: ToolSpec) { - self.specs.push(spec); + self.push_spec_with_parallel_support(spec, false); + } + + pub fn push_spec_with_parallel_support( + &mut self, + spec: ToolSpec, + supports_parallel_tool_calls: bool, + ) { + self.specs + .push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls)); } pub fn register_handler(&mut self, name: impl Into, handler: Arc) { @@ -183,7 +206,7 @@ impl ToolRegistryBuilder { // } // } - pub fn build(self) -> (Vec, ToolRegistry) { + pub fn build(self) -> (Vec, ToolRegistry) { let registry = ToolRegistry::new(self.handlers); (self.specs, registry) } diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 6ec62e20..fa6e38a4 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -1,15 +1,17 @@ use std::collections::HashMap; +use std::sync::Arc; use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::codex::TurnContext; use crate::function_tool::FunctionCallError; +use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolPayload; +use crate::tools::registry::ConfiguredToolSpec; use crate::tools::registry::ToolRegistry; use crate::tools::spec::ToolsConfig; use crate::tools::spec::build_specs; -use crate::turn_diff_tracker::TurnDiffTracker; use codex_protocol::models::LocalShellAction; use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; @@ -24,7 +26,7 @@ pub struct ToolCall { pub struct ToolRouter { registry: ToolRegistry, - specs: Vec, + specs: Vec, } impl ToolRouter { @@ -34,11 +36,22 @@ impl ToolRouter { ) -> Self { let builder = build_specs(config, mcp_tools); let (specs, registry) = builder.build(); + Self { registry, specs } } - pub fn specs(&self) -> &[ToolSpec] { - &self.specs + pub fn specs(&self) -> Vec { + self.specs + .iter() + .map(|config| config.spec.clone()) + .collect() + } + + pub fn tool_supports_parallel(&self, tool_name: &str) -> bool { + self.specs + .iter() + .filter(|config| config.supports_parallel_tool_calls) + .any(|config| config.spec.name() == tool_name) } pub fn build_tool_call( @@ -118,10 +131,10 @@ impl ToolRouter { pub async fn dispatch_tool_call( &self, - session: &Session, - turn: &TurnContext, - tracker: &mut TurnDiffTracker, - sub_id: &str, + session: Arc, + turn: Arc, + tracker: SharedTurnDiffTracker, + sub_id: String, call: ToolCall, ) -> Result { let ToolCall { diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs index 5ea5b6a9..51124d41 100644 --- a/codex-rs/core/src/tools/spec.rs +++ b/codex-rs/core/src/tools/spec.rs @@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec { }) } +fn create_test_sync_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "sleep_before_ms".to_string(), + JsonSchema::Number { + description: Some("Optional delay in milliseconds before any other action".to_string()), + }, + ); + properties.insert( + "sleep_after_ms".to_string(), + JsonSchema::Number { + description: Some( + "Optional delay in milliseconds after completing the barrier".to_string(), + ), + }, + ); + + let mut barrier_properties = BTreeMap::new(); + barrier_properties.insert( + "id".to_string(), + JsonSchema::String { + description: Some( + "Identifier shared by concurrent calls that should rendezvous".to_string(), + ), + }, + ); + barrier_properties.insert( + "participants".to_string(), + JsonSchema::Number { + description: Some( + "Number of tool calls that must arrive before the barrier opens".to_string(), + ), + }, + ); + barrier_properties.insert( + "timeout_ms".to_string(), + JsonSchema::Number { + description: Some("Maximum time in milliseconds to wait at the barrier".to_string()), + }, + ); + + properties.insert( + "barrier".to_string(), + JsonSchema::Object { + properties: barrier_properties, + required: Some(vec!["id".to_string(), "participants".to_string()]), + additional_properties: Some(false.into()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "test_sync_tool".to_string(), + description: "Internal synchronization helper used by Codex integration tests.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: None, + additional_properties: Some(false.into()), + }, + }) +} + fn create_read_file_tool() -> ToolSpec { let mut properties = BTreeMap::new(); properties.insert( @@ -507,6 +569,7 @@ pub(crate) fn build_specs( use crate::tools::handlers::PlanHandler; use crate::tools::handlers::ReadFileHandler; use crate::tools::handlers::ShellHandler; + use crate::tools::handlers::TestSyncHandler; use crate::tools::handlers::UnifiedExecHandler; use crate::tools::handlers::ViewImageHandler; use std::sync::Arc; @@ -573,16 +636,26 @@ pub(crate) fn build_specs( .any(|tool| tool == "read_file") { let read_file_handler = Arc::new(ReadFileHandler); - builder.push_spec(create_read_file_tool()); + builder.push_spec_with_parallel_support(create_read_file_tool(), true); builder.register_handler("read_file", read_file_handler); } + if config + .experimental_supported_tools + .iter() + .any(|tool| tool == "test_sync_tool") + { + let test_sync_handler = Arc::new(TestSyncHandler); + builder.push_spec_with_parallel_support(create_test_sync_tool(), true); + builder.register_handler("test_sync_tool", test_sync_handler); + } + if config.web_search_request { builder.push_spec(ToolSpec::WebSearch {}); } if config.include_view_image_tool { - builder.push_spec(create_view_image_tool()); + builder.push_spec_with_parallel_support(create_view_image_tool(), true); builder.register_handler("view_image", view_image_handler); } @@ -610,20 +683,25 @@ pub(crate) fn build_specs( mod tests { use crate::client_common::tools::FreeformTool; use crate::model_family::find_family_for_model; + use crate::tools::registry::ConfiguredToolSpec; use mcp_types::ToolInputSchema; use pretty_assertions::assert_eq; use super::*; - fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) { + fn tool_name(tool: &ToolSpec) -> &str { + match tool { + ToolSpec::Function(ResponsesApiTool { name, .. }) => name, + ToolSpec::LocalShell {} => "local_shell", + ToolSpec::WebSearch {} => "web_search", + ToolSpec::Freeform(FreeformTool { name, .. }) => name, + } + } + + fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) { let tool_names = tools .iter() - .map(|tool| match tool { - ToolSpec::Function(ResponsesApiTool { name, .. }) => name, - ToolSpec::LocalShell {} => "local_shell", - ToolSpec::WebSearch {} => "web_search", - ToolSpec::Freeform(FreeformTool { name, .. }) => name, - }) + .map(|tool| tool_name(&tool.spec)) .collect::>(); assert_eq!( @@ -639,6 +717,16 @@ mod tests { } } + fn find_tool<'a>( + tools: &'a [ConfiguredToolSpec], + expected_name: &str, + ) -> &'a ConfiguredToolSpec { + tools + .iter() + .find(|tool| tool_name(&tool.spec) == expected_name) + .unwrap_or_else(|| panic!("expected tool {expected_name}")) + } + #[test] fn test_build_specs() { let model_family = find_family_for_model("codex-mini-latest") @@ -680,6 +768,53 @@ mod tests { ); } + #[test] + #[ignore] + fn test_parallel_support_flags() { + let model_family = find_family_for_model("gpt-5-codex") + .expect("codex-mini-latest should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: false, + use_streamable_shell_tool: false, + include_view_image_tool: false, + experimental_unified_exec_tool: true, + }); + let (tools, _) = build_specs(&config, None).build(); + + assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls); + assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls); + } + + #[test] + fn test_test_model_family_includes_sync_tool() { + let model_family = find_family_for_model("test-gpt-5-codex") + .expect("test-gpt-5-codex should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: false, + use_streamable_shell_tool: false, + include_view_image_tool: false, + experimental_unified_exec_tool: false, + }); + let (tools, _) = build_specs(&config, None).build(); + + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "test_sync_tool") + ); + assert!( + tools + .iter() + .any(|tool| tool_name(&tool.spec) == "read_file") + ); + } + #[test] fn test_build_specs_mcp_tools() { let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); @@ -742,7 +877,7 @@ mod tests { ); assert_eq!( - tools[3], + tools[3].spec, ToolSpec::Function(ResponsesApiTool { name: "test_server/do_something_cool".to_string(), parameters: JsonSchema::Object { @@ -911,7 +1046,7 @@ mod tests { ); assert_eq!( - tools[4], + tools[4].spec, ToolSpec::Function(ResponsesApiTool { name: "dash/search".to_string(), parameters: JsonSchema::Object { @@ -977,7 +1112,7 @@ mod tests { ], ); assert_eq!( - tools[4], + tools[4].spec, ToolSpec::Function(ResponsesApiTool { name: "dash/paginate".to_string(), parameters: JsonSchema::Object { @@ -1041,7 +1176,7 @@ mod tests { ], ); assert_eq!( - tools[4], + tools[4].spec, ToolSpec::Function(ResponsesApiTool { name: "dash/tags".to_string(), parameters: JsonSchema::Object { @@ -1108,7 +1243,7 @@ mod tests { ], ); assert_eq!( - tools[4], + tools[4].spec, ToolSpec::Function(ResponsesApiTool { name: "dash/value".to_string(), parameters: JsonSchema::Object { @@ -1213,7 +1348,7 @@ mod tests { ); assert_eq!( - tools[4], + tools[4].spec, ToolSpec::Function(ResponsesApiTool { name: "test_server/do_something_cool".to_string(), parameters: JsonSchema::Object { diff --git a/codex-rs/core/tests/suite/abort_tasks.rs b/codex-rs/core/tests/suite/abort_tasks.rs index 368f7f74..5122f661 100644 --- a/codex-rs/core/tests/suite/abort_tasks.rs +++ b/codex-rs/core/tests/suite/abort_tasks.rs @@ -3,14 +3,14 @@ use std::time::Duration; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; +use core_test_support::responses::ev_completed; use core_test_support::responses::ev_function_call; -use core_test_support::responses::mount_sse_once_match; +use core_test_support::responses::mount_sse_once; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event_with_timeout; use serde_json::json; -use wiremock::matchers::body_string_contains; /// Integration test: spawn a long‑running shell tool via a mocked Responses SSE /// function call, then interrupt the session and expect TurnAborted. @@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() { "timeout_ms": 60_000 }) .to_string(); - let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]); + let body = sse(vec![ + ev_function_call("call_sleep", "shell", &args), + ev_completed("done"), + ]); let server = start_mock_server().await; - mount_sse_once_match(&server, body_string_contains("start sleep"), body).await; + mount_sse_once(&server, body).await; let codex = test_codex().build(&server).await.unwrap().codex; diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 27bddb7e..2abbb6fc 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -24,6 +24,7 @@ mod shell_serialization; mod stream_error_allows_next_turn; mod stream_no_completed; mod tool_harness; +mod tool_parallelism; mod tools; mod unified_exec; mod user_notification; diff --git a/codex-rs/core/tests/suite/tool_parallelism.rs b/codex-rs/core/tests/suite/tool_parallelism.rs new file mode 100644 index 00000000..e667df43 --- /dev/null +++ b/codex-rs/core/tests/suite/tool_parallelism.rs @@ -0,0 +1,178 @@ +#![cfg(not(target_os = "windows"))] +#![allow(clippy::unwrap_used)] + +use std::time::Duration; +use std::time::Instant; + +use codex_core::model_family::find_family_for_model; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use serde_json::json; + +async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> { + let session_model = test.session_configured.model.clone(); + + test.codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: prompt.into(), + }], + final_output_json_schema: None, + cwd: test.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + Ok(()) +} + +async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result { + let start = Instant::now(); + run_turn(test, prompt).await?; + Ok(start.elapsed()) +} + +#[allow(clippy::expect_used)] +async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result { + let mut builder = test_codex().with_config(|config| { + config.model = "test-gpt-5-codex".to_string(); + config.model_family = + find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family"); + }); + builder.build(server).await +} + +fn assert_parallel_duration(actual: Duration) { + assert!( + actual < Duration::from_millis(500), + "expected parallel execution to finish quickly, got {actual:?}" + ); +} + +fn assert_serial_duration(actual: Duration) { + assert!( + actual >= Duration::from_millis(500), + "expected serial execution to take longer, got {actual:?}" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let test = build_codex_with_test_tool(&server).await?; + + let parallel_args = json!({ + "sleep_after_ms": 300, + "barrier": { + "id": "parallel-test-sync", + "participants": 2, + "timeout_ms": 1_000, + } + }) + .to_string(); + + let first_response = sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call("call-1", "test_sync_tool", ¶llel_args), + ev_function_call("call-2", "test_sync_tool", ¶llel_args), + ev_completed("resp-1"), + ]); + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + mount_sse_sequence(&server, vec![first_response, second_response]).await; + + let duration = run_turn_and_measure(&test, "exercise sync tool").await?; + assert_parallel_duration(duration); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn non_parallel_tools_run_serially() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let test = test_codex().build(&server).await?; + + let shell_args = json!({ + "command": ["/bin/sh", "-c", "sleep 0.3"], + "timeout_ms": 1_000, + }); + let args_one = serde_json::to_string(&shell_args)?; + let args_two = serde_json::to_string(&shell_args)?; + + let first_response = sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call("call-1", "shell", &args_one), + ev_function_call("call-2", "shell", &args_two), + ev_completed("resp-1"), + ]); + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + mount_sse_sequence(&server, vec![first_response, second_response]).await; + + let duration = run_turn_and_measure(&test, "run shell twice").await?; + assert_serial_duration(duration); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let test = build_codex_with_test_tool(&server).await?; + + let sync_args = json!({ + "sleep_after_ms": 300 + }) + .to_string(); + let shell_args = serde_json::to_string(&json!({ + "command": ["/bin/sh", "-c", "sleep 0.3"], + "timeout_ms": 1_000, + }))?; + + let first_response = sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call("call-1", "test_sync_tool", &sync_args), + ev_function_call("call-2", "shell", &shell_args), + ev_completed("resp-1"), + ]); + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + mount_sse_sequence(&server, vec![first_response, second_response]).await; + + let duration = run_turn_and_measure(&test, "mix tools").await?; + assert_serial_duration(duration); + + Ok(()) +}