feat: parallel tool calls (#4663)
Add parallel tool calls. This is configurable at model level and tool level
This commit is contained in:
@@ -227,7 +227,7 @@ impl ModelClient {
|
||||
input: &input_with_instructions,
|
||||
tools: &tools_json,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: prompt.parallel_tool_calls,
|
||||
reasoning,
|
||||
store: azure_workaround,
|
||||
stream: true,
|
||||
|
||||
@@ -33,6 +33,9 @@ pub struct Prompt {
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
|
||||
@@ -288,6 +291,17 @@ pub(crate) mod tools {
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
@@ -433,7 +447,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -474,7 +488,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
@@ -510,7 +524,7 @@ mod tests {
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
|
||||
@@ -100,7 +100,9 @@ use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::format_exec_output_str;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -818,7 +820,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_begin(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
exec_command_context: ExecCommandContext,
|
||||
) {
|
||||
let ExecCommandContext {
|
||||
@@ -834,7 +836,10 @@ impl Session {
|
||||
user_explicitly_approved_this_action,
|
||||
changes,
|
||||
}) => {
|
||||
turn_diff_tracker.on_patch_begin(&changes);
|
||||
{
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.on_patch_begin(&changes);
|
||||
}
|
||||
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id,
|
||||
@@ -861,7 +866,7 @@ impl Session {
|
||||
|
||||
async fn on_exec_command_end(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
output: &ExecToolCallOutput,
|
||||
@@ -909,7 +914,10 @@ impl Session {
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
if is_apply_patch {
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -926,7 +934,7 @@ impl Session {
|
||||
/// Returns the output of the exec tool call.
|
||||
pub(crate) async fn run_exec_with_events(
|
||||
&self,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prepared: PreparedExec,
|
||||
approval_policy: AskForApproval,
|
||||
) -> Result<ExecToolCallOutput, ExecError> {
|
||||
@@ -935,7 +943,7 @@ impl Session {
|
||||
let sub_id = context.sub_id.clone();
|
||||
let call_id = context.call_id.clone();
|
||||
|
||||
self.on_exec_command_begin(turn_diff_tracker, context.clone())
|
||||
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||
.await;
|
||||
|
||||
let result = self
|
||||
@@ -1644,7 +1652,7 @@ pub(crate) async fn run_task(
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let mut auto_compact_recently_attempted = false;
|
||||
|
||||
loop {
|
||||
@@ -1692,9 +1700,9 @@ pub(crate) async fn run_task(
|
||||
})
|
||||
.collect();
|
||||
match run_turn(
|
||||
&sess,
|
||||
turn_context.as_ref(),
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
)
|
||||
@@ -1917,18 +1925,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(mcp_tools),
|
||||
));
|
||||
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
tools: router.specs().to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
@@ -1936,10 +1953,10 @@ async fn run_turn(
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
match try_run_turn(
|
||||
&router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&sub_id,
|
||||
&prompt,
|
||||
)
|
||||
@@ -1950,7 +1967,7 @@ async fn run_turn(
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&sub_id, turn_context).await;
|
||||
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
@@ -1999,9 +2016,9 @@ async fn run_turn(
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
pub(crate) struct ProcessedResponseItem {
|
||||
pub(crate) item: ResponseItem,
|
||||
pub(crate) response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -2011,10 +2028,10 @@ struct TurnRunResult {
|
||||
}
|
||||
|
||||
async fn try_run_turn(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
@@ -2085,24 +2102,34 @@ async fn try_run_turn(
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.to_string(),
|
||||
);
|
||||
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let Some(event) = event else {
|
||||
// Channel closed without yielding a final Completed event or explicit error.
|
||||
// Treat as a disconnected stream so the caller can retry.
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
let event = match event {
|
||||
Some(event) => event,
|
||||
None => {
|
||||
tool_runtime.abort_all();
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
tool_runtime.abort_all();
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
@@ -2112,16 +2139,66 @@ async fn try_run_turn(
|
||||
match event {
|
||||
ResponseEvent::Created => {}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let response = handle_response_item(
|
||||
router,
|
||||
sess,
|
||||
turn_context,
|
||||
turn_diff_tracker,
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
let index = output.len();
|
||||
output.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: None,
|
||||
});
|
||||
tool_runtime
|
||||
.handle_tool_call(call, index, output.as_mut_slice())
|
||||
.await?;
|
||||
}
|
||||
Ok(None) => {
|
||||
let response = handle_non_tool_response_item(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
sub_id,
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
output.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(message)) => {
|
||||
let response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
output.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => {
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
let _ = sess
|
||||
@@ -2141,10 +2218,15 @@ async fn try_run_turn(
|
||||
response_id: _,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
||||
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
let event = Event {
|
||||
@@ -2203,88 +2285,40 @@ async fn try_run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_item(
|
||||
router: &crate::tools::ToolRouter,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
async fn handle_non_tool_response_item(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: &str,
|
||||
item: ResponseItem,
|
||||
) -> CodexResult<Option<ResponseInputItem>> {
|
||||
debug!(?item, "Output item");
|
||||
|
||||
match ToolRouter::build_tool_call(sess, item.clone()) {
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
match router
|
||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(Some(response)),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
match &item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => {
|
||||
let msgs = match &item {
|
||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||
trace!("suppressing assistant Message in review mode");
|
||||
Vec::new()
|
||||
}
|
||||
_ => map_response_item_to_event_messages(
|
||||
&item,
|
||||
sess.show_raw_agent_reasoning(),
|
||||
),
|
||||
};
|
||||
for msg in msgs {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||
debug!("unexpected tool output from stream");
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
turn_context
|
||||
.client
|
||||
.get_otel_event_manager()
|
||||
.log_tool_failed("local_shell", msg);
|
||||
error!(msg);
|
||||
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg.to_string(),
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
||||
call_id: String::new(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: msg,
|
||||
success: None,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||
@@ -2927,13 +2961,10 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn fatal_tool_error_stops_turn_and_reports_error() {
|
||||
let (session, turn_context, _rx) = make_session_and_context_with_rx();
|
||||
let session_ref = session.as_ref();
|
||||
let turn_context_ref = turn_context.as_ref();
|
||||
let router = ToolRouter::from_config(
|
||||
&turn_context_ref.tools_config,
|
||||
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
|
||||
&turn_context.tools_config,
|
||||
Some(session.services.mcp_connection_manager.list_all_tools()),
|
||||
);
|
||||
let mut tracker = TurnDiffTracker::new();
|
||||
let item = ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
@@ -2942,22 +2973,26 @@ mod tests {
|
||||
input: "{}".to_string(),
|
||||
};
|
||||
|
||||
let err = handle_response_item(
|
||||
&router,
|
||||
session_ref,
|
||||
turn_context_ref,
|
||||
&mut tracker,
|
||||
"sub-id",
|
||||
item,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
|
||||
.expect("build tool call")
|
||||
.expect("tool call present");
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let err = router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
tracker,
|
||||
"sub-id".to_string(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
.expect_err("expected fatal error");
|
||||
|
||||
match err {
|
||||
CodexErr::Fatal(message) => {
|
||||
FunctionCallError::Fatal(message) => {
|
||||
assert_eq!(message, "tool shell invoked with incompatible payload");
|
||||
}
|
||||
other => panic!("expected CodexErr::Fatal, got {other:?}"),
|
||||
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3071,9 +3106,11 @@ mod tests {
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let (session, mut turn_context) = make_session_and_context();
|
||||
let (session, mut turn_context_raw) = make_session_and_context();
|
||||
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
||||
turn_context.approval_policy = AskForApproval::OnFailure;
|
||||
turn_context_raw.approval_policy = AskForApproval::OnFailure;
|
||||
let session = Arc::new(session);
|
||||
let mut turn_context = Arc::new(turn_context_raw);
|
||||
|
||||
let params = ExecParams {
|
||||
command: if cfg!(windows) {
|
||||
@@ -3101,7 +3138,7 @@ mod tests {
|
||||
..params.clone()
|
||||
};
|
||||
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let tool_name = "shell";
|
||||
let sub_id = "test-sub".to_string();
|
||||
@@ -3110,9 +3147,9 @@ mod tests {
|
||||
let resp = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id,
|
||||
call_id,
|
||||
)
|
||||
@@ -3131,14 +3168,16 @@ mod tests {
|
||||
|
||||
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
||||
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
||||
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
Arc::get_mut(&mut turn_context)
|
||||
.expect("unique turn context Arc")
|
||||
.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
let resp2 = handle_container_exec_with_params(
|
||||
tool_name,
|
||||
params2,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
"test-sub".to_string(),
|
||||
"test-call-2".to_string(),
|
||||
)
|
||||
|
||||
@@ -35,6 +35,10 @@ pub struct ModelFamily {
|
||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||
pub uses_local_shell_tool: bool,
|
||||
|
||||
/// Whether this model supports parallel tool calls when using the
|
||||
/// Responses API.
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
@@ -58,6 +62,7 @@ macro_rules! model_family {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
@@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("test-gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec![
|
||||
"read_file".to_string(),
|
||||
"test_sync_tool".to_string()
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
)
|
||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
@@ -110,6 +127,8 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
// experimental_supported_tools: vec!["read_file".to_string()],
|
||||
// supports_parallel_tool_calls: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
@@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
uses_local_shell_tool: false,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
|
||||
@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct ToolInvocation<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub tracker: &'a mut TurnDiffTracker,
|
||||
pub sub_id: &'a str,
|
||||
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
tool_name,
|
||||
|
||||
@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
|
||||
ToolKind::Mcp
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
|
||||
let arguments_str = raw_arguments;
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session,
|
||||
sub_id,
|
||||
session.as_ref(),
|
||||
&sub_id,
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
|
||||
@@ -4,6 +4,7 @@ mod mcp;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
mod view_image;
|
||||
|
||||
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
pub use unified_exec::UnifiedExecHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
|
||||
@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
sub_id,
|
||||
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
|
||||
@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecParams;
|
||||
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
|
||||
})
|
||||
}
|
||||
ToolPayload::LocalShell { params } => {
|
||||
let exec_params = Self::to_exec_params(params, turn);
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
let content = handle_container_exec_with_params(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
sub_id.to_string(),
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
Arc::clone(&tracker),
|
||||
sub_id.clone(),
|
||||
call_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct TestSyncHandler;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
|
||||
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||
|
||||
struct BarrierState {
|
||||
barrier: Arc<Barrier>,
|
||||
participants: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarrierArgs {
|
||||
id: String,
|
||||
participants: usize,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestSyncArgs {
|
||||
#[serde(default)]
|
||||
sleep_before_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
sleep_after_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
barrier: Option<BarrierArgs>,
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
DEFAULT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for TestSyncHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"test_sync_tool handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(delay) = args.sleep_before_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
if let Some(barrier) = args.barrier {
|
||||
wait_on_barrier(barrier).await?;
|
||||
}
|
||||
|
||||
if let Some(delay) = args.sleep_after_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||
if args.participants == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier participants must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.timeout_ms == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier timeout must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let barrier_id = args.id.clone();
|
||||
let barrier = {
|
||||
let mut map = barrier_map().lock().await;
|
||||
match map.entry(barrier_id.clone()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let state = entry.get();
|
||||
if state.participants != args.participants {
|
||||
let existing = state.participants;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"barrier {barrier_id} already registered with {existing} participants"
|
||||
)));
|
||||
}
|
||||
state.barrier.clone()
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let barrier = Arc::new(Barrier::new(args.participants));
|
||||
entry.insert(BarrierState {
|
||||
barrier: barrier.clone(),
|
||||
participants: args.participants,
|
||||
});
|
||||
barrier
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(args.timeout_ms);
|
||||
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||
})?;
|
||||
|
||||
if wait_result.is_leader() {
|
||||
let mut map = barrier_map().lock().await;
|
||||
if let Some(state) = map.get(&barrier_id)
|
||||
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||
{
|
||||
map.remove(&barrier_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session, payload, ..
|
||||
} = invocation;
|
||||
|
||||
@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
invocation: ToolInvocation<'_>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod context;
|
||||
pub(crate) mod handlers;
|
||||
pub mod parallel;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
pub mod spec;
|
||||
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ApplyPatchCommandContext;
|
||||
use crate::tools::context::ExecCommandContext;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
pub use router::ToolRouter;
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tracing::trace;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||
pub(crate) async fn handle_container_exec_with_params(
|
||||
tool_name: &str,
|
||||
params: ExecParams,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
turn_diff_tracker: &mut TurnDiffTracker,
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
||||
match apply_patch::apply_patch(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
&sub_id,
|
||||
&call_id,
|
||||
changes,
|
||||
)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => return item,
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||
Some(apply_patch_exec)
|
||||
@@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
||||
|
||||
let output_result = sess
|
||||
.run_exec_with_events(
|
||||
turn_diff_tracker,
|
||||
turn_diff_tracker.clone(),
|
||||
prepared_exec,
|
||||
turn_context.approval_policy,
|
||||
)
|
||||
|
||||
137
codex-rs/core/src/tools/parallel.rs
Normal file
137
codex-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
use crate::codex::ProcessedResponseItem;
|
||||
|
||||
struct PendingToolCall {
|
||||
index: usize,
|
||||
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
pending_calls: Vec<PendingToolCall>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
pending_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_tool_call(
|
||||
&mut self,
|
||||
call: ToolCall,
|
||||
output_index: usize,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
if supports_parallel {
|
||||
self.spawn_parallel(call, output_index);
|
||||
} else {
|
||||
self.resolve_pending(output).await?;
|
||||
let response = self.dispatch_serial(call).await?;
|
||||
let slot = output.get_mut(output_index).ok_or_else(|| {
|
||||
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
|
||||
})?;
|
||||
slot.response = Some(response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn abort_all(&mut self) {
|
||||
while let Some(pending) = self.pending_calls.pop() {
|
||||
pending.handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn resolve_pending(
|
||||
&mut self,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => {
|
||||
if let Some(slot) = output.get_mut(index) {
|
||||
slot.response = Some(response);
|
||||
}
|
||||
}
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
Ok(Err(other)) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(other.to_string()));
|
||||
}
|
||||
Err(join_err) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to join: {join_err}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
});
|
||||
self.pending_calls.push(PendingToolCall { index, handle });
|
||||
}
|
||||
|
||||
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
|
||||
match self
|
||||
.router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&self.session),
|
||||
Arc::clone(&self.turn_context),
|
||||
Arc::clone(&self.tracker),
|
||||
self.sub_id.clone(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => Err(CodexErr::Fatal(other.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||||
-> Result<ToolOutput, FunctionCallError>;
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
@@ -57,9 +56,9 @@ impl ToolRegistry {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch<'a>(
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation<'a>,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
@@ -137,9 +136,24 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfiguredToolSpec {
|
||||
pub spec: ToolSpec,
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ConfiguredToolSpec {
|
||||
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.specs.push(spec);
|
||||
self.push_spec_with_parallel_support(spec, false);
|
||||
}
|
||||
|
||||
pub fn push_spec_with_parallel_support(
|
||||
&mut self,
|
||||
spec: ToolSpec,
|
||||
supports_parallel_tool_calls: bool,
|
||||
) {
|
||||
self.specs
|
||||
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||||
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -24,7 +26,7 @@ pub struct ToolCall {
|
||||
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ToolSpec>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRouter {
|
||||
@@ -34,11 +36,22 @@ impl ToolRouter {
|
||||
) -> Self {
|
||||
let builder = build_specs(config, mcp_tools);
|
||||
let (specs, registry) = builder.build();
|
||||
|
||||
Self { registry, specs }
|
||||
}
|
||||
|
||||
pub fn specs(&self) -> &[ToolSpec] {
|
||||
&self.specs
|
||||
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||
self.specs
|
||||
.iter()
|
||||
.map(|config| config.spec.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
@@ -118,10 +131,10 @@ impl ToolRouter {
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
tracker: &mut TurnDiffTracker,
|
||||
sub_id: &str,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
|
||||
@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
|
||||
})
|
||||
}
|
||||
|
||||
fn create_test_sync_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"sleep_before_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Optional delay in milliseconds before any other action".to_string()),
|
||||
},
|
||||
);
|
||||
properties.insert(
|
||||
"sleep_after_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Optional delay in milliseconds after completing the barrier".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
let mut barrier_properties = BTreeMap::new();
|
||||
barrier_properties.insert(
|
||||
"id".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(
|
||||
"Identifier shared by concurrent calls that should rendezvous".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"participants".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some(
|
||||
"Number of tool calls that must arrive before the barrier opens".to_string(),
|
||||
),
|
||||
},
|
||||
);
|
||||
barrier_properties.insert(
|
||||
"timeout_ms".to_string(),
|
||||
JsonSchema::Number {
|
||||
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
properties.insert(
|
||||
"barrier".to_string(),
|
||||
JsonSchema::Object {
|
||||
properties: barrier_properties,
|
||||
required: Some(vec!["id".to_string(), "participants".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_sync_tool".to_string(),
|
||||
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: None,
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn create_read_file_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
@@ -507,6 +569,7 @@ pub(crate) fn build_specs(
|
||||
use crate::tools::handlers::PlanHandler;
|
||||
use crate::tools::handlers::ReadFileHandler;
|
||||
use crate::tools::handlers::ShellHandler;
|
||||
use crate::tools::handlers::TestSyncHandler;
|
||||
use crate::tools::handlers::UnifiedExecHandler;
|
||||
use crate::tools::handlers::ViewImageHandler;
|
||||
use std::sync::Arc;
|
||||
@@ -573,16 +636,26 @@ pub(crate) fn build_specs(
|
||||
.any(|tool| tool == "read_file")
|
||||
{
|
||||
let read_file_handler = Arc::new(ReadFileHandler);
|
||||
builder.push_spec(create_read_file_tool());
|
||||
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
|
||||
builder.register_handler("read_file", read_file_handler);
|
||||
}
|
||||
|
||||
if config
|
||||
.experimental_supported_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "test_sync_tool")
|
||||
{
|
||||
let test_sync_handler = Arc::new(TestSyncHandler);
|
||||
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
|
||||
builder.register_handler("test_sync_tool", test_sync_handler);
|
||||
}
|
||||
|
||||
if config.web_search_request {
|
||||
builder.push_spec(ToolSpec::WebSearch {});
|
||||
}
|
||||
|
||||
if config.include_view_image_tool {
|
||||
builder.push_spec(create_view_image_tool());
|
||||
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
|
||||
builder.register_handler("view_image", view_image_handler);
|
||||
}
|
||||
|
||||
@@ -610,20 +683,25 @@ pub(crate) fn build_specs(
|
||||
mod tests {
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use mcp_types::ToolInputSchema;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
|
||||
fn tool_name(tool: &ToolSpec) -> &str {
|
||||
match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
|
||||
let tool_names = tools
|
||||
.iter()
|
||||
.map(|tool| match tool {
|
||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||
})
|
||||
.map(|tool| tool_name(&tool.spec))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(
|
||||
@@ -639,6 +717,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn find_tool<'a>(
|
||||
tools: &'a [ConfiguredToolSpec],
|
||||
expected_name: &str,
|
||||
) -> &'a ConfiguredToolSpec {
|
||||
tools
|
||||
.iter()
|
||||
.find(|tool| tool_name(&tool.spec) == expected_name)
|
||||
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs() {
|
||||
let model_family = find_family_for_model("codex-mini-latest")
|
||||
@@ -680,6 +768,53 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn test_parallel_support_flags() {
|
||||
let model_family = find_family_for_model("gpt-5-codex")
|
||||
.expect("codex-mini-latest should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: true,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
|
||||
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_test_model_family_includes_sync_tool() {
|
||||
let model_family = find_family_for_model("test-gpt-5-codex")
|
||||
.expect("test-gpt-5-codex should be a valid model family");
|
||||
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: false,
|
||||
include_web_search_request: false,
|
||||
use_streamable_shell_tool: false,
|
||||
include_view_image_tool: false,
|
||||
experimental_unified_exec_tool: false,
|
||||
});
|
||||
let (tools, _) = build_specs(&config, None).build();
|
||||
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
|
||||
);
|
||||
assert!(
|
||||
tools
|
||||
.iter()
|
||||
.any(|tool| tool_name(&tool.spec) == "read_file")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_specs_mcp_tools() {
|
||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||
@@ -742,7 +877,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[3],
|
||||
tools[3].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -911,7 +1046,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/search".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -977,7 +1112,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/paginate".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1041,7 +1176,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/tags".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1108,7 +1243,7 @@ mod tests {
|
||||
],
|
||||
);
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "dash/value".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
@@ -1213,7 +1348,7 @@ mod tests {
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
tools[4],
|
||||
tools[4].spec,
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "test_server/do_something_cool".to_string(),
|
||||
parameters: JsonSchema::Object {
|
||||
|
||||
Reference in New Issue
Block a user