feat: parallel tool calls (#4663)
Add parallel tool calls. This is configurable at model level and tool level
This commit is contained in:
@@ -227,7 +227,7 @@ impl ModelClient {
|
|||||||
input: &input_with_instructions,
|
input: &input_with_instructions,
|
||||||
tools: &tools_json,
|
tools: &tools_json,
|
||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: prompt.parallel_tool_calls,
|
||||||
reasoning,
|
reasoning,
|
||||||
store: azure_workaround,
|
store: azure_workaround,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
|||||||
@@ -33,6 +33,9 @@ pub struct Prompt {
|
|||||||
/// external MCP servers.
|
/// external MCP servers.
|
||||||
pub(crate) tools: Vec<ToolSpec>,
|
pub(crate) tools: Vec<ToolSpec>,
|
||||||
|
|
||||||
|
/// Whether parallel tool calls are permitted for this prompt.
|
||||||
|
pub(crate) parallel_tool_calls: bool,
|
||||||
|
|
||||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||||
pub base_instructions_override: Option<String>,
|
pub base_instructions_override: Option<String>,
|
||||||
|
|
||||||
@@ -288,6 +291,17 @@ pub(crate) mod tools {
|
|||||||
Freeform(FreeformTool),
|
Freeform(FreeformTool),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolSpec {
|
||||||
|
pub(crate) fn name(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||||
|
ToolSpec::LocalShell {} => "local_shell",
|
||||||
|
ToolSpec::WebSearch {} => "web_search",
|
||||||
|
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
pub struct FreeformTool {
|
pub struct FreeformTool {
|
||||||
pub(crate) name: String,
|
pub(crate) name: String,
|
||||||
@@ -433,7 +447,7 @@ mod tests {
|
|||||||
input: &input,
|
input: &input,
|
||||||
tools: &tools,
|
tools: &tools,
|
||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: true,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
store: false,
|
store: false,
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -474,7 +488,7 @@ mod tests {
|
|||||||
input: &input,
|
input: &input,
|
||||||
tools: &tools,
|
tools: &tools,
|
||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: true,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
store: false,
|
store: false,
|
||||||
stream: true,
|
stream: true,
|
||||||
@@ -510,7 +524,7 @@ mod tests {
|
|||||||
input: &input,
|
input: &input,
|
||||||
tools: &tools,
|
tools: &tools,
|
||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: true,
|
||||||
reasoning: None,
|
reasoning: None,
|
||||||
store: false,
|
store: false,
|
||||||
stream: true,
|
stream: true,
|
||||||
|
|||||||
@@ -100,7 +100,9 @@ use crate::tasks::CompactTask;
|
|||||||
use crate::tasks::RegularTask;
|
use crate::tasks::RegularTask;
|
||||||
use crate::tasks::ReviewTask;
|
use crate::tasks::ReviewTask;
|
||||||
use crate::tools::ToolRouter;
|
use crate::tools::ToolRouter;
|
||||||
|
use crate::tools::context::SharedTurnDiffTracker;
|
||||||
use crate::tools::format_exec_output_str;
|
use crate::tools::format_exec_output_str;
|
||||||
|
use crate::tools::parallel::ToolCallRuntime;
|
||||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
use crate::unified_exec::UnifiedExecSessionManager;
|
use crate::unified_exec::UnifiedExecSessionManager;
|
||||||
use crate::user_instructions::UserInstructions;
|
use crate::user_instructions::UserInstructions;
|
||||||
@@ -818,7 +820,7 @@ impl Session {
|
|||||||
|
|
||||||
async fn on_exec_command_begin(
|
async fn on_exec_command_begin(
|
||||||
&self,
|
&self,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
exec_command_context: ExecCommandContext,
|
exec_command_context: ExecCommandContext,
|
||||||
) {
|
) {
|
||||||
let ExecCommandContext {
|
let ExecCommandContext {
|
||||||
@@ -834,7 +836,10 @@ impl Session {
|
|||||||
user_explicitly_approved_this_action,
|
user_explicitly_approved_this_action,
|
||||||
changes,
|
changes,
|
||||||
}) => {
|
}) => {
|
||||||
turn_diff_tracker.on_patch_begin(&changes);
|
{
|
||||||
|
let mut tracker = turn_diff_tracker.lock().await;
|
||||||
|
tracker.on_patch_begin(&changes);
|
||||||
|
}
|
||||||
|
|
||||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||||
call_id,
|
call_id,
|
||||||
@@ -861,7 +866,7 @@ impl Session {
|
|||||||
|
|
||||||
async fn on_exec_command_end(
|
async fn on_exec_command_end(
|
||||||
&self,
|
&self,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
output: &ExecToolCallOutput,
|
output: &ExecToolCallOutput,
|
||||||
@@ -909,7 +914,10 @@ impl Session {
|
|||||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||||
// with the full turn diff if there is one.
|
// with the full turn diff if there is one.
|
||||||
if is_apply_patch {
|
if is_apply_patch {
|
||||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
let unified_diff = {
|
||||||
|
let mut tracker = turn_diff_tracker.lock().await;
|
||||||
|
tracker.get_unified_diff()
|
||||||
|
};
|
||||||
if let Ok(Some(unified_diff)) = unified_diff {
|
if let Ok(Some(unified_diff)) = unified_diff {
|
||||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||||
let event = Event {
|
let event = Event {
|
||||||
@@ -926,7 +934,7 @@ impl Session {
|
|||||||
/// Returns the output of the exec tool call.
|
/// Returns the output of the exec tool call.
|
||||||
pub(crate) async fn run_exec_with_events(
|
pub(crate) async fn run_exec_with_events(
|
||||||
&self,
|
&self,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
prepared: PreparedExec,
|
prepared: PreparedExec,
|
||||||
approval_policy: AskForApproval,
|
approval_policy: AskForApproval,
|
||||||
) -> Result<ExecToolCallOutput, ExecError> {
|
) -> Result<ExecToolCallOutput, ExecError> {
|
||||||
@@ -935,7 +943,7 @@ impl Session {
|
|||||||
let sub_id = context.sub_id.clone();
|
let sub_id = context.sub_id.clone();
|
||||||
let call_id = context.call_id.clone();
|
let call_id = context.call_id.clone();
|
||||||
|
|
||||||
self.on_exec_command_begin(turn_diff_tracker, context.clone())
|
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let result = self
|
let result = self
|
||||||
@@ -1644,7 +1652,7 @@ pub(crate) async fn run_task(
|
|||||||
let mut last_agent_message: Option<String> = None;
|
let mut last_agent_message: Option<String> = None;
|
||||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||||
// many turns, from the perspective of the user, it is a single turn.
|
// many turns, from the perspective of the user, it is a single turn.
|
||||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||||
let mut auto_compact_recently_attempted = false;
|
let mut auto_compact_recently_attempted = false;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -1692,9 +1700,9 @@ pub(crate) async fn run_task(
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
match run_turn(
|
match run_turn(
|
||||||
&sess,
|
Arc::clone(&sess),
|
||||||
turn_context.as_ref(),
|
Arc::clone(&turn_context),
|
||||||
&mut turn_diff_tracker,
|
Arc::clone(&turn_diff_tracker),
|
||||||
sub_id.clone(),
|
sub_id.clone(),
|
||||||
turn_input,
|
turn_input,
|
||||||
)
|
)
|
||||||
@@ -1917,18 +1925,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn run_turn(
|
async fn run_turn(
|
||||||
sess: &Session,
|
sess: Arc<Session>,
|
||||||
turn_context: &TurnContext,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<ResponseItem>,
|
input: Vec<ResponseItem>,
|
||||||
) -> CodexResult<TurnRunResult> {
|
) -> CodexResult<TurnRunResult> {
|
||||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||||
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
|
let router = Arc::new(ToolRouter::from_config(
|
||||||
|
&turn_context.tools_config,
|
||||||
|
Some(mcp_tools),
|
||||||
|
));
|
||||||
|
|
||||||
|
let model_supports_parallel = turn_context
|
||||||
|
.client
|
||||||
|
.get_model_family()
|
||||||
|
.supports_parallel_tool_calls;
|
||||||
|
let parallel_tool_calls = model_supports_parallel;
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input,
|
input,
|
||||||
tools: router.specs().to_vec(),
|
tools: router.specs(),
|
||||||
|
parallel_tool_calls,
|
||||||
base_instructions_override: turn_context.base_instructions.clone(),
|
base_instructions_override: turn_context.base_instructions.clone(),
|
||||||
output_schema: turn_context.final_output_json_schema.clone(),
|
output_schema: turn_context.final_output_json_schema.clone(),
|
||||||
};
|
};
|
||||||
@@ -1936,10 +1953,10 @@ async fn run_turn(
|
|||||||
let mut retries = 0;
|
let mut retries = 0;
|
||||||
loop {
|
loop {
|
||||||
match try_run_turn(
|
match try_run_turn(
|
||||||
&router,
|
Arc::clone(&router),
|
||||||
sess,
|
Arc::clone(&sess),
|
||||||
turn_context,
|
Arc::clone(&turn_context),
|
||||||
turn_diff_tracker,
|
Arc::clone(&turn_diff_tracker),
|
||||||
&sub_id,
|
&sub_id,
|
||||||
&prompt,
|
&prompt,
|
||||||
)
|
)
|
||||||
@@ -1950,7 +1967,7 @@ async fn run_turn(
|
|||||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||||
sess.set_total_tokens_full(&sub_id, turn_context).await;
|
sess.set_total_tokens_full(&sub_id, &turn_context).await;
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
Err(CodexErr::UsageLimitReached(e)) => {
|
Err(CodexErr::UsageLimitReached(e)) => {
|
||||||
@@ -1999,9 +2016,9 @@ async fn run_turn(
|
|||||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||||
/// sent back to the model on the next turn.
|
/// sent back to the model on the next turn.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct ProcessedResponseItem {
|
pub(crate) struct ProcessedResponseItem {
|
||||||
item: ResponseItem,
|
pub(crate) item: ResponseItem,
|
||||||
response: Option<ResponseInputItem>,
|
pub(crate) response: Option<ResponseInputItem>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -2011,10 +2028,10 @@ struct TurnRunResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn try_run_turn(
|
async fn try_run_turn(
|
||||||
router: &crate::tools::ToolRouter,
|
router: Arc<ToolRouter>,
|
||||||
sess: &Session,
|
sess: Arc<Session>,
|
||||||
turn_context: &TurnContext,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
) -> CodexResult<TurnRunResult> {
|
) -> CodexResult<TurnRunResult> {
|
||||||
@@ -2085,24 +2102,34 @@ async fn try_run_turn(
|
|||||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||||
|
|
||||||
let mut output = Vec::new();
|
let mut output = Vec::new();
|
||||||
|
let mut tool_runtime = ToolCallRuntime::new(
|
||||||
|
Arc::clone(&router),
|
||||||
|
Arc::clone(&sess),
|
||||||
|
Arc::clone(&turn_context),
|
||||||
|
Arc::clone(&turn_diff_tracker),
|
||||||
|
sub_id.to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||||
let event = stream.next().await;
|
let event = stream.next().await;
|
||||||
let Some(event) = event else {
|
let event = match event {
|
||||||
// Channel closed without yielding a final Completed event or explicit error.
|
Some(event) => event,
|
||||||
// Treat as a disconnected stream so the caller can retry.
|
None => {
|
||||||
return Err(CodexErr::Stream(
|
tool_runtime.abort_all();
|
||||||
"stream closed before response.completed".into(),
|
return Err(CodexErr::Stream(
|
||||||
None,
|
"stream closed before response.completed".into(),
|
||||||
));
|
None,
|
||||||
|
));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let event = match event {
|
let event = match event {
|
||||||
Ok(ev) => ev,
|
Ok(ev) => ev,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
tool_runtime.abort_all();
|
||||||
// Propagate the underlying stream error to the caller (run_turn), which
|
// Propagate the underlying stream error to the caller (run_turn), which
|
||||||
// will apply the configured `stream_max_retries` policy.
|
// will apply the configured `stream_max_retries` policy.
|
||||||
return Err(e);
|
return Err(e);
|
||||||
@@ -2112,16 +2139,66 @@ async fn try_run_turn(
|
|||||||
match event {
|
match event {
|
||||||
ResponseEvent::Created => {}
|
ResponseEvent::Created => {}
|
||||||
ResponseEvent::OutputItemDone(item) => {
|
ResponseEvent::OutputItemDone(item) => {
|
||||||
let response = handle_response_item(
|
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
|
||||||
router,
|
Ok(Some(call)) => {
|
||||||
sess,
|
let payload_preview = call.payload.log_payload().into_owned();
|
||||||
turn_context,
|
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||||
turn_diff_tracker,
|
let index = output.len();
|
||||||
sub_id,
|
output.push(ProcessedResponseItem {
|
||||||
item.clone(),
|
item,
|
||||||
)
|
response: None,
|
||||||
.await?;
|
});
|
||||||
output.push(ProcessedResponseItem { item, response });
|
tool_runtime
|
||||||
|
.handle_tool_call(call, index, output.as_mut_slice())
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
let response = handle_non_tool_response_item(
|
||||||
|
Arc::clone(&sess),
|
||||||
|
Arc::clone(&turn_context),
|
||||||
|
sub_id,
|
||||||
|
item.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
output.push(ProcessedResponseItem { item, response });
|
||||||
|
}
|
||||||
|
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||||
|
let msg = "LocalShellCall without call_id or id";
|
||||||
|
turn_context
|
||||||
|
.client
|
||||||
|
.get_otel_event_manager()
|
||||||
|
.log_tool_failed("local_shell", msg);
|
||||||
|
error!(msg);
|
||||||
|
|
||||||
|
let response = ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id: String::new(),
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: msg.to_string(),
|
||||||
|
success: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
output.push(ProcessedResponseItem {
|
||||||
|
item,
|
||||||
|
response: Some(response),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(FunctionCallError::RespondToModel(message)) => {
|
||||||
|
let response = ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id: String::new(),
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: message,
|
||||||
|
success: None,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
output.push(ProcessedResponseItem {
|
||||||
|
item,
|
||||||
|
response: Some(response),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Err(FunctionCallError::Fatal(message)) => {
|
||||||
|
return Err(CodexErr::Fatal(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||||
let _ = sess
|
let _ = sess
|
||||||
@@ -2141,10 +2218,15 @@ async fn try_run_turn(
|
|||||||
response_id: _,
|
response_id: _,
|
||||||
token_usage,
|
token_usage,
|
||||||
} => {
|
} => {
|
||||||
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
|
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let unified_diff = turn_diff_tracker.get_unified_diff();
|
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
|
||||||
|
|
||||||
|
let unified_diff = {
|
||||||
|
let mut tracker = turn_diff_tracker.lock().await;
|
||||||
|
tracker.get_unified_diff()
|
||||||
|
};
|
||||||
if let Ok(Some(unified_diff)) = unified_diff {
|
if let Ok(Some(unified_diff)) = unified_diff {
|
||||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||||
let event = Event {
|
let event = Event {
|
||||||
@@ -2203,88 +2285,40 @@ async fn try_run_turn(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_response_item(
|
async fn handle_non_tool_response_item(
|
||||||
router: &crate::tools::ToolRouter,
|
sess: Arc<Session>,
|
||||||
sess: &Session,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_context: &TurnContext,
|
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
|
||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
item: ResponseItem,
|
item: ResponseItem,
|
||||||
) -> CodexResult<Option<ResponseInputItem>> {
|
) -> CodexResult<Option<ResponseInputItem>> {
|
||||||
debug!(?item, "Output item");
|
debug!(?item, "Output item");
|
||||||
|
|
||||||
match ToolRouter::build_tool_call(sess, item.clone()) {
|
match &item {
|
||||||
Ok(Some(call)) => {
|
ResponseItem::Message { .. }
|
||||||
let payload_preview = call.payload.log_payload().into_owned();
|
| ResponseItem::Reasoning { .. }
|
||||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
| ResponseItem::WebSearchCall { .. } => {
|
||||||
match router
|
let msgs = match &item {
|
||||||
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
|
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
||||||
.await
|
trace!("suppressing assistant Message in review mode");
|
||||||
{
|
Vec::new()
|
||||||
Ok(response) => Ok(Some(response)),
|
}
|
||||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
|
||||||
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
|
};
|
||||||
|
for msg in msgs {
|
||||||
|
let event = Event {
|
||||||
|
id: sub_id.to_string(),
|
||||||
|
msg,
|
||||||
|
};
|
||||||
|
sess.send_event(event).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(None) => {
|
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
|
||||||
match &item {
|
debug!("unexpected tool output from stream");
|
||||||
ResponseItem::Message { .. }
|
|
||||||
| ResponseItem::Reasoning { .. }
|
|
||||||
| ResponseItem::WebSearchCall { .. } => {
|
|
||||||
let msgs = match &item {
|
|
||||||
ResponseItem::Message { .. } if turn_context.is_review_mode => {
|
|
||||||
trace!("suppressing assistant Message in review mode");
|
|
||||||
Vec::new()
|
|
||||||
}
|
|
||||||
_ => map_response_item_to_event_messages(
|
|
||||||
&item,
|
|
||||||
sess.show_raw_agent_reasoning(),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
for msg in msgs {
|
|
||||||
let event = Event {
|
|
||||||
id: sub_id.to_string(),
|
|
||||||
msg,
|
|
||||||
};
|
|
||||||
sess.send_event(event).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ResponseItem::FunctionCallOutput { .. }
|
|
||||||
| ResponseItem::CustomToolCallOutput { .. } => {
|
|
||||||
debug!("unexpected tool output from stream");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(None)
|
|
||||||
}
|
}
|
||||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
_ => {}
|
||||||
let msg = "LocalShellCall without call_id or id";
|
|
||||||
turn_context
|
|
||||||
.client
|
|
||||||
.get_otel_event_manager()
|
|
||||||
.log_tool_failed("local_shell", msg);
|
|
||||||
error!(msg);
|
|
||||||
|
|
||||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
|
||||||
call_id: String::new(),
|
|
||||||
output: FunctionCallOutputPayload {
|
|
||||||
content: msg.to_string(),
|
|
||||||
success: None,
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
Err(FunctionCallError::RespondToModel(msg)) => {
|
|
||||||
Ok(Some(ResponseInputItem::FunctionCallOutput {
|
|
||||||
call_id: String::new(),
|
|
||||||
output: FunctionCallOutputPayload {
|
|
||||||
content: msg,
|
|
||||||
success: None,
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||||
@@ -2927,13 +2961,10 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fatal_tool_error_stops_turn_and_reports_error() {
|
async fn fatal_tool_error_stops_turn_and_reports_error() {
|
||||||
let (session, turn_context, _rx) = make_session_and_context_with_rx();
|
let (session, turn_context, _rx) = make_session_and_context_with_rx();
|
||||||
let session_ref = session.as_ref();
|
|
||||||
let turn_context_ref = turn_context.as_ref();
|
|
||||||
let router = ToolRouter::from_config(
|
let router = ToolRouter::from_config(
|
||||||
&turn_context_ref.tools_config,
|
&turn_context.tools_config,
|
||||||
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
|
Some(session.services.mcp_connection_manager.list_all_tools()),
|
||||||
);
|
);
|
||||||
let mut tracker = TurnDiffTracker::new();
|
|
||||||
let item = ResponseItem::CustomToolCall {
|
let item = ResponseItem::CustomToolCall {
|
||||||
id: None,
|
id: None,
|
||||||
status: None,
|
status: None,
|
||||||
@@ -2942,22 +2973,26 @@ mod tests {
|
|||||||
input: "{}".to_string(),
|
input: "{}".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = handle_response_item(
|
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
|
||||||
&router,
|
.expect("build tool call")
|
||||||
session_ref,
|
.expect("tool call present");
|
||||||
turn_context_ref,
|
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||||
&mut tracker,
|
let err = router
|
||||||
"sub-id",
|
.dispatch_tool_call(
|
||||||
item,
|
Arc::clone(&session),
|
||||||
)
|
Arc::clone(&turn_context),
|
||||||
.await
|
tracker,
|
||||||
.expect_err("expected fatal error");
|
"sub-id".to_string(),
|
||||||
|
call,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect_err("expected fatal error");
|
||||||
|
|
||||||
match err {
|
match err {
|
||||||
CodexErr::Fatal(message) => {
|
FunctionCallError::Fatal(message) => {
|
||||||
assert_eq!(message, "tool shell invoked with incompatible payload");
|
assert_eq!(message, "tool shell invoked with incompatible payload");
|
||||||
}
|
}
|
||||||
other => panic!("expected CodexErr::Fatal, got {other:?}"),
|
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3071,9 +3106,11 @@ mod tests {
|
|||||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
let (session, mut turn_context) = make_session_and_context();
|
let (session, mut turn_context_raw) = make_session_and_context();
|
||||||
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
||||||
turn_context.approval_policy = AskForApproval::OnFailure;
|
turn_context_raw.approval_policy = AskForApproval::OnFailure;
|
||||||
|
let session = Arc::new(session);
|
||||||
|
let mut turn_context = Arc::new(turn_context_raw);
|
||||||
|
|
||||||
let params = ExecParams {
|
let params = ExecParams {
|
||||||
command: if cfg!(windows) {
|
command: if cfg!(windows) {
|
||||||
@@ -3101,7 +3138,7 @@ mod tests {
|
|||||||
..params.clone()
|
..params.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||||
|
|
||||||
let tool_name = "shell";
|
let tool_name = "shell";
|
||||||
let sub_id = "test-sub".to_string();
|
let sub_id = "test-sub".to_string();
|
||||||
@@ -3110,9 +3147,9 @@ mod tests {
|
|||||||
let resp = handle_container_exec_with_params(
|
let resp = handle_container_exec_with_params(
|
||||||
tool_name,
|
tool_name,
|
||||||
params,
|
params,
|
||||||
&session,
|
Arc::clone(&session),
|
||||||
&turn_context,
|
Arc::clone(&turn_context),
|
||||||
&mut turn_diff_tracker,
|
Arc::clone(&turn_diff_tracker),
|
||||||
sub_id,
|
sub_id,
|
||||||
call_id,
|
call_id,
|
||||||
)
|
)
|
||||||
@@ -3131,14 +3168,16 @@ mod tests {
|
|||||||
|
|
||||||
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
||||||
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
||||||
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
Arc::get_mut(&mut turn_context)
|
||||||
|
.expect("unique turn context Arc")
|
||||||
|
.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||||
|
|
||||||
let resp2 = handle_container_exec_with_params(
|
let resp2 = handle_container_exec_with_params(
|
||||||
tool_name,
|
tool_name,
|
||||||
params2,
|
params2,
|
||||||
&session,
|
Arc::clone(&session),
|
||||||
&turn_context,
|
Arc::clone(&turn_context),
|
||||||
&mut turn_diff_tracker,
|
Arc::clone(&turn_diff_tracker),
|
||||||
"test-sub".to_string(),
|
"test-sub".to_string(),
|
||||||
"test-call-2".to_string(),
|
"test-call-2".to_string(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ pub struct ModelFamily {
|
|||||||
// See https://platform.openai.com/docs/guides/tools-local-shell
|
// See https://platform.openai.com/docs/guides/tools-local-shell
|
||||||
pub uses_local_shell_tool: bool,
|
pub uses_local_shell_tool: bool,
|
||||||
|
|
||||||
|
/// Whether this model supports parallel tool calls when using the
|
||||||
|
/// Responses API.
|
||||||
|
pub supports_parallel_tool_calls: bool,
|
||||||
|
|
||||||
/// Present if the model performs better when `apply_patch` is provided as
|
/// Present if the model performs better when `apply_patch` is provided as
|
||||||
/// a tool call instead of just a bash command
|
/// a tool call instead of just a bash command
|
||||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||||
@@ -58,6 +62,7 @@ macro_rules! model_family {
|
|||||||
supports_reasoning_summaries: false,
|
supports_reasoning_summaries: false,
|
||||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||||
uses_local_shell_tool: false,
|
uses_local_shell_tool: false,
|
||||||
|
supports_parallel_tool_calls: false,
|
||||||
apply_patch_tool_type: None,
|
apply_patch_tool_type: None,
|
||||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||||
experimental_supported_tools: Vec::new(),
|
experimental_supported_tools: Vec::new(),
|
||||||
@@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
|||||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||||
} else if slug.starts_with("gpt-3.5") {
|
} else if slug.starts_with("gpt-3.5") {
|
||||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||||
|
} else if slug.starts_with("test-gpt-5-codex") {
|
||||||
|
model_family!(
|
||||||
|
slug, slug,
|
||||||
|
supports_reasoning_summaries: true,
|
||||||
|
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||||
|
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||||
|
experimental_supported_tools: vec![
|
||||||
|
"read_file".to_string(),
|
||||||
|
"test_sync_tool".to_string()
|
||||||
|
],
|
||||||
|
supports_parallel_tool_calls: true,
|
||||||
|
)
|
||||||
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
|
||||||
model_family!(
|
model_family!(
|
||||||
slug, slug,
|
slug, slug,
|
||||||
@@ -110,6 +127,8 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
|||||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||||
|
// experimental_supported_tools: vec!["read_file".to_string()],
|
||||||
|
// supports_parallel_tool_calls: true,
|
||||||
)
|
)
|
||||||
} else if slug.starts_with("gpt-5") {
|
} else if slug.starts_with("gpt-5") {
|
||||||
model_family!(
|
model_family!(
|
||||||
@@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
|||||||
supports_reasoning_summaries: false,
|
supports_reasoning_summaries: false,
|
||||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||||
uses_local_shell_tool: false,
|
uses_local_shell_tool: false,
|
||||||
|
supports_parallel_tool_calls: false,
|
||||||
apply_patch_tool_type: None,
|
apply_patch_tool_type: None,
|
||||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||||
experimental_supported_tools: Vec::new(),
|
experimental_supported_tools: Vec::new(),
|
||||||
|
|||||||
@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
|
|||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
pub struct ToolInvocation<'a> {
|
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||||
pub session: &'a Session,
|
|
||||||
pub turn: &'a TurnContext,
|
#[derive(Clone)]
|
||||||
pub tracker: &'a mut TurnDiffTracker,
|
pub struct ToolInvocation {
|
||||||
pub sub_id: &'a str,
|
pub session: Arc<Session>,
|
||||||
|
pub turn: Arc<TurnContext>,
|
||||||
|
pub tracker: SharedTurnDiffTracker,
|
||||||
|
pub sub_id: String,
|
||||||
pub call_id: String,
|
pub call_id: String,
|
||||||
pub tool_name: String,
|
pub tool_name: String,
|
||||||
pub payload: ToolPayload,
|
pub payload: ToolPayload,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::client_common::tools::FreeformTool;
|
use crate::client_common::tools::FreeformTool;
|
||||||
use crate::client_common::tools::FreeformToolFormat;
|
use crate::client_common::tools::FreeformToolFormat;
|
||||||
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
|
|||||||
let content = handle_container_exec_with_params(
|
let content = handle_container_exec_with_params(
|
||||||
tool_name.as_str(),
|
tool_name.as_str(),
|
||||||
exec_params,
|
exec_params,
|
||||||
session,
|
Arc::clone(&session),
|
||||||
turn,
|
Arc::clone(&turn),
|
||||||
tracker,
|
Arc::clone(&tracker),
|
||||||
sub_id.to_string(),
|
sub_id.clone(),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
|
|||||||
ToolKind::Function
|
ToolKind::Function
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
tool_name,
|
tool_name,
|
||||||
|
|||||||
@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
|
|||||||
ToolKind::Mcp
|
ToolKind::Mcp
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
sub_id,
|
||||||
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
|
|||||||
let arguments_str = raw_arguments;
|
let arguments_str = raw_arguments;
|
||||||
|
|
||||||
let response = handle_mcp_tool_call(
|
let response = handle_mcp_tool_call(
|
||||||
session,
|
session.as_ref(),
|
||||||
sub_id,
|
&sub_id,
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
server,
|
server,
|
||||||
tool,
|
tool,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ mod mcp;
|
|||||||
mod plan;
|
mod plan;
|
||||||
mod read_file;
|
mod read_file;
|
||||||
mod shell;
|
mod shell;
|
||||||
|
mod test_sync;
|
||||||
mod unified_exec;
|
mod unified_exec;
|
||||||
mod view_image;
|
mod view_image;
|
||||||
|
|
||||||
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
|
|||||||
pub use plan::PlanHandler;
|
pub use plan::PlanHandler;
|
||||||
pub use read_file::ReadFileHandler;
|
pub use read_file::ReadFileHandler;
|
||||||
pub use shell::ShellHandler;
|
pub use shell::ShellHandler;
|
||||||
|
pub use test_sync::TestSyncHandler;
|
||||||
pub use unified_exec::UnifiedExecHandler;
|
pub use unified_exec::UnifiedExecHandler;
|
||||||
pub use view_image::ViewImageHandler;
|
pub use view_image::ViewImageHandler;
|
||||||
|
|||||||
@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
|
|||||||
ToolKind::Function
|
ToolKind::Function
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
sub_id,
|
sub_id,
|
||||||
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
|
let content =
|
||||||
|
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
|
||||||
|
|
||||||
Ok(ToolOutput::Function {
|
Ok(ToolOutput::Function {
|
||||||
content,
|
content,
|
||||||
|
|||||||
@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
|
|||||||
ToolKind::Function
|
ToolKind::Function
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation { payload, .. } = invocation;
|
let ToolInvocation { payload, .. } = invocation;
|
||||||
|
|
||||||
let arguments = match payload {
|
let arguments = match payload {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use codex_protocol::models::ShellToolCallParams;
|
use codex_protocol::models::ShellToolCallParams;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::exec::ExecParams;
|
use crate::exec::ExecParams;
|
||||||
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
|
|||||||
"failed to parse function arguments: {e:?}"
|
"failed to parse function arguments: {e:?}"
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
let exec_params = Self::to_exec_params(params, turn);
|
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||||
let content = handle_container_exec_with_params(
|
let content = handle_container_exec_with_params(
|
||||||
tool_name.as_str(),
|
tool_name.as_str(),
|
||||||
exec_params,
|
exec_params,
|
||||||
session,
|
Arc::clone(&session),
|
||||||
turn,
|
Arc::clone(&turn),
|
||||||
tracker,
|
Arc::clone(&tracker),
|
||||||
sub_id.to_string(),
|
sub_id.clone(),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
ToolPayload::LocalShell { params } => {
|
ToolPayload::LocalShell { params } => {
|
||||||
let exec_params = Self::to_exec_params(params, turn);
|
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||||
let content = handle_container_exec_with_params(
|
let content = handle_container_exec_with_params(
|
||||||
tool_name.as_str(),
|
tool_name.as_str(),
|
||||||
exec_params,
|
exec_params,
|
||||||
session,
|
Arc::clone(&session),
|
||||||
turn,
|
Arc::clone(&turn),
|
||||||
tracker,
|
Arc::clone(&tracker),
|
||||||
sub_id.to_string(),
|
sub_id.clone(),
|
||||||
call_id.clone(),
|
call_id.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
158
codex-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::hash_map::Entry;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::sync::Barrier;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct TestSyncHandler;
|
||||||
|
|
||||||
|
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||||
|
|
||||||
|
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||||
|
|
||||||
|
struct BarrierState {
|
||||||
|
barrier: Arc<Barrier>,
|
||||||
|
participants: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct BarrierArgs {
|
||||||
|
id: String,
|
||||||
|
participants: usize,
|
||||||
|
#[serde(default = "default_timeout_ms")]
|
||||||
|
timeout_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct TestSyncArgs {
|
||||||
|
#[serde(default)]
|
||||||
|
sleep_before_ms: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
sleep_after_ms: Option<u64>,
|
||||||
|
#[serde(default)]
|
||||||
|
barrier: Option<BarrierArgs>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_timeout_ms() -> u64 {
|
||||||
|
DEFAULT_TIMEOUT_MS
|
||||||
|
}
|
||||||
|
|
||||||
|
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||||
|
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for TestSyncHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation { payload, .. } = invocation;
|
||||||
|
|
||||||
|
let arguments = match payload {
|
||||||
|
ToolPayload::Function { arguments } => arguments,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"test_sync_tool handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {err:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if let Some(delay) = args.sleep_before_ms
|
||||||
|
&& delay > 0
|
||||||
|
{
|
||||||
|
sleep(Duration::from_millis(delay)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(barrier) = args.barrier {
|
||||||
|
wait_on_barrier(barrier).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(delay) = args.sleep_after_ms
|
||||||
|
&& delay > 0
|
||||||
|
{
|
||||||
|
sleep(Duration::from_millis(delay)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content: "ok".to_string(),
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||||
|
if args.participants == 0 {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"barrier participants must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.timeout_ms == 0 {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"barrier timeout must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let barrier_id = args.id.clone();
|
||||||
|
let barrier = {
|
||||||
|
let mut map = barrier_map().lock().await;
|
||||||
|
match map.entry(barrier_id.clone()) {
|
||||||
|
Entry::Occupied(entry) => {
|
||||||
|
let state = entry.get();
|
||||||
|
if state.participants != args.participants {
|
||||||
|
let existing = state.participants;
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"barrier {barrier_id} already registered with {existing} participants"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
state.barrier.clone()
|
||||||
|
}
|
||||||
|
Entry::Vacant(entry) => {
|
||||||
|
let barrier = Arc::new(Barrier::new(args.participants));
|
||||||
|
entry.insert(BarrierState {
|
||||||
|
barrier: barrier.clone(),
|
||||||
|
participants: args.participants,
|
||||||
|
});
|
||||||
|
barrier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let timeout = Duration::from_millis(args.timeout_ms);
|
||||||
|
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if wait_result.is_leader() {
|
||||||
|
let mut map = barrier_map().lock().await;
|
||||||
|
if let Some(state) = map.get(&barrier_id)
|
||||||
|
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||||
|
{
|
||||||
|
map.remove(&barrier_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session, payload, ..
|
session, payload, ..
|
||||||
} = invocation;
|
} = invocation;
|
||||||
|
|||||||
@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
|
|||||||
ToolKind::Function
|
ToolKind::Function
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||||
&self,
|
|
||||||
invocation: ToolInvocation<'_>,
|
|
||||||
) -> Result<ToolOutput, FunctionCallError> {
|
|
||||||
let ToolInvocation {
|
let ToolInvocation {
|
||||||
session,
|
session,
|
||||||
turn,
|
turn,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod context;
|
pub mod context;
|
||||||
pub(crate) mod handlers;
|
pub(crate) mod handlers;
|
||||||
|
pub mod parallel;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod spec;
|
pub mod spec;
|
||||||
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
|
|||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::tools::context::ApplyPatchCommandContext;
|
use crate::tools::context::ApplyPatchCommandContext;
|
||||||
use crate::tools::context::ExecCommandContext;
|
use crate::tools::context::ExecCommandContext;
|
||||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
use crate::tools::context::SharedTurnDiffTracker;
|
||||||
use codex_apply_patch::MaybeApplyPatchVerified;
|
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||||
use codex_protocol::protocol::AskForApproval;
|
use codex_protocol::protocol::AskForApproval;
|
||||||
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
|
|||||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||||
pub use router::ToolRouter;
|
pub use router::ToolRouter;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use std::sync::Arc;
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
|
||||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||||
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
|||||||
pub(crate) async fn handle_container_exec_with_params(
|
pub(crate) async fn handle_container_exec_with_params(
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
params: ExecParams,
|
params: ExecParams,
|
||||||
sess: &Session,
|
sess: Arc<Session>,
|
||||||
turn_context: &TurnContext,
|
turn_context: Arc<TurnContext>,
|
||||||
turn_diff_tracker: &mut TurnDiffTracker,
|
turn_diff_tracker: SharedTurnDiffTracker,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
call_id: String,
|
call_id: String,
|
||||||
) -> Result<String, FunctionCallError> {
|
) -> Result<String, FunctionCallError> {
|
||||||
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
// check if this was a patch, and apply it if so
|
// check if this was a patch, and apply it if so
|
||||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||||
MaybeApplyPatchVerified::Body(changes) => {
|
MaybeApplyPatchVerified::Body(changes) => {
|
||||||
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
match apply_patch::apply_patch(
|
||||||
|
sess.as_ref(),
|
||||||
|
turn_context.as_ref(),
|
||||||
|
&sub_id,
|
||||||
|
&call_id,
|
||||||
|
changes,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
InternalApplyPatchInvocation::Output(item) => return item,
|
InternalApplyPatchInvocation::Output(item) => return item,
|
||||||
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||||
Some(apply_patch_exec)
|
Some(apply_patch_exec)
|
||||||
@@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params(
|
|||||||
|
|
||||||
let output_result = sess
|
let output_result = sess
|
||||||
.run_exec_with_events(
|
.run_exec_with_events(
|
||||||
turn_diff_tracker,
|
turn_diff_tracker.clone(),
|
||||||
prepared_exec,
|
prepared_exec,
|
||||||
turn_context.approval_policy,
|
turn_context.approval_policy,
|
||||||
)
|
)
|
||||||
|
|||||||
137
codex-rs/core/src/tools/parallel.rs
Normal file
137
codex-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
|
use crate::error::CodexErr;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::SharedTurnDiffTracker;
|
||||||
|
use crate::tools::router::ToolCall;
|
||||||
|
use crate::tools::router::ToolRouter;
|
||||||
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
|
||||||
|
use crate::codex::ProcessedResponseItem;
|
||||||
|
|
||||||
|
struct PendingToolCall {
|
||||||
|
index: usize,
|
||||||
|
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct ToolCallRuntime {
|
||||||
|
router: Arc<ToolRouter>,
|
||||||
|
session: Arc<Session>,
|
||||||
|
turn_context: Arc<TurnContext>,
|
||||||
|
tracker: SharedTurnDiffTracker,
|
||||||
|
sub_id: String,
|
||||||
|
pending_calls: Vec<PendingToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCallRuntime {
|
||||||
|
pub(crate) fn new(
|
||||||
|
router: Arc<ToolRouter>,
|
||||||
|
session: Arc<Session>,
|
||||||
|
turn_context: Arc<TurnContext>,
|
||||||
|
tracker: SharedTurnDiffTracker,
|
||||||
|
sub_id: String,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
router,
|
||||||
|
session,
|
||||||
|
turn_context,
|
||||||
|
tracker,
|
||||||
|
sub_id,
|
||||||
|
pending_calls: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn handle_tool_call(
|
||||||
|
&mut self,
|
||||||
|
call: ToolCall,
|
||||||
|
output_index: usize,
|
||||||
|
output: &mut [ProcessedResponseItem],
|
||||||
|
) -> Result<(), CodexErr> {
|
||||||
|
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||||
|
if supports_parallel {
|
||||||
|
self.spawn_parallel(call, output_index);
|
||||||
|
} else {
|
||||||
|
self.resolve_pending(output).await?;
|
||||||
|
let response = self.dispatch_serial(call).await?;
|
||||||
|
let slot = output.get_mut(output_index).ok_or_else(|| {
|
||||||
|
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
|
||||||
|
})?;
|
||||||
|
slot.response = Some(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn abort_all(&mut self) {
|
||||||
|
while let Some(pending) = self.pending_calls.pop() {
|
||||||
|
pending.handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn resolve_pending(
|
||||||
|
&mut self,
|
||||||
|
output: &mut [ProcessedResponseItem],
|
||||||
|
) -> Result<(), CodexErr> {
|
||||||
|
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
|
||||||
|
match handle.await {
|
||||||
|
Ok(Ok(response)) => {
|
||||||
|
if let Some(slot) = output.get_mut(index) {
|
||||||
|
slot.response = Some(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(Err(FunctionCallError::Fatal(message))) => {
|
||||||
|
self.abort_all();
|
||||||
|
return Err(CodexErr::Fatal(message));
|
||||||
|
}
|
||||||
|
Ok(Err(other)) => {
|
||||||
|
self.abort_all();
|
||||||
|
return Err(CodexErr::Fatal(other.to_string()));
|
||||||
|
}
|
||||||
|
Err(join_err) => {
|
||||||
|
self.abort_all();
|
||||||
|
return Err(CodexErr::Fatal(format!(
|
||||||
|
"tool task failed to join: {join_err}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
|
||||||
|
let router = Arc::clone(&self.router);
|
||||||
|
let session = Arc::clone(&self.session);
|
||||||
|
let turn = Arc::clone(&self.turn_context);
|
||||||
|
let tracker = Arc::clone(&self.tracker);
|
||||||
|
let sub_id = self.sub_id.clone();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
router
|
||||||
|
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
self.pending_calls.push(PendingToolCall { index, handle });
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
|
||||||
|
match self
|
||||||
|
.router
|
||||||
|
.dispatch_tool_call(
|
||||||
|
Arc::clone(&self.session),
|
||||||
|
Arc::clone(&self.turn_context),
|
||||||
|
Arc::clone(&self.tracker),
|
||||||
|
self.sub_id.clone(),
|
||||||
|
call,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => Ok(response),
|
||||||
|
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||||
|
Err(other) => Err(CodexErr::Fatal(other.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle(&self, invocation: ToolInvocation<'_>)
|
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||||
-> Result<ToolOutput, FunctionCallError>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ToolRegistry {
|
pub struct ToolRegistry {
|
||||||
@@ -57,9 +56,9 @@ impl ToolRegistry {
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
pub async fn dispatch<'a>(
|
pub async fn dispatch(
|
||||||
&self,
|
&self,
|
||||||
invocation: ToolInvocation<'a>,
|
invocation: ToolInvocation,
|
||||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||||
let tool_name = invocation.tool_name.clone();
|
let tool_name = invocation.tool_name.clone();
|
||||||
let call_id_owned = invocation.call_id.clone();
|
let call_id_owned = invocation.call_id.clone();
|
||||||
@@ -137,9 +136,24 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ConfiguredToolSpec {
|
||||||
|
pub spec: ToolSpec,
|
||||||
|
pub supports_parallel_tool_calls: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConfiguredToolSpec {
|
||||||
|
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
spec,
|
||||||
|
supports_parallel_tool_calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ToolRegistryBuilder {
|
pub struct ToolRegistryBuilder {
|
||||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||||
specs: Vec<ToolSpec>,
|
specs: Vec<ConfiguredToolSpec>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRegistryBuilder {
|
impl ToolRegistryBuilder {
|
||||||
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||||
self.specs.push(spec);
|
self.push_spec_with_parallel_support(spec, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_spec_with_parallel_support(
|
||||||
|
&mut self,
|
||||||
|
spec: ToolSpec,
|
||||||
|
supports_parallel_tool_calls: bool,
|
||||||
|
) {
|
||||||
|
self.specs
|
||||||
|
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||||
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
|
||||||
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||||
let registry = ToolRegistry::new(self.handlers);
|
let registry = ToolRegistry::new(self.handlers);
|
||||||
(self.specs, registry)
|
(self.specs, registry)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::client_common::tools::ToolSpec;
|
use crate::client_common::tools::ToolSpec;
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::SharedTurnDiffTracker;
|
||||||
use crate::tools::context::ToolInvocation;
|
use crate::tools::context::ToolInvocation;
|
||||||
use crate::tools::context::ToolPayload;
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ConfiguredToolSpec;
|
||||||
use crate::tools::registry::ToolRegistry;
|
use crate::tools::registry::ToolRegistry;
|
||||||
use crate::tools::spec::ToolsConfig;
|
use crate::tools::spec::ToolsConfig;
|
||||||
use crate::tools::spec::build_specs;
|
use crate::tools::spec::build_specs;
|
||||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
|
||||||
use codex_protocol::models::LocalShellAction;
|
use codex_protocol::models::LocalShellAction;
|
||||||
use codex_protocol::models::ResponseInputItem;
|
use codex_protocol::models::ResponseInputItem;
|
||||||
use codex_protocol::models::ResponseItem;
|
use codex_protocol::models::ResponseItem;
|
||||||
@@ -24,7 +26,7 @@ pub struct ToolCall {
|
|||||||
|
|
||||||
pub struct ToolRouter {
|
pub struct ToolRouter {
|
||||||
registry: ToolRegistry,
|
registry: ToolRegistry,
|
||||||
specs: Vec<ToolSpec>,
|
specs: Vec<ConfiguredToolSpec>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolRouter {
|
impl ToolRouter {
|
||||||
@@ -34,11 +36,22 @@ impl ToolRouter {
|
|||||||
) -> Self {
|
) -> Self {
|
||||||
let builder = build_specs(config, mcp_tools);
|
let builder = build_specs(config, mcp_tools);
|
||||||
let (specs, registry) = builder.build();
|
let (specs, registry) = builder.build();
|
||||||
|
|
||||||
Self { registry, specs }
|
Self { registry, specs }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn specs(&self) -> &[ToolSpec] {
|
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||||
&self.specs
|
self.specs
|
||||||
|
.iter()
|
||||||
|
.map(|config| config.spec.clone())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||||
|
self.specs
|
||||||
|
.iter()
|
||||||
|
.filter(|config| config.supports_parallel_tool_calls)
|
||||||
|
.any(|config| config.spec.name() == tool_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build_tool_call(
|
pub fn build_tool_call(
|
||||||
@@ -118,10 +131,10 @@ impl ToolRouter {
|
|||||||
|
|
||||||
pub async fn dispatch_tool_call(
|
pub async fn dispatch_tool_call(
|
||||||
&self,
|
&self,
|
||||||
session: &Session,
|
session: Arc<Session>,
|
||||||
turn: &TurnContext,
|
turn: Arc<TurnContext>,
|
||||||
tracker: &mut TurnDiffTracker,
|
tracker: SharedTurnDiffTracker,
|
||||||
sub_id: &str,
|
sub_id: String,
|
||||||
call: ToolCall,
|
call: ToolCall,
|
||||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||||
let ToolCall {
|
let ToolCall {
|
||||||
|
|||||||
@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_test_sync_tool() -> ToolSpec {
|
||||||
|
let mut properties = BTreeMap::new();
|
||||||
|
properties.insert(
|
||||||
|
"sleep_before_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("Optional delay in milliseconds before any other action".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"sleep_after_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some(
|
||||||
|
"Optional delay in milliseconds after completing the barrier".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut barrier_properties = BTreeMap::new();
|
||||||
|
barrier_properties.insert(
|
||||||
|
"id".to_string(),
|
||||||
|
JsonSchema::String {
|
||||||
|
description: Some(
|
||||||
|
"Identifier shared by concurrent calls that should rendezvous".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
barrier_properties.insert(
|
||||||
|
"participants".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some(
|
||||||
|
"Number of tool calls that must arrive before the barrier opens".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
barrier_properties.insert(
|
||||||
|
"timeout_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
properties.insert(
|
||||||
|
"barrier".to_string(),
|
||||||
|
JsonSchema::Object {
|
||||||
|
properties: barrier_properties,
|
||||||
|
required: Some(vec!["id".to_string(), "participants".to_string()]),
|
||||||
|
additional_properties: Some(false.into()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
|
name: "test_sync_tool".to_string(),
|
||||||
|
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
|
||||||
|
strict: false,
|
||||||
|
parameters: JsonSchema::Object {
|
||||||
|
properties,
|
||||||
|
required: None,
|
||||||
|
additional_properties: Some(false.into()),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn create_read_file_tool() -> ToolSpec {
|
fn create_read_file_tool() -> ToolSpec {
|
||||||
let mut properties = BTreeMap::new();
|
let mut properties = BTreeMap::new();
|
||||||
properties.insert(
|
properties.insert(
|
||||||
@@ -507,6 +569,7 @@ pub(crate) fn build_specs(
|
|||||||
use crate::tools::handlers::PlanHandler;
|
use crate::tools::handlers::PlanHandler;
|
||||||
use crate::tools::handlers::ReadFileHandler;
|
use crate::tools::handlers::ReadFileHandler;
|
||||||
use crate::tools::handlers::ShellHandler;
|
use crate::tools::handlers::ShellHandler;
|
||||||
|
use crate::tools::handlers::TestSyncHandler;
|
||||||
use crate::tools::handlers::UnifiedExecHandler;
|
use crate::tools::handlers::UnifiedExecHandler;
|
||||||
use crate::tools::handlers::ViewImageHandler;
|
use crate::tools::handlers::ViewImageHandler;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -573,16 +636,26 @@ pub(crate) fn build_specs(
|
|||||||
.any(|tool| tool == "read_file")
|
.any(|tool| tool == "read_file")
|
||||||
{
|
{
|
||||||
let read_file_handler = Arc::new(ReadFileHandler);
|
let read_file_handler = Arc::new(ReadFileHandler);
|
||||||
builder.push_spec(create_read_file_tool());
|
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
|
||||||
builder.register_handler("read_file", read_file_handler);
|
builder.register_handler("read_file", read_file_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config
|
||||||
|
.experimental_supported_tools
|
||||||
|
.iter()
|
||||||
|
.any(|tool| tool == "test_sync_tool")
|
||||||
|
{
|
||||||
|
let test_sync_handler = Arc::new(TestSyncHandler);
|
||||||
|
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
|
||||||
|
builder.register_handler("test_sync_tool", test_sync_handler);
|
||||||
|
}
|
||||||
|
|
||||||
if config.web_search_request {
|
if config.web_search_request {
|
||||||
builder.push_spec(ToolSpec::WebSearch {});
|
builder.push_spec(ToolSpec::WebSearch {});
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.include_view_image_tool {
|
if config.include_view_image_tool {
|
||||||
builder.push_spec(create_view_image_tool());
|
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
|
||||||
builder.register_handler("view_image", view_image_handler);
|
builder.register_handler("view_image", view_image_handler);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -610,20 +683,25 @@ pub(crate) fn build_specs(
|
|||||||
mod tests {
|
mod tests {
|
||||||
use crate::client_common::tools::FreeformTool;
|
use crate::client_common::tools::FreeformTool;
|
||||||
use crate::model_family::find_family_for_model;
|
use crate::model_family::find_family_for_model;
|
||||||
|
use crate::tools::registry::ConfiguredToolSpec;
|
||||||
use mcp_types::ToolInputSchema;
|
use mcp_types::ToolInputSchema;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
|
fn tool_name(tool: &ToolSpec) -> &str {
|
||||||
|
match tool {
|
||||||
|
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
||||||
|
ToolSpec::LocalShell {} => "local_shell",
|
||||||
|
ToolSpec::WebSearch {} => "web_search",
|
||||||
|
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
|
||||||
let tool_names = tools
|
let tool_names = tools
|
||||||
.iter()
|
.iter()
|
||||||
.map(|tool| match tool {
|
.map(|tool| tool_name(&tool.spec))
|
||||||
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
|
|
||||||
ToolSpec::LocalShell {} => "local_shell",
|
|
||||||
ToolSpec::WebSearch {} => "web_search",
|
|
||||||
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -639,6 +717,16 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn find_tool<'a>(
|
||||||
|
tools: &'a [ConfiguredToolSpec],
|
||||||
|
expected_name: &str,
|
||||||
|
) -> &'a ConfiguredToolSpec {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.find(|tool| tool_name(&tool.spec) == expected_name)
|
||||||
|
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_specs() {
|
fn test_build_specs() {
|
||||||
let model_family = find_family_for_model("codex-mini-latest")
|
let model_family = find_family_for_model("codex-mini-latest")
|
||||||
@@ -680,6 +768,53 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore]
|
||||||
|
fn test_parallel_support_flags() {
|
||||||
|
let model_family = find_family_for_model("gpt-5-codex")
|
||||||
|
.expect("codex-mini-latest should be a valid model family");
|
||||||
|
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||||
|
model_family: &model_family,
|
||||||
|
include_plan_tool: false,
|
||||||
|
include_apply_patch_tool: false,
|
||||||
|
include_web_search_request: false,
|
||||||
|
use_streamable_shell_tool: false,
|
||||||
|
include_view_image_tool: false,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
|
});
|
||||||
|
let (tools, _) = build_specs(&config, None).build();
|
||||||
|
|
||||||
|
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
|
||||||
|
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_test_model_family_includes_sync_tool() {
|
||||||
|
let model_family = find_family_for_model("test-gpt-5-codex")
|
||||||
|
.expect("test-gpt-5-codex should be a valid model family");
|
||||||
|
let config = ToolsConfig::new(&ToolsConfigParams {
|
||||||
|
model_family: &model_family,
|
||||||
|
include_plan_tool: false,
|
||||||
|
include_apply_patch_tool: false,
|
||||||
|
include_web_search_request: false,
|
||||||
|
use_streamable_shell_tool: false,
|
||||||
|
include_view_image_tool: false,
|
||||||
|
experimental_unified_exec_tool: false,
|
||||||
|
});
|
||||||
|
let (tools, _) = build_specs(&config, None).build();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.any(|tool| tool_name(&tool.spec) == "read_file")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_specs_mcp_tools() {
|
fn test_build_specs_mcp_tools() {
|
||||||
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
|
||||||
@@ -742,7 +877,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[3],
|
tools[3].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "test_server/do_something_cool".to_string(),
|
name: "test_server/do_something_cool".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
@@ -911,7 +1046,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[4],
|
tools[4].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "dash/search".to_string(),
|
name: "dash/search".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
@@ -977,7 +1112,7 @@ mod tests {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[4],
|
tools[4].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "dash/paginate".to_string(),
|
name: "dash/paginate".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
@@ -1041,7 +1176,7 @@ mod tests {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[4],
|
tools[4].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "dash/tags".to_string(),
|
name: "dash/tags".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
@@ -1108,7 +1243,7 @@ mod tests {
|
|||||||
],
|
],
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[4],
|
tools[4].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "dash/value".to_string(),
|
name: "dash/value".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
@@ -1213,7 +1348,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[4],
|
tools[4].spec,
|
||||||
ToolSpec::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "test_server/do_something_cool".to_string(),
|
name: "test_server/do_something_cool".to_string(),
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ use std::time::Duration;
|
|||||||
use codex_core::protocol::EventMsg;
|
use codex_core::protocol::EventMsg;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
use core_test_support::responses::ev_function_call;
|
use core_test_support::responses::ev_function_call;
|
||||||
use core_test_support::responses::mount_sse_once_match;
|
use core_test_support::responses::mount_sse_once;
|
||||||
use core_test_support::responses::sse;
|
use core_test_support::responses::sse;
|
||||||
use core_test_support::responses::start_mock_server;
|
use core_test_support::responses::start_mock_server;
|
||||||
use core_test_support::test_codex::test_codex;
|
use core_test_support::test_codex::test_codex;
|
||||||
use core_test_support::wait_for_event_with_timeout;
|
use core_test_support::wait_for_event_with_timeout;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use wiremock::matchers::body_string_contains;
|
|
||||||
|
|
||||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||||
/// function call, then interrupt the session and expect TurnAborted.
|
/// function call, then interrupt the session and expect TurnAborted.
|
||||||
@@ -27,10 +27,13 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
|||||||
"timeout_ms": 60_000
|
"timeout_ms": 60_000
|
||||||
})
|
})
|
||||||
.to_string();
|
.to_string();
|
||||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
let body = sse(vec![
|
||||||
|
ev_function_call("call_sleep", "shell", &args),
|
||||||
|
ev_completed("done"),
|
||||||
|
]);
|
||||||
|
|
||||||
let server = start_mock_server().await;
|
let server = start_mock_server().await;
|
||||||
mount_sse_once_match(&server, body_string_contains("start sleep"), body).await;
|
mount_sse_once(&server, body).await;
|
||||||
|
|
||||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ mod shell_serialization;
|
|||||||
mod stream_error_allows_next_turn;
|
mod stream_error_allows_next_turn;
|
||||||
mod stream_no_completed;
|
mod stream_no_completed;
|
||||||
mod tool_harness;
|
mod tool_harness;
|
||||||
|
mod tool_parallelism;
|
||||||
mod tools;
|
mod tools;
|
||||||
mod unified_exec;
|
mod unified_exec;
|
||||||
mod user_notification;
|
mod user_notification;
|
||||||
|
|||||||
178
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
178
codex-rs/core/tests/suite/tool_parallelism.rs
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use codex_core::model_family::find_family_for_model;
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::mount_sse_sequence;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use core_test_support::wait_for_event;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
async fn run_turn(test: &TestCodex, prompt: &str) -> anyhow::Result<()> {
|
||||||
|
let session_model = test.session_configured.model.clone();
|
||||||
|
|
||||||
|
test.codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: prompt.into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: test.cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
wait_for_event(&test.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_turn_and_measure(test: &TestCodex, prompt: &str) -> anyhow::Result<Duration> {
|
||||||
|
let start = Instant::now();
|
||||||
|
run_turn(test, prompt).await?;
|
||||||
|
Ok(start.elapsed())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
async fn build_codex_with_test_tool(server: &wiremock::MockServer) -> anyhow::Result<TestCodex> {
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.model = "test-gpt-5-codex".to_string();
|
||||||
|
config.model_family =
|
||||||
|
find_family_for_model("test-gpt-5-codex").expect("test-gpt-5-codex model family");
|
||||||
|
});
|
||||||
|
builder.build(server).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_parallel_duration(actual: Duration) {
|
||||||
|
assert!(
|
||||||
|
actual < Duration::from_millis(500),
|
||||||
|
"expected parallel execution to finish quickly, got {actual:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn assert_serial_duration(actual: Duration) {
|
||||||
|
assert!(
|
||||||
|
actual >= Duration::from_millis(500),
|
||||||
|
"expected serial execution to take longer, got {actual:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn read_file_tools_run_in_parallel() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let test = build_codex_with_test_tool(&server).await?;
|
||||||
|
|
||||||
|
let parallel_args = json!({
|
||||||
|
"sleep_after_ms": 300,
|
||||||
|
"barrier": {
|
||||||
|
"id": "parallel-test-sync",
|
||||||
|
"participants": 2,
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call("call-1", "test_sync_tool", ¶llel_args),
|
||||||
|
ev_function_call("call-2", "test_sync_tool", ¶llel_args),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||||
|
|
||||||
|
let duration = run_turn_and_measure(&test, "exercise sync tool").await?;
|
||||||
|
assert_parallel_duration(duration);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn non_parallel_tools_run_serially() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let test = test_codex().build(&server).await?;
|
||||||
|
|
||||||
|
let shell_args = json!({
|
||||||
|
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
});
|
||||||
|
let args_one = serde_json::to_string(&shell_args)?;
|
||||||
|
let args_two = serde_json::to_string(&shell_args)?;
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call("call-1", "shell", &args_one),
|
||||||
|
ev_function_call("call-2", "shell", &args_two),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||||
|
|
||||||
|
let duration = run_turn_and_measure(&test, "run shell twice").await?;
|
||||||
|
assert_serial_duration(duration);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn mixed_tools_fall_back_to_serial() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let test = build_codex_with_test_tool(&server).await?;
|
||||||
|
|
||||||
|
let sync_args = json!({
|
||||||
|
"sleep_after_ms": 300
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
let shell_args = serde_json::to_string(&json!({
|
||||||
|
"command": ["/bin/sh", "-c", "sleep 0.3"],
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
}))?;
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call("call-1", "test_sync_tool", &sync_args),
|
||||||
|
ev_function_call("call-2", "shell", &shell_args),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||||
|
|
||||||
|
let duration = run_turn_and_measure(&test, "mix tools").await?;
|
||||||
|
assert_serial_duration(duration);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user