feat: parallel tool calls (#4663)

Add parallel tool calls. This is configurable at model level and tool
level
This commit is contained in:
jif-oai
2025-10-05 17:10:49 +01:00
committed by GitHub
parent 3203862167
commit dc3c6bf62a
23 changed files with 961 additions and 244 deletions

View File

@@ -227,7 +227,7 @@ impl ModelClient {
input: &input_with_instructions, input: &input_with_instructions,
tools: &tools_json, tools: &tools_json,
tool_choice: "auto", tool_choice: "auto",
parallel_tool_calls: false, parallel_tool_calls: prompt.parallel_tool_calls,
reasoning, reasoning,
store: azure_workaround, store: azure_workaround,
stream: true, stream: true,

View File

@@ -33,6 +33,9 @@ pub struct Prompt {
/// external MCP servers. /// external MCP servers.
pub(crate) tools: Vec<ToolSpec>, pub(crate) tools: Vec<ToolSpec>,
/// Whether parallel tool calls are permitted for this prompt.
pub(crate) parallel_tool_calls: bool,
/// Optional override for the built-in BASE_INSTRUCTIONS. /// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>, pub base_instructions_override: Option<String>,
@@ -288,6 +291,17 @@ pub(crate) mod tools {
Freeform(FreeformTool), Freeform(FreeformTool),
} }
impl ToolSpec {
pub(crate) fn name(&self) -> &str {
match self {
ToolSpec::Function(tool) => tool.name.as_str(),
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(tool) => tool.name.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FreeformTool { pub struct FreeformTool {
pub(crate) name: String, pub(crate) name: String,
@@ -433,7 +447,7 @@ mod tests {
input: &input, input: &input,
tools: &tools, tools: &tools,
tool_choice: "auto", tool_choice: "auto",
parallel_tool_calls: false, parallel_tool_calls: true,
reasoning: None, reasoning: None,
store: false, store: false,
stream: true, stream: true,
@@ -474,7 +488,7 @@ mod tests {
input: &input, input: &input,
tools: &tools, tools: &tools,
tool_choice: "auto", tool_choice: "auto",
parallel_tool_calls: false, parallel_tool_calls: true,
reasoning: None, reasoning: None,
store: false, store: false,
stream: true, stream: true,
@@ -510,7 +524,7 @@ mod tests {
input: &input, input: &input,
tools: &tools, tools: &tools,
tool_choice: "auto", tool_choice: "auto",
parallel_tool_calls: false, parallel_tool_calls: true,
reasoning: None, reasoning: None,
store: false, store: false,
stream: true, stream: true,

View File

@@ -100,7 +100,9 @@ use crate::tasks::CompactTask;
use crate::tasks::RegularTask; use crate::tasks::RegularTask;
use crate::tasks::ReviewTask; use crate::tasks::ReviewTask;
use crate::tools::ToolRouter; use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::format_exec_output_str; use crate::tools::format_exec_output_str;
use crate::tools::parallel::ToolCallRuntime;
use crate::turn_diff_tracker::TurnDiffTracker; use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager; use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions; use crate::user_instructions::UserInstructions;
@@ -818,7 +820,7 @@ impl Session {
async fn on_exec_command_begin( async fn on_exec_command_begin(
&self, &self,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
exec_command_context: ExecCommandContext, exec_command_context: ExecCommandContext,
) { ) {
let ExecCommandContext { let ExecCommandContext {
@@ -834,7 +836,10 @@ impl Session {
user_explicitly_approved_this_action, user_explicitly_approved_this_action,
changes, changes,
}) => { }) => {
turn_diff_tracker.on_patch_begin(&changes); {
let mut tracker = turn_diff_tracker.lock().await;
tracker.on_patch_begin(&changes);
}
EventMsg::PatchApplyBegin(PatchApplyBeginEvent { EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id, call_id,
@@ -861,7 +866,7 @@ impl Session {
async fn on_exec_command_end( async fn on_exec_command_end(
&self, &self,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str, sub_id: &str,
call_id: &str, call_id: &str,
output: &ExecToolCallOutput, output: &ExecToolCallOutput,
@@ -909,7 +914,10 @@ impl Session {
// If this is an apply_patch, after we emit the end patch, emit a second event // If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one. // with the full turn diff if there is one.
if is_apply_patch { if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff(); let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff { if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event { let event = Event {
@@ -926,7 +934,7 @@ impl Session {
/// Returns the output of the exec tool call. /// Returns the output of the exec tool call.
pub(crate) async fn run_exec_with_events( pub(crate) async fn run_exec_with_events(
&self, &self,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
prepared: PreparedExec, prepared: PreparedExec,
approval_policy: AskForApproval, approval_policy: AskForApproval,
) -> Result<ExecToolCallOutput, ExecError> { ) -> Result<ExecToolCallOutput, ExecError> {
@@ -935,7 +943,7 @@ impl Session {
let sub_id = context.sub_id.clone(); let sub_id = context.sub_id.clone();
let call_id = context.call_id.clone(); let call_id = context.call_id.clone();
self.on_exec_command_begin(turn_diff_tracker, context.clone()) self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
.await; .await;
let result = self let result = self
@@ -1644,7 +1652,7 @@ pub(crate) async fn run_task(
let mut last_agent_message: Option<String> = None; let mut last_agent_message: Option<String> = None;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains // Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn. // many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new(); let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut auto_compact_recently_attempted = false; let mut auto_compact_recently_attempted = false;
loop { loop {
@@ -1692,9 +1700,9 @@ pub(crate) async fn run_task(
}) })
.collect(); .collect();
match run_turn( match run_turn(
&sess, Arc::clone(&sess),
turn_context.as_ref(), Arc::clone(&turn_context),
&mut turn_diff_tracker, Arc::clone(&turn_diff_tracker),
sub_id.clone(), sub_id.clone(),
turn_input, turn_input,
) )
@@ -1917,18 +1925,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
} }
async fn run_turn( async fn run_turn(
sess: &Session, sess: Arc<Session>,
turn_context: &TurnContext, turn_context: Arc<TurnContext>,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String, sub_id: String,
input: Vec<ResponseItem>, input: Vec<ResponseItem>,
) -> CodexResult<TurnRunResult> { ) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools)); let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(mcp_tools),
));
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
let parallel_tool_calls = model_supports_parallel;
let prompt = Prompt { let prompt = Prompt {
input, input,
tools: router.specs().to_vec(), tools: router.specs(),
parallel_tool_calls,
base_instructions_override: turn_context.base_instructions.clone(), base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(), output_schema: turn_context.final_output_json_schema.clone(),
}; };
@@ -1936,10 +1953,10 @@ async fn run_turn(
let mut retries = 0; let mut retries = 0;
loop { loop {
match try_run_turn( match try_run_turn(
&router, Arc::clone(&router),
sess, Arc::clone(&sess),
turn_context, Arc::clone(&turn_context),
turn_diff_tracker, Arc::clone(&turn_diff_tracker),
&sub_id, &sub_id,
&prompt, &prompt,
) )
@@ -1950,7 +1967,7 @@ async fn run_turn(
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::Fatal(_)) => return Err(e),
Err(e @ CodexErr::ContextWindowExceeded) => { Err(e @ CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&sub_id, turn_context).await; sess.set_total_tokens_full(&sub_id, &turn_context).await;
return Err(e); return Err(e);
} }
Err(CodexErr::UsageLimitReached(e)) => { Err(CodexErr::UsageLimitReached(e)) => {
@@ -1999,9 +2016,9 @@ async fn run_turn(
/// "handled" such that it produces a `ResponseInputItem` that needs to be /// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn. /// sent back to the model on the next turn.
#[derive(Debug)] #[derive(Debug)]
struct ProcessedResponseItem { pub(crate) struct ProcessedResponseItem {
item: ResponseItem, pub(crate) item: ResponseItem,
response: Option<ResponseInputItem>, pub(crate) response: Option<ResponseInputItem>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -2011,10 +2028,10 @@ struct TurnRunResult {
} }
async fn try_run_turn( async fn try_run_turn(
router: &crate::tools::ToolRouter, router: Arc<ToolRouter>,
sess: &Session, sess: Arc<Session>,
turn_context: &TurnContext, turn_context: Arc<TurnContext>,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str, sub_id: &str,
prompt: &Prompt, prompt: &Prompt,
) -> CodexResult<TurnRunResult> { ) -> CodexResult<TurnRunResult> {
@@ -2085,24 +2102,34 @@ async fn try_run_turn(
let mut stream = turn_context.client.clone().stream(&prompt).await?; let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut output = Vec::new(); let mut output = Vec::new();
let mut tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.to_string(),
);
loop { loop {
// Poll the next item from the model stream. We must inspect *both* Ok and Err // Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before // cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic. // `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await; let event = stream.next().await;
let Some(event) = event else { let event = match event {
// Channel closed without yielding a final Completed event or explicit error. Some(event) => event,
// Treat as a disconnected stream so the caller can retry. None => {
return Err(CodexErr::Stream( tool_runtime.abort_all();
"stream closed before response.completed".into(), return Err(CodexErr::Stream(
None, "stream closed before response.completed".into(),
)); None,
));
}
}; };
let event = match event { let event = match event {
Ok(ev) => ev, Ok(ev) => ev,
Err(e) => { Err(e) => {
tool_runtime.abort_all();
// Propagate the underlying stream error to the caller (run_turn), which // Propagate the underlying stream error to the caller (run_turn), which
// will apply the configured `stream_max_retries` policy. // will apply the configured `stream_max_retries` policy.
return Err(e); return Err(e);
@@ -2112,16 +2139,66 @@ async fn try_run_turn(
match event { match event {
ResponseEvent::Created => {} ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => { ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item( match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
router, Ok(Some(call)) => {
sess, let payload_preview = call.payload.log_payload().into_owned();
turn_context, tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
turn_diff_tracker, let index = output.len();
sub_id, output.push(ProcessedResponseItem {
item.clone(), item,
) response: None,
.await?; });
output.push(ProcessedResponseItem { item, response }); tool_runtime
.handle_tool_call(call, index, output.as_mut_slice())
.await?;
}
Ok(None) => {
let response = handle_non_tool_response_item(
Arc::clone(&sess),
Arc::clone(&turn_context),
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::RespondToModel(message)) => {
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: message,
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::Fatal(message)) => {
return Err(CodexErr::Fatal(message));
}
}
} }
ResponseEvent::WebSearchCallBegin { call_id } => { ResponseEvent::WebSearchCallBegin { call_id } => {
let _ = sess let _ = sess
@@ -2141,10 +2218,15 @@ async fn try_run_turn(
response_id: _, response_id: _,
token_usage, token_usage,
} => { } => {
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref()) sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await; .await;
let unified_diff = turn_diff_tracker.get_unified_diff(); tool_runtime.resolve_pending(output.as_mut_slice()).await?;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff { if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff }); let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event { let event = Event {
@@ -2203,88 +2285,40 @@ async fn try_run_turn(
} }
} }
async fn handle_response_item( async fn handle_non_tool_response_item(
router: &crate::tools::ToolRouter, sess: Arc<Session>,
sess: &Session, turn_context: Arc<TurnContext>,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str, sub_id: &str,
item: ResponseItem, item: ResponseItem,
) -> CodexResult<Option<ResponseInputItem>> { ) -> CodexResult<Option<ResponseInputItem>> {
debug!(?item, "Output item"); debug!(?item, "Output item");
match ToolRouter::build_tool_call(sess, item.clone()) { match &item {
Ok(Some(call)) => { ResponseItem::Message { .. }
let payload_preview = call.payload.log_payload().into_owned(); | ResponseItem::Reasoning { .. }
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); | ResponseItem::WebSearchCall { .. } => {
match router let msgs = match &item {
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call) ResponseItem::Message { .. } if turn_context.is_review_mode => {
.await trace!("suppressing assistant Message in review mode");
{ Vec::new()
Ok(response) => Ok(Some(response)), }
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"), };
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
} }
} }
Ok(None) => { ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
match &item { debug!("unexpected tool output from stream");
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => {
let msgs = match &item {
ResponseItem::Message { .. } if turn_context.is_review_mode => {
trace!("suppressing assistant Message in review mode");
Vec::new()
}
_ => map_response_item_to_event_messages(
&item,
sess.show_raw_agent_reasoning(),
),
};
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
}
}
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::CustomToolCallOutput { .. } => {
debug!("unexpected tool output from stream");
}
_ => {}
}
Ok(None)
} }
Err(FunctionCallError::MissingLocalShellCallId) => { _ => {}
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
}))
}
Err(FunctionCallError::RespondToModel(msg)) => {
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg,
success: None,
},
}))
}
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
} }
Ok(None)
} }
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> { pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
@@ -2927,13 +2961,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fatal_tool_error_stops_turn_and_reports_error() { async fn fatal_tool_error_stops_turn_and_reports_error() {
let (session, turn_context, _rx) = make_session_and_context_with_rx(); let (session, turn_context, _rx) = make_session_and_context_with_rx();
let session_ref = session.as_ref();
let turn_context_ref = turn_context.as_ref();
let router = ToolRouter::from_config( let router = ToolRouter::from_config(
&turn_context_ref.tools_config, &turn_context.tools_config,
Some(session_ref.services.mcp_connection_manager.list_all_tools()), Some(session.services.mcp_connection_manager.list_all_tools()),
); );
let mut tracker = TurnDiffTracker::new();
let item = ResponseItem::CustomToolCall { let item = ResponseItem::CustomToolCall {
id: None, id: None,
status: None, status: None,
@@ -2942,22 +2973,26 @@ mod tests {
input: "{}".to_string(), input: "{}".to_string(),
}; };
let err = handle_response_item( let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
&router, .expect("build tool call")
session_ref, .expect("tool call present");
turn_context_ref, let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
&mut tracker, let err = router
"sub-id", .dispatch_tool_call(
item, Arc::clone(&session),
) Arc::clone(&turn_context),
.await tracker,
.expect_err("expected fatal error"); "sub-id".to_string(),
call,
)
.await
.expect_err("expected fatal error");
match err { match err {
CodexErr::Fatal(message) => { FunctionCallError::Fatal(message) => {
assert_eq!(message, "tool shell invoked with incompatible payload"); assert_eq!(message, "tool shell invoked with incompatible payload");
} }
other => panic!("expected CodexErr::Fatal, got {other:?}"), other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
} }
} }
@@ -3071,9 +3106,11 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker; use crate::turn_diff_tracker::TurnDiffTracker;
use std::collections::HashMap; use std::collections::HashMap;
let (session, mut turn_context) = make_session_and_context(); let (session, mut turn_context_raw) = make_session_and_context();
// Ensure policy is NOT OnRequest so the early rejection path triggers // Ensure policy is NOT OnRequest so the early rejection path triggers
turn_context.approval_policy = AskForApproval::OnFailure; turn_context_raw.approval_policy = AskForApproval::OnFailure;
let session = Arc::new(session);
let mut turn_context = Arc::new(turn_context_raw);
let params = ExecParams { let params = ExecParams {
command: if cfg!(windows) { command: if cfg!(windows) {
@@ -3101,7 +3138,7 @@ mod tests {
..params.clone() ..params.clone()
}; };
let mut turn_diff_tracker = TurnDiffTracker::new(); let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let tool_name = "shell"; let tool_name = "shell";
let sub_id = "test-sub".to_string(); let sub_id = "test-sub".to_string();
@@ -3110,9 +3147,9 @@ mod tests {
let resp = handle_container_exec_with_params( let resp = handle_container_exec_with_params(
tool_name, tool_name,
params, params,
&session, Arc::clone(&session),
&turn_context, Arc::clone(&turn_context),
&mut turn_diff_tracker, Arc::clone(&turn_diff_tracker),
sub_id, sub_id,
call_id, call_id,
) )
@@ -3131,14 +3168,16 @@ mod tests {
// Now retry the same command WITHOUT escalated permissions; should succeed. // Now retry the same command WITHOUT escalated permissions; should succeed.
// Force DangerFullAccess to avoid platform sandbox dependencies in tests. // Force DangerFullAccess to avoid platform sandbox dependencies in tests.
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess; Arc::get_mut(&mut turn_context)
.expect("unique turn context Arc")
.sandbox_policy = SandboxPolicy::DangerFullAccess;
let resp2 = handle_container_exec_with_params( let resp2 = handle_container_exec_with_params(
tool_name, tool_name,
params2, params2,
&session, Arc::clone(&session),
&turn_context, Arc::clone(&turn_context),
&mut turn_diff_tracker, Arc::clone(&turn_diff_tracker),
"test-sub".to_string(), "test-sub".to_string(),
"test-call-2".to_string(), "test-call-2".to_string(),
) )

View File

@@ -35,6 +35,10 @@ pub struct ModelFamily {
// See https://platform.openai.com/docs/guides/tools-local-shell // See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool, pub uses_local_shell_tool: bool,
/// Whether this model supports parallel tool calls when using the
/// Responses API.
pub supports_parallel_tool_calls: bool,
/// Present if the model performs better when `apply_patch` is provided as /// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command /// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>, pub apply_patch_tool_type: Option<ApplyPatchToolType>,
@@ -58,6 +62,7 @@ macro_rules! model_family {
supports_reasoning_summaries: false, supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None, reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false, uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None, apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(), base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(), experimental_supported_tools: Vec::new(),
@@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true) model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("gpt-3.5") { } else if slug.starts_with("gpt-3.5") {
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true) model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("test-gpt-5-codex") {
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
experimental_supported_tools: vec![
"read_file".to_string(),
"test_sync_tool".to_string()
],
supports_parallel_tool_calls: true,
)
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") { } else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
model_family!( model_family!(
slug, slug, slug, slug,
@@ -110,6 +127,8 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
reasoning_summary_format: ReasoningSummaryFormat::Experimental, reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(), base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform), apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
// experimental_supported_tools: vec!["read_file".to_string()],
// supports_parallel_tool_calls: true,
) )
} else if slug.starts_with("gpt-5") { } else if slug.starts_with("gpt-5") {
model_family!( model_family!(
@@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
supports_reasoning_summaries: false, supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None, reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false, uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None, apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(), base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(), experimental_supported_tools: Vec::new(),

View File

@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct ToolInvocation<'a> { pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
pub session: &'a Session,
pub turn: &'a TurnContext, #[derive(Clone)]
pub tracker: &'a mut TurnDiffTracker, pub struct ToolInvocation {
pub sub_id: &'a str, pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub tracker: SharedTurnDiffTracker,
pub sub_id: String,
pub call_id: String, pub call_id: String,
pub tool_name: String, pub tool_name: String,
pub payload: ToolPayload, pub payload: ToolPayload,

View File

@@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::FreeformTool; use crate::client_common::tools::FreeformTool;
use crate::client_common::tools::FreeformToolFormat; use crate::client_common::tools::FreeformToolFormat;
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
) )
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
turn, turn,
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
let content = handle_container_exec_with_params( let content = handle_container_exec_with_params(
tool_name.as_str(), tool_name.as_str(),
exec_params, exec_params,
session, Arc::clone(&session),
turn, Arc::clone(&turn),
tracker, Arc::clone(&tracker),
sub_id.to_string(), sub_id.clone(),
call_id.clone(), call_id.clone(),
) )
.await?; .await?;

View File

@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
ToolKind::Function ToolKind::Function
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
tool_name, tool_name,

View File

@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
ToolKind::Mcp ToolKind::Mcp
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
sub_id, sub_id,
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
let arguments_str = raw_arguments; let arguments_str = raw_arguments;
let response = handle_mcp_tool_call( let response = handle_mcp_tool_call(
session, session.as_ref(),
sub_id, &sub_id,
call_id.clone(), call_id.clone(),
server, server,
tool, tool,

View File

@@ -4,6 +4,7 @@ mod mcp;
mod plan; mod plan;
mod read_file; mod read_file;
mod shell; mod shell;
mod test_sync;
mod unified_exec; mod unified_exec;
mod view_image; mod view_image;
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
pub use plan::PlanHandler; pub use plan::PlanHandler;
pub use read_file::ReadFileHandler; pub use read_file::ReadFileHandler;
pub use shell::ShellHandler; pub use shell::ShellHandler;
pub use test_sync::TestSyncHandler;
pub use unified_exec::UnifiedExecHandler; pub use unified_exec::UnifiedExecHandler;
pub use view_image::ViewImageHandler; pub use view_image::ViewImageHandler;

View File

@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
ToolKind::Function ToolKind::Function
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
sub_id, sub_id,
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
} }
}; };
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?; let content =
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
Ok(ToolOutput::Function { Ok(ToolOutput::Function {
content, content,

View File

@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
ToolKind::Function ToolKind::Function
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation; let ToolInvocation { payload, .. } = invocation;
let arguments = match payload { let arguments = match payload {

View File

@@ -1,5 +1,6 @@
use async_trait::async_trait; use async_trait::async_trait;
use codex_protocol::models::ShellToolCallParams; use codex_protocol::models::ShellToolCallParams;
use std::sync::Arc;
use crate::codex::TurnContext; use crate::codex::TurnContext;
use crate::exec::ExecParams; use crate::exec::ExecParams;
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
) )
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
turn, turn,
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
"failed to parse function arguments: {e:?}" "failed to parse function arguments: {e:?}"
)) ))
})?; })?;
let exec_params = Self::to_exec_params(params, turn); let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params( let content = handle_container_exec_with_params(
tool_name.as_str(), tool_name.as_str(),
exec_params, exec_params,
session, Arc::clone(&session),
turn, Arc::clone(&turn),
tracker, Arc::clone(&tracker),
sub_id.to_string(), sub_id.clone(),
call_id.clone(), call_id.clone(),
) )
.await?; .await?;
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
}) })
} }
ToolPayload::LocalShell { params } => { ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn); let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params( let content = handle_container_exec_with_params(
tool_name.as_str(), tool_name.as_str(),
exec_params, exec_params,
session, Arc::clone(&session),
turn, Arc::clone(&turn),
tracker, Arc::clone(&tracker),
sub_id.to_string(), sub_id.clone(),
call_id.clone(), call_id.clone(),
) )
.await?; .await?;

View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::sync::Barrier;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct TestSyncHandler;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
struct BarrierState {
barrier: Arc<Barrier>,
participants: usize,
}
#[derive(Debug, Deserialize)]
struct BarrierArgs {
id: String,
participants: usize,
#[serde(default = "default_timeout_ms")]
timeout_ms: u64,
}
#[derive(Debug, Deserialize)]
struct TestSyncArgs {
#[serde(default)]
sleep_before_ms: Option<u64>,
#[serde(default)]
sleep_after_ms: Option<u64>,
#[serde(default)]
barrier: Option<BarrierArgs>,
}
fn default_timeout_ms() -> u64 {
DEFAULT_TIMEOUT_MS
}
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
}
#[async_trait]
impl ToolHandler for TestSyncHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"test_sync_tool handler received unsupported payload".to_string(),
));
}
};
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
if let Some(delay) = args.sleep_before_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
if let Some(barrier) = args.barrier {
wait_on_barrier(barrier).await?;
}
if let Some(delay) = args.sleep_after_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
Ok(ToolOutput::Function {
content: "ok".to_string(),
success: Some(true),
})
}
}
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
if args.participants == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier participants must be greater than zero".to_string(),
));
}
if args.timeout_ms == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier timeout must be greater than zero".to_string(),
));
}
let barrier_id = args.id.clone();
let barrier = {
let mut map = barrier_map().lock().await;
match map.entry(barrier_id.clone()) {
Entry::Occupied(entry) => {
let state = entry.get();
if state.participants != args.participants {
let existing = state.participants;
return Err(FunctionCallError::RespondToModel(format!(
"barrier {barrier_id} already registered with {existing} participants"
)));
}
state.barrier.clone()
}
Entry::Vacant(entry) => {
let barrier = Arc::new(Barrier::new(args.participants));
entry.insert(BarrierState {
barrier: barrier.clone(),
participants: args.participants,
});
barrier
}
}
};
let timeout = Duration::from_millis(args.timeout_ms);
let wait_result = tokio::time::timeout(timeout, barrier.wait())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
})?;
if wait_result.is_leader() {
let mut map = barrier_map().lock().await;
if let Some(state) = map.get(&barrier_id)
&& Arc::ptr_eq(&state.barrier, &barrier)
{
map.remove(&barrier_id);
}
}
Ok(())
}

View File

@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
) )
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, payload, .. session, payload, ..
} = invocation; } = invocation;

View File

@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
ToolKind::Function ToolKind::Function
} }
async fn handle( async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { let ToolInvocation {
session, session,
turn, turn,

View File

@@ -1,5 +1,6 @@
pub mod context; pub mod context;
pub(crate) mod handlers; pub(crate) mod handlers;
pub mod parallel;
pub mod registry; pub mod registry;
pub mod router; pub mod router;
pub mod spec; pub mod spec;
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
use crate::function_tool::FunctionCallError; use crate::function_tool::FunctionCallError;
use crate::tools::context::ApplyPatchCommandContext; use crate::tools::context::ApplyPatchCommandContext;
use crate::tools::context::ExecCommandContext; use crate::tools::context::ExecCommandContext;
use crate::turn_diff_tracker::TurnDiffTracker; use crate::tools::context::SharedTurnDiffTracker;
use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::MaybeApplyPatchVerified;
use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_apply_patch::maybe_parse_apply_patch_verified;
use codex_protocol::protocol::AskForApproval; use codex_protocol::protocol::AskForApproval;
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
use codex_utils_string::take_last_bytes_at_char_boundary; use codex_utils_string::take_last_bytes_at_char_boundary;
pub use router::ToolRouter; pub use router::ToolRouter;
use serde::Serialize; use serde::Serialize;
use std::sync::Arc;
use tracing::trace; use tracing::trace;
// Model-formatting limits: clients get full streams; only content sent to the model is truncated. // Model-formatting limits: clients get full streams; only content sent to the model is truncated.
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
pub(crate) async fn handle_container_exec_with_params( pub(crate) async fn handle_container_exec_with_params(
tool_name: &str, tool_name: &str,
params: ExecParams, params: ExecParams,
sess: &Session, sess: Arc<Session>,
turn_context: &TurnContext, turn_context: Arc<TurnContext>,
turn_diff_tracker: &mut TurnDiffTracker, turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String, sub_id: String,
call_id: String, call_id: String,
) -> Result<String, FunctionCallError> { ) -> Result<String, FunctionCallError> {
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
// check if this was a patch, and apply it if so // check if this was a patch, and apply it if so
let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) { let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) {
MaybeApplyPatchVerified::Body(changes) => { MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await { match apply_patch::apply_patch(
sess.as_ref(),
turn_context.as_ref(),
&sub_id,
&call_id,
changes,
)
.await
{
InternalApplyPatchInvocation::Output(item) => return item, InternalApplyPatchInvocation::Output(item) => return item,
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
Some(apply_patch_exec) Some(apply_patch_exec)
@@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params(
let output_result = sess let output_result = sess
.run_exec_with_events( .run_exec_with_events(
turn_diff_tracker, turn_diff_tracker.clone(),
prepared_exec, prepared_exec,
turn_context.approval_policy, turn_context.approval_policy,
) )

View File

@@ -0,0 +1,137 @@
use std::sync::Arc;
use tokio::task::JoinHandle;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::ResponseInputItem;
use crate::codex::ProcessedResponseItem;
struct PendingToolCall {
index: usize,
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
}
pub(crate) struct ToolCallRuntime {
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
pending_calls: Vec<PendingToolCall>,
}
impl ToolCallRuntime {
pub(crate) fn new(
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
) -> Self {
Self {
router,
session,
turn_context,
tracker,
sub_id,
pending_calls: Vec::new(),
}
}
pub(crate) async fn handle_tool_call(
&mut self,
call: ToolCall,
output_index: usize,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
if supports_parallel {
self.spawn_parallel(call, output_index);
} else {
self.resolve_pending(output).await?;
let response = self.dispatch_serial(call).await?;
let slot = output.get_mut(output_index).ok_or_else(|| {
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
})?;
slot.response = Some(response);
}
Ok(())
}
pub(crate) fn abort_all(&mut self) {
while let Some(pending) = self.pending_calls.pop() {
pending.handle.abort();
}
}
pub(crate) async fn resolve_pending(
&mut self,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
match handle.await {
Ok(Ok(response)) => {
if let Some(slot) = output.get_mut(index) {
slot.response = Some(response);
}
}
Ok(Err(FunctionCallError::Fatal(message))) => {
self.abort_all();
return Err(CodexErr::Fatal(message));
}
Ok(Err(other)) => {
self.abort_all();
return Err(CodexErr::Fatal(other.to_string()));
}
Err(join_err) => {
self.abort_all();
return Err(CodexErr::Fatal(format!(
"tool task failed to join: {join_err}"
)));
}
}
}
Ok(())
}
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let sub_id = self.sub_id.clone();
let handle = tokio::spawn(async move {
router
.dispatch_tool_call(session, turn, tracker, sub_id, call)
.await
});
self.pending_calls.push(PendingToolCall { index, handle });
}
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
match self
.router
.dispatch_tool_call(
Arc::clone(&self.session),
Arc::clone(&self.turn_context),
Arc::clone(&self.tracker),
self.sub_id.clone(),
call,
)
.await
{
Ok(response) => Ok(response),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => Err(CodexErr::Fatal(other.to_string())),
}
}
}

View File

@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
) )
} }
async fn handle(&self, invocation: ToolInvocation<'_>) async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
-> Result<ToolOutput, FunctionCallError>;
} }
pub struct ToolRegistry { pub struct ToolRegistry {
@@ -57,9 +56,9 @@ impl ToolRegistry {
// } // }
// } // }
pub async fn dispatch<'a>( pub async fn dispatch(
&self, &self,
invocation: ToolInvocation<'a>, invocation: ToolInvocation,
) -> Result<ResponseInputItem, FunctionCallError> { ) -> Result<ResponseInputItem, FunctionCallError> {
let tool_name = invocation.tool_name.clone(); let tool_name = invocation.tool_name.clone();
let call_id_owned = invocation.call_id.clone(); let call_id_owned = invocation.call_id.clone();
@@ -137,9 +136,24 @@ impl ToolRegistry {
} }
} }
#[derive(Debug, Clone)]
pub struct ConfiguredToolSpec {
pub spec: ToolSpec,
pub supports_parallel_tool_calls: bool,
}
impl ConfiguredToolSpec {
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
Self {
spec,
supports_parallel_tool_calls,
}
}
}
pub struct ToolRegistryBuilder { pub struct ToolRegistryBuilder {
handlers: HashMap<String, Arc<dyn ToolHandler>>, handlers: HashMap<String, Arc<dyn ToolHandler>>,
specs: Vec<ToolSpec>, specs: Vec<ConfiguredToolSpec>,
} }
impl ToolRegistryBuilder { impl ToolRegistryBuilder {
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
} }
pub fn push_spec(&mut self, spec: ToolSpec) { pub fn push_spec(&mut self, spec: ToolSpec) {
self.specs.push(spec); self.push_spec_with_parallel_support(spec, false);
}
pub fn push_spec_with_parallel_support(
&mut self,
spec: ToolSpec,
supports_parallel_tool_calls: bool,
) {
self.specs
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
} }
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) { pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
// } // }
// } // }
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) { pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
let registry = ToolRegistry::new(self.handlers); let registry = ToolRegistry::new(self.handlers);
(self.specs, registry) (self.specs, registry)
} }

View File

@@ -1,15 +1,17 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::ToolSpec; use crate::client_common::tools::ToolSpec;
use crate::codex::Session; use crate::codex::Session;
use crate::codex::TurnContext; use crate::codex::TurnContext;
use crate::function_tool::FunctionCallError; use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolInvocation; use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload; use crate::tools::context::ToolPayload;
use crate::tools::registry::ConfiguredToolSpec;
use crate::tools::registry::ToolRegistry; use crate::tools::registry::ToolRegistry;
use crate::tools::spec::ToolsConfig; use crate::tools::spec::ToolsConfig;
use crate::tools::spec::build_specs; use crate::tools::spec::build_specs;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::models::LocalShellAction; use codex_protocol::models::LocalShellAction;
use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem; use codex_protocol::models::ResponseItem;
@@ -24,7 +26,7 @@ pub struct ToolCall {
pub struct ToolRouter { pub struct ToolRouter {
registry: ToolRegistry, registry: ToolRegistry,
specs: Vec<ToolSpec>, specs: Vec<ConfiguredToolSpec>,
} }
impl ToolRouter { impl ToolRouter {
@@ -34,11 +36,22 @@ impl ToolRouter {
) -> Self { ) -> Self {
let builder = build_specs(config, mcp_tools); let builder = build_specs(config, mcp_tools);
let (specs, registry) = builder.build(); let (specs, registry) = builder.build();
Self { registry, specs } Self { registry, specs }
} }
pub fn specs(&self) -> &[ToolSpec] { pub fn specs(&self) -> Vec<ToolSpec> {
&self.specs self.specs
.iter()
.map(|config| config.spec.clone())
.collect()
}
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
self.specs
.iter()
.filter(|config| config.supports_parallel_tool_calls)
.any(|config| config.spec.name() == tool_name)
} }
pub fn build_tool_call( pub fn build_tool_call(
@@ -118,10 +131,10 @@ impl ToolRouter {
pub async fn dispatch_tool_call( pub async fn dispatch_tool_call(
&self, &self,
session: &Session, session: Arc<Session>,
turn: &TurnContext, turn: Arc<TurnContext>,
tracker: &mut TurnDiffTracker, tracker: SharedTurnDiffTracker,
sub_id: &str, sub_id: String,
call: ToolCall, call: ToolCall,
) -> Result<ResponseInputItem, FunctionCallError> { ) -> Result<ResponseInputItem, FunctionCallError> {
let ToolCall { let ToolCall {

View File

@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
}) })
} }
fn create_test_sync_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"sleep_before_ms".to_string(),
JsonSchema::Number {
description: Some("Optional delay in milliseconds before any other action".to_string()),
},
);
properties.insert(
"sleep_after_ms".to_string(),
JsonSchema::Number {
description: Some(
"Optional delay in milliseconds after completing the barrier".to_string(),
),
},
);
let mut barrier_properties = BTreeMap::new();
barrier_properties.insert(
"id".to_string(),
JsonSchema::String {
description: Some(
"Identifier shared by concurrent calls that should rendezvous".to_string(),
),
},
);
barrier_properties.insert(
"participants".to_string(),
JsonSchema::Number {
description: Some(
"Number of tool calls that must arrive before the barrier opens".to_string(),
),
},
);
barrier_properties.insert(
"timeout_ms".to_string(),
JsonSchema::Number {
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
},
);
properties.insert(
"barrier".to_string(),
JsonSchema::Object {
properties: barrier_properties,
required: Some(vec!["id".to_string(), "participants".to_string()]),
additional_properties: Some(false.into()),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "test_sync_tool".to_string(),
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: None,
additional_properties: Some(false.into()),
},
})
}
fn create_read_file_tool() -> ToolSpec { fn create_read_file_tool() -> ToolSpec {
let mut properties = BTreeMap::new(); let mut properties = BTreeMap::new();
properties.insert( properties.insert(
@@ -507,6 +569,7 @@ pub(crate) fn build_specs(
use crate::tools::handlers::PlanHandler; use crate::tools::handlers::PlanHandler;
use crate::tools::handlers::ReadFileHandler; use crate::tools::handlers::ReadFileHandler;
use crate::tools::handlers::ShellHandler; use crate::tools::handlers::ShellHandler;
use crate::tools::handlers::TestSyncHandler;
use crate::tools::handlers::UnifiedExecHandler; use crate::tools::handlers::UnifiedExecHandler;
use crate::tools::handlers::ViewImageHandler; use crate::tools::handlers::ViewImageHandler;
use std::sync::Arc; use std::sync::Arc;
@@ -573,16 +636,26 @@ pub(crate) fn build_specs(
.any(|tool| tool == "read_file") .any(|tool| tool == "read_file")
{ {
let read_file_handler = Arc::new(ReadFileHandler); let read_file_handler = Arc::new(ReadFileHandler);
builder.push_spec(create_read_file_tool()); builder.push_spec_with_parallel_support(create_read_file_tool(), true);
builder.register_handler("read_file", read_file_handler); builder.register_handler("read_file", read_file_handler);
} }
if config
.experimental_supported_tools
.iter()
.any(|tool| tool == "test_sync_tool")
{
let test_sync_handler = Arc::new(TestSyncHandler);
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
builder.register_handler("test_sync_tool", test_sync_handler);
}
if config.web_search_request { if config.web_search_request {
builder.push_spec(ToolSpec::WebSearch {}); builder.push_spec(ToolSpec::WebSearch {});
} }
if config.include_view_image_tool { if config.include_view_image_tool {
builder.push_spec(create_view_image_tool()); builder.push_spec_with_parallel_support(create_view_image_tool(), true);
builder.register_handler("view_image", view_image_handler); builder.register_handler("view_image", view_image_handler);
} }
@@ -610,20 +683,25 @@ pub(crate) fn build_specs(
mod tests { mod tests {
use crate::client_common::tools::FreeformTool; use crate::client_common::tools::FreeformTool;
use crate::model_family::find_family_for_model; use crate::model_family::find_family_for_model;
use crate::tools::registry::ConfiguredToolSpec;
use mcp_types::ToolInputSchema; use mcp_types::ToolInputSchema;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use super::*; use super::*;
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) { fn tool_name(tool: &ToolSpec) -> &str {
match tool {
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
}
}
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
let tool_names = tools let tool_names = tools
.iter() .iter()
.map(|tool| match tool { .map(|tool| tool_name(&tool.spec))
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!( assert_eq!(
@@ -639,6 +717,16 @@ mod tests {
} }
} }
fn find_tool<'a>(
tools: &'a [ConfiguredToolSpec],
expected_name: &str,
) -> &'a ConfiguredToolSpec {
tools
.iter()
.find(|tool| tool_name(&tool.spec) == expected_name)
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
}
#[test] #[test]
fn test_build_specs() { fn test_build_specs() {
let model_family = find_family_for_model("codex-mini-latest") let model_family = find_family_for_model("codex-mini-latest")
@@ -680,6 +768,53 @@ mod tests {
); );
} }
#[test]
#[ignore]
fn test_parallel_support_flags() {
let model_family = find_family_for_model("gpt-5-codex")
.expect("codex-mini-latest should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: true,
});
let (tools, _) = build_specs(&config, None).build();
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
}
#[test]
fn test_test_model_family_includes_sync_tool() {
let model_family = find_family_for_model("test-gpt-5-codex")
.expect("test-gpt-5-codex should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: false,
});
let (tools, _) = build_specs(&config, None).build();
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
);
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "read_file")
);
}
#[test] #[test]
fn test_build_specs_mcp_tools() { fn test_build_specs_mcp_tools() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
@@ -742,7 +877,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
tools[3], tools[3].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(), name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {
@@ -911,7 +1046,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
tools[4], tools[4].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "dash/search".to_string(), name: "dash/search".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {
@@ -977,7 +1112,7 @@ mod tests {
], ],
); );
assert_eq!( assert_eq!(
tools[4], tools[4].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "dash/paginate".to_string(), name: "dash/paginate".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {
@@ -1041,7 +1176,7 @@ mod tests {
], ],
); );
assert_eq!( assert_eq!(
tools[4], tools[4].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "dash/tags".to_string(), name: "dash/tags".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {
@@ -1108,7 +1243,7 @@ mod tests {
], ],
); );
assert_eq!( assert_eq!(
tools[4], tools[4].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "dash/value".to_string(), name: "dash/value".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {
@@ -1213,7 +1348,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
tools[4], tools[4].spec,
ToolSpec::Function(ResponsesApiTool { ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(), name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object { parameters: JsonSchema::Object {

View File

@@ -3,14 +3,14 @@ use std::time::Duration;
use codex_core::protocol::EventMsg; use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem; use codex_core::protocol::InputItem;
use codex_core::protocol::Op; use codex_core::protocol::Op;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_sse_once_match; use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse; use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server; use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::test_codex; use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event_with_timeout; use core_test_support::wait_for_event_with_timeout;
use serde_json::json; use serde_json::json;
use wiremock::matchers::body_string_contains;
/// Integration test: spawn a longrunning shell tool via a mocked Responses SSE /// Integration test: spawn a longrunning shell tool via a mocked Responses SSE
/// function call, then interrupt the session and expect TurnAborted. /// function call, then interrupt the session and expect TurnAborted.
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
"timeout_ms": 60_000 "timeout_ms": 60_000
}) })
.to_string(); .to_string();
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]); let body = sse(vec![
ev_function_call("call_sleep", "shell", &args),
ev_completed("done"),
]);
let server = start_mock_server().await; let server = start_mock_server().await;
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await; mount_sse_once(&server, body).await;
let codex = test_codex().build(&server).await.unwrap().codex; let codex = test_codex().build(&server).await.unwrap().codex;

View File

@@ -24,6 +24,7 @@ mod shell_serialization;
mod stream_error_allows_next_turn; mod stream_error_allows_next_turn;
mod stream_no_completed; mod stream_no_completed;
mod tool_harness; mod tool_harness;
mod tool_parallelism;
mod tools; mod tools;
mod unified_exec; mod unified_exec;
mod user_notification; mod user_notification;

View File

@@ -0,0 +1,178 @@
#![cfg(not(target_os = "windows"))]
#![allow(clippy::unwrap_used)]
use std::time::Duration;
use std::time::Instant;
use codex_core::model_family::find_family_for_model;
use codex_core::protocol::AskForApproval;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_protocol::config_types::ReasoningSummary;
use core_test_support::responses::ev_assistant_message;
use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
use core_test_support::test_codex::TestCodex;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event;
use serde_json::json;
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
let session_model = test.session_configured.model.clone();
test.codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: prompt.into(),
}],
final_output_json_schema: None,
cwd: test.cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: session_model,
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
Ok(())
}
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
let start = Instant::now();
run_turn(test, prompt).await?;
Ok(start.elapsed())
}
#[allow(clippy::expect_used)]
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
let mut builder = test_codex().with_config(|config| {
config.model = "test-gpt-5-codex".to_string();
config.model_family =
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
});
builder.build(server).await
}
fn assert_parallel_duration(actual: Duration) {
assert!(
actual < Duration::from_millis(500),
"expected parallel execution to finish quickly, got {actual:?}"
);
}
fn assert_serial_duration(actual: Duration) {
assert!(
actual >= Duration::from_millis(500),
"expected serial execution to take longer, got {actual:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = build_codex_with_test_tool(&server).await?;
let parallel_args = json!({
"sleep_after_ms": 300,
"barrier": {
"id": "parallel-test-sync",
"participants": 2,
"timeout_ms": 1_000,
}
})
.to_string();
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "test_sync_tool", &parallel_args),
ev_function_call("call-2", "test_sync_tool", &parallel_args),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
assert_parallel_duration(duration);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = test_codex().build(&server).await?;
let shell_args = json!({
"command": ["/bin/sh", "-c", "sleep 0.3"],
"timeout_ms": 1_000,
});
let args_one = serde_json::to_string(&shell_args)?;
let args_two = serde_json::to_string(&shell_args)?;
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "shell", &args_one),
ev_function_call("call-2", "shell", &args_two),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "run shell twice").await?;
assert_serial_duration(duration);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let test = build_codex_with_test_tool(&server).await?;
let sync_args = json!({
"sleep_after_ms": 300
})
.to_string();
let shell_args = serde_json::to_string(&json!({
"command": ["/bin/sh", "-c", "sleep 0.3"],
"timeout_ms": 1_000,
}))?;
let first_response = sse(vec![
json!({"type": "response.created", "response": {"id": "resp-1"}}),
ev_function_call("call-1", "test_sync_tool", &sync_args),
ev_function_call("call-2", "shell", &shell_args),
ev_completed("resp-1"),
]);
let second_response = sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
mount_sse_sequence(&server, vec![first_response, second_response]).await;
let duration = run_turn_and_measure(&test, "mix tools").await?;
assert_serial_duration(duration);
Ok(())
}