feat: parallel tool calls (#4663)

Add parallel tool calls. This is configurable at model level and tool
level
This commit is contained in:
jif-oai
2025-10-05 17:10:49 +01:00
committed by GitHub
parent 3203862167
commit dc3c6bf62a
23 changed files with 961 additions and 244 deletions

View File

@@ -227,7 +227,7 @@ impl ModelClient {
input: &input_with_instructions,
tools: &tools_json,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: prompt.parallel_tool_calls,
reasoning,
store: azure_workaround,
stream: true,

View File

@@ -33,6 +33,9 @@ pub struct Prompt {
/// external MCP servers.
pub(crate) tools: Vec<ToolSpec>,
/// Whether parallel tool calls are permitted for this prompt.
pub(crate) parallel_tool_calls: bool,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,
@@ -288,6 +291,17 @@ pub(crate) mod tools {
Freeform(FreeformTool),
}
impl ToolSpec {
pub(crate) fn name(&self) -> &str {
match self {
ToolSpec::Function(tool) => tool.name.as_str(),
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(tool) => tool.name.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FreeformTool {
pub(crate) name: String,
@@ -433,7 +447,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
@@ -474,7 +488,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
@@ -510,7 +524,7 @@ mod tests {
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: false,
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,

View File

@@ -100,7 +100,9 @@ use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::tools::ToolRouter;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::format_exec_output_str;
use crate::tools::parallel::ToolCallRuntime;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
@@ -818,7 +820,7 @@ impl Session {
async fn on_exec_command_begin(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
exec_command_context: ExecCommandContext,
) {
let ExecCommandContext {
@@ -834,7 +836,10 @@ impl Session {
user_explicitly_approved_this_action,
changes,
}) => {
turn_diff_tracker.on_patch_begin(&changes);
{
let mut tracker = turn_diff_tracker.lock().await;
tracker.on_patch_begin(&changes);
}
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
@@ -861,7 +866,7 @@ impl Session {
async fn on_exec_command_end(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str,
call_id: &str,
output: &ExecToolCallOutput,
@@ -909,7 +914,10 @@ impl Session {
// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff();
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
@@ -926,7 +934,7 @@ impl Session {
/// Returns the output of the exec tool call.
pub(crate) async fn run_exec_with_events(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
turn_diff_tracker: SharedTurnDiffTracker,
prepared: PreparedExec,
approval_policy: AskForApproval,
) -> Result<ExecToolCallOutput, ExecError> {
@@ -935,7 +943,7 @@ impl Session {
let sub_id = context.sub_id.clone();
let call_id = context.call_id.clone();
self.on_exec_command_begin(turn_diff_tracker, context.clone())
self.on_exec_command_begin(turn_diff_tracker.clone(), context.clone())
.await;
let result = self
@@ -1644,7 +1652,7 @@ pub(crate) async fn run_task(
let mut last_agent_message: Option<String> = None;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new();
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut auto_compact_recently_attempted = false;
loop {
@@ -1692,9 +1700,9 @@ pub(crate) async fn run_task(
})
.collect();
match run_turn(
&sess,
turn_context.as_ref(),
&mut turn_diff_tracker,
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.clone(),
turn_input,
)
@@ -1917,18 +1925,27 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
}
async fn run_turn(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools));
let router = Arc::new(ToolRouter::from_config(
&turn_context.tools_config,
Some(mcp_tools),
));
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;
let parallel_tool_calls = model_supports_parallel;
let prompt = Prompt {
input,
tools: router.specs().to_vec(),
tools: router.specs(),
parallel_tool_calls,
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
};
@@ -1936,10 +1953,10 @@ async fn run_turn(
let mut retries = 0;
loop {
match try_run_turn(
&router,
sess,
turn_context,
turn_diff_tracker,
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
&sub_id,
&prompt,
)
@@ -1950,7 +1967,7 @@ async fn run_turn(
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e),
Err(e @ CodexErr::ContextWindowExceeded) => {
sess.set_total_tokens_full(&sub_id, turn_context).await;
sess.set_total_tokens_full(&sub_id, &turn_context).await;
return Err(e);
}
Err(CodexErr::UsageLimitReached(e)) => {
@@ -1999,9 +2016,9 @@ async fn run_turn(
/// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn.
#[derive(Debug)]
struct ProcessedResponseItem {
item: ResponseItem,
response: Option<ResponseInputItem>,
pub(crate) struct ProcessedResponseItem {
pub(crate) item: ResponseItem,
pub(crate) response: Option<ResponseInputItem>,
}
#[derive(Debug)]
@@ -2011,10 +2028,10 @@ struct TurnRunResult {
}
async fn try_run_turn(
router: &crate::tools::ToolRouter,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
router: Arc<ToolRouter>,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<TurnRunResult> {
@@ -2085,24 +2102,34 @@ async fn try_run_turn(
let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut output = Vec::new();
let mut tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.to_string(),
);
loop {
// Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let Some(event) = event else {
// Channel closed without yielding a final Completed event or explicit error.
// Treat as a disconnected stream so the caller can retry.
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
let event = match event {
Some(event) => event,
None => {
tool_runtime.abort_all();
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
}
};
let event = match event {
Ok(ev) => ev,
Err(e) => {
tool_runtime.abort_all();
// Propagate the underlying stream error to the caller (run_turn), which
// will apply the configured `stream_max_retries` policy.
return Err(e);
@@ -2112,16 +2139,66 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item(
router,
sess,
turn_context,
turn_diff_tracker,
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
match ToolRouter::build_tool_call(sess.as_ref(), item.clone()) {
Ok(Some(call)) => {
let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
let index = output.len();
output.push(ProcessedResponseItem {
item,
response: None,
});
tool_runtime
.handle_tool_call(call, index, output.as_mut_slice())
.await?;
}
Ok(None) => {
let response = handle_non_tool_response_item(
Arc::clone(&sess),
Arc::clone(&turn_context),
sub_id,
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::RespondToModel(message)) => {
let response = ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: message,
success: None,
},
};
output.push(ProcessedResponseItem {
item,
response: Some(response),
});
}
Err(FunctionCallError::Fatal(message)) => {
return Err(CodexErr::Fatal(message));
}
}
}
ResponseEvent::WebSearchCallBegin { call_id } => {
let _ = sess
@@ -2141,10 +2218,15 @@ async fn try_run_turn(
response_id: _,
token_usage,
} => {
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await;
let unified_diff = turn_diff_tracker.get_unified_diff();
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
@@ -2203,88 +2285,40 @@ async fn try_run_turn(
}
}
async fn handle_response_item(
router: &crate::tools::ToolRouter,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
async fn handle_non_tool_response_item(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: &str,
item: ResponseItem,
) -> CodexResult<Option<ResponseInputItem>> {
debug!(?item, "Output item");
match ToolRouter::build_tool_call(sess, item.clone()) {
Ok(Some(call)) => {
let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
match router
.dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call)
.await
{
Ok(response) => Ok(Some(response)),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => unreachable!("non-fatal tool error returned: {other:?}"),
match &item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => {
let msgs = match &item {
ResponseItem::Message { .. } if turn_context.is_review_mode => {
trace!("suppressing assistant Message in review mode");
Vec::new()
}
_ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()),
};
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
}
}
Ok(None) => {
match &item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => {
let msgs = match &item {
ResponseItem::Message { .. } if turn_context.is_review_mode => {
trace!("suppressing assistant Message in review mode");
Vec::new()
}
_ => map_response_item_to_event_messages(
&item,
sess.show_raw_agent_reasoning(),
),
};
for msg in msgs {
let event = Event {
id: sub_id.to_string(),
msg,
};
sess.send_event(event).await;
}
}
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::CustomToolCallOutput { .. } => {
debug!("unexpected tool output from stream");
}
_ => {}
}
Ok(None)
ResponseItem::FunctionCallOutput { .. } | ResponseItem::CustomToolCallOutput { .. } => {
debug!("unexpected tool output from stream");
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
turn_context
.client
.get_otel_event_manager()
.log_tool_failed("local_shell", msg);
error!(msg);
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg.to_string(),
success: None,
},
}))
}
Err(FunctionCallError::RespondToModel(msg)) => {
Ok(Some(ResponseInputItem::FunctionCallOutput {
call_id: String::new(),
output: FunctionCallOutputPayload {
content: msg,
success: None,
},
}))
}
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
_ => {}
}
Ok(None)
}
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
@@ -2927,13 +2961,10 @@ mod tests {
#[tokio::test]
async fn fatal_tool_error_stops_turn_and_reports_error() {
let (session, turn_context, _rx) = make_session_and_context_with_rx();
let session_ref = session.as_ref();
let turn_context_ref = turn_context.as_ref();
let router = ToolRouter::from_config(
&turn_context_ref.tools_config,
Some(session_ref.services.mcp_connection_manager.list_all_tools()),
&turn_context.tools_config,
Some(session.services.mcp_connection_manager.list_all_tools()),
);
let mut tracker = TurnDiffTracker::new();
let item = ResponseItem::CustomToolCall {
id: None,
status: None,
@@ -2942,22 +2973,26 @@ mod tests {
input: "{}".to_string(),
};
let err = handle_response_item(
&router,
session_ref,
turn_context_ref,
&mut tracker,
"sub-id",
item,
)
.await
.expect_err("expected fatal error");
let call = ToolRouter::build_tool_call(session.as_ref(), item.clone())
.expect("build tool call")
.expect("tool call present");
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let err = router
.dispatch_tool_call(
Arc::clone(&session),
Arc::clone(&turn_context),
tracker,
"sub-id".to_string(),
call,
)
.await
.expect_err("expected fatal error");
match err {
CodexErr::Fatal(message) => {
FunctionCallError::Fatal(message) => {
assert_eq!(message, "tool shell invoked with incompatible payload");
}
other => panic!("expected CodexErr::Fatal, got {other:?}"),
other => panic!("expected FunctionCallError::Fatal, got {other:?}"),
}
}
@@ -3071,9 +3106,11 @@ mod tests {
use crate::turn_diff_tracker::TurnDiffTracker;
use std::collections::HashMap;
let (session, mut turn_context) = make_session_and_context();
let (session, mut turn_context_raw) = make_session_and_context();
// Ensure policy is NOT OnRequest so the early rejection path triggers
turn_context.approval_policy = AskForApproval::OnFailure;
turn_context_raw.approval_policy = AskForApproval::OnFailure;
let session = Arc::new(session);
let mut turn_context = Arc::new(turn_context_raw);
let params = ExecParams {
command: if cfg!(windows) {
@@ -3101,7 +3138,7 @@ mod tests {
..params.clone()
};
let mut turn_diff_tracker = TurnDiffTracker::new();
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let tool_name = "shell";
let sub_id = "test-sub".to_string();
@@ -3110,9 +3147,9 @@ mod tests {
let resp = handle_container_exec_with_params(
tool_name,
params,
&session,
&turn_context,
&mut turn_diff_tracker,
Arc::clone(&session),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id,
call_id,
)
@@ -3131,14 +3168,16 @@ mod tests {
// Now retry the same command WITHOUT escalated permissions; should succeed.
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
Arc::get_mut(&mut turn_context)
.expect("unique turn context Arc")
.sandbox_policy = SandboxPolicy::DangerFullAccess;
let resp2 = handle_container_exec_with_params(
tool_name,
params2,
&session,
&turn_context,
&mut turn_diff_tracker,
Arc::clone(&session),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
"test-sub".to_string(),
"test-call-2".to_string(),
)

View File

@@ -35,6 +35,10 @@ pub struct ModelFamily {
// See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool,
/// Whether this model supports parallel tool calls when using the
/// Responses API.
pub supports_parallel_tool_calls: bool,
/// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
@@ -58,6 +62,7 @@ macro_rules! model_family {
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),
@@ -103,6 +108,18 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("gpt-3.5") {
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("test-gpt-5-codex") {
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
experimental_supported_tools: vec![
"read_file".to_string(),
"test_sync_tool".to_string()
],
supports_parallel_tool_calls: true,
)
} else if slug.starts_with("codex-") || slug.starts_with("gpt-5-codex") {
model_family!(
slug, slug,
@@ -110,6 +127,8 @@ pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
// experimental_supported_tools: vec!["read_file".to_string()],
// supports_parallel_tool_calls: true,
)
} else if slug.starts_with("gpt-5") {
model_family!(
@@ -130,6 +149,7 @@ pub fn derive_default_model_family(model: &str) -> ModelFamily {
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
uses_local_shell_tool: false,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),

View File

@@ -14,12 +14,17 @@ use mcp_types::CallToolResult;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct ToolInvocation<'a> {
pub session: &'a Session,
pub turn: &'a TurnContext,
pub tracker: &'a mut TurnDiffTracker,
pub sub_id: &'a str,
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
#[derive(Clone)]
pub struct ToolInvocation {
pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub tracker: SharedTurnDiffTracker,
pub sub_id: String,
pub call_id: String,
pub tool_name: String,
pub payload: ToolPayload,

View File

@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::FreeformTool;
use crate::client_common::tools::FreeformToolFormat;
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
tool_name,

View File

@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
ToolKind::Mcp
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
let arguments_str = raw_arguments;
let response = handle_mcp_tool_call(
session,
sub_id,
session.as_ref(),
&sub_id,
call_id.clone(),
server,
tool,

View File

@@ -4,6 +4,7 @@ mod mcp;
mod plan;
mod read_file;
mod shell;
mod test_sync;
mod unified_exec;
mod view_image;
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
pub use plan::PlanHandler;
pub use read_file::ReadFileHandler;
pub use shell::ShellHandler;
pub use test_sync::TestSyncHandler;
pub use unified_exec::UnifiedExecHandler;
pub use view_image::ViewImageHandler;

View File

@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
}
};
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
let content =
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
Ok(ToolOutput::Function {
content,

View File

@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {

View File

@@ -1,5 +1,6 @@
use async_trait::async_trait;
use codex_protocol::models::ShellToolCallParams;
use std::sync::Arc;
use crate::codex::TurnContext;
use crate::exec::ExecParams;
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
"failed to parse function arguments: {e:?}"
))
})?;
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
})
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::sync::Barrier;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct TestSyncHandler;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
struct BarrierState {
barrier: Arc<Barrier>,
participants: usize,
}
#[derive(Debug, Deserialize)]
struct BarrierArgs {
id: String,
participants: usize,
#[serde(default = "default_timeout_ms")]
timeout_ms: u64,
}
#[derive(Debug, Deserialize)]
struct TestSyncArgs {
#[serde(default)]
sleep_before_ms: Option<u64>,
#[serde(default)]
sleep_after_ms: Option<u64>,
#[serde(default)]
barrier: Option<BarrierArgs>,
}
fn default_timeout_ms() -> u64 {
DEFAULT_TIMEOUT_MS
}
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
}
#[async_trait]
impl ToolHandler for TestSyncHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"test_sync_tool handler received unsupported payload".to_string(),
));
}
};
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
if let Some(delay) = args.sleep_before_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
if let Some(barrier) = args.barrier {
wait_on_barrier(barrier).await?;
}
if let Some(delay) = args.sleep_after_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
Ok(ToolOutput::Function {
content: "ok".to_string(),
success: Some(true),
})
}
}
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
if args.participants == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier participants must be greater than zero".to_string(),
));
}
if args.timeout_ms == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier timeout must be greater than zero".to_string(),
));
}
let barrier_id = args.id.clone();
let barrier = {
let mut map = barrier_map().lock().await;
match map.entry(barrier_id.clone()) {
Entry::Occupied(entry) => {
let state = entry.get();
if state.participants != args.participants {
let existing = state.participants;
return Err(FunctionCallError::RespondToModel(format!(
"barrier {barrier_id} already registered with {existing} participants"
)));
}
state.barrier.clone()
}
Entry::Vacant(entry) => {
let barrier = Arc::new(Barrier::new(args.participants));
entry.insert(BarrierState {
barrier: barrier.clone(),
participants: args.participants,
});
barrier
}
}
};
let timeout = Duration::from_millis(args.timeout_ms);
let wait_result = tokio::time::timeout(timeout, barrier.wait())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
})?;
if wait_result.is_leader() {
let mut map = barrier_map().lock().await;
if let Some(state) = map.get(&barrier_id)
&& Arc::ptr_eq(&state.barrier, &barrier)
{
map.remove(&barrier_id);
}
}
Ok(())
}

View File

@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session, payload, ..
} = invocation;

View File

@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,

View File

@@ -1,5 +1,6 @@
pub mod context;
pub(crate) mod handlers;
pub mod parallel;
pub mod registry;
pub mod router;
pub mod spec;
@@ -21,7 +22,7 @@ use crate::executor::linkers::PreparedExec;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ApplyPatchCommandContext;
use crate::tools::context::ExecCommandContext;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::tools::context::SharedTurnDiffTracker;
use codex_apply_patch::MaybeApplyPatchVerified;
use codex_apply_patch::maybe_parse_apply_patch_verified;
use codex_protocol::protocol::AskForApproval;
@@ -29,6 +30,7 @@ use codex_utils_string::take_bytes_at_char_boundary;
use codex_utils_string::take_last_bytes_at_char_boundary;
pub use router::ToolRouter;
use serde::Serialize;
use std::sync::Arc;
use tracing::trace;
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
@@ -48,9 +50,9 @@ pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
pub(crate) async fn handle_container_exec_with_params(
tool_name: &str,
params: ExecParams,
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String,
call_id: String,
) -> Result<String, FunctionCallError> {
@@ -68,7 +70,15 @@ pub(crate) async fn handle_container_exec_with_params(
// check if this was a patch, and apply it if so
let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) {
MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
match apply_patch::apply_patch(
sess.as_ref(),
turn_context.as_ref(),
&sub_id,
&call_id,
changes,
)
.await
{
InternalApplyPatchInvocation::Output(item) => return item,
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
Some(apply_patch_exec)
@@ -139,7 +149,7 @@ pub(crate) async fn handle_container_exec_with_params(
let output_result = sess
.run_exec_with_events(
turn_diff_tracker,
turn_diff_tracker.clone(),
prepared_exec,
turn_context.approval_policy,
)

View File

@@ -0,0 +1,137 @@
use std::sync::Arc;
use tokio::task::JoinHandle;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::CodexErr;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter;
use codex_protocol::models::ResponseInputItem;
use crate::codex::ProcessedResponseItem;
struct PendingToolCall {
index: usize,
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
}
pub(crate) struct ToolCallRuntime {
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
pending_calls: Vec<PendingToolCall>,
}
impl ToolCallRuntime {
pub(crate) fn new(
router: Arc<ToolRouter>,
session: Arc<Session>,
turn_context: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
) -> Self {
Self {
router,
session,
turn_context,
tracker,
sub_id,
pending_calls: Vec::new(),
}
}
pub(crate) async fn handle_tool_call(
&mut self,
call: ToolCall,
output_index: usize,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
if supports_parallel {
self.spawn_parallel(call, output_index);
} else {
self.resolve_pending(output).await?;
let response = self.dispatch_serial(call).await?;
let slot = output.get_mut(output_index).ok_or_else(|| {
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
})?;
slot.response = Some(response);
}
Ok(())
}
pub(crate) fn abort_all(&mut self) {
while let Some(pending) = self.pending_calls.pop() {
pending.handle.abort();
}
}
pub(crate) async fn resolve_pending(
&mut self,
output: &mut [ProcessedResponseItem],
) -> Result<(), CodexErr> {
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
match handle.await {
Ok(Ok(response)) => {
if let Some(slot) = output.get_mut(index) {
slot.response = Some(response);
}
}
Ok(Err(FunctionCallError::Fatal(message))) => {
self.abort_all();
return Err(CodexErr::Fatal(message));
}
Ok(Err(other)) => {
self.abort_all();
return Err(CodexErr::Fatal(other.to_string()));
}
Err(join_err) => {
self.abort_all();
return Err(CodexErr::Fatal(format!(
"tool task failed to join: {join_err}"
)));
}
}
}
Ok(())
}
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
let router = Arc::clone(&self.router);
let session = Arc::clone(&self.session);
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let sub_id = self.sub_id.clone();
let handle = tokio::spawn(async move {
router
.dispatch_tool_call(session, turn, tracker, sub_id, call)
.await
});
self.pending_calls.push(PendingToolCall { index, handle });
}
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
match self
.router
.dispatch_tool_call(
Arc::clone(&self.session),
Arc::clone(&self.turn_context),
Arc::clone(&self.tracker),
self.sub_id.clone(),
call,
)
.await
{
Ok(response) => Ok(response),
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
Err(other) => Err(CodexErr::Fatal(other.to_string())),
}
}
}

View File

@@ -32,8 +32,7 @@ pub trait ToolHandler: Send + Sync {
)
}
async fn handle(&self, invocation: ToolInvocation<'_>)
-> Result<ToolOutput, FunctionCallError>;
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
}
pub struct ToolRegistry {
@@ -57,9 +56,9 @@ impl ToolRegistry {
// }
// }
pub async fn dispatch<'a>(
pub async fn dispatch(
&self,
invocation: ToolInvocation<'a>,
invocation: ToolInvocation,
) -> Result<ResponseInputItem, FunctionCallError> {
let tool_name = invocation.tool_name.clone();
let call_id_owned = invocation.call_id.clone();
@@ -137,9 +136,24 @@ impl ToolRegistry {
}
}
#[derive(Debug, Clone)]
pub struct ConfiguredToolSpec {
pub spec: ToolSpec,
pub supports_parallel_tool_calls: bool,
}
impl ConfiguredToolSpec {
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
Self {
spec,
supports_parallel_tool_calls,
}
}
}
pub struct ToolRegistryBuilder {
handlers: HashMap<String, Arc<dyn ToolHandler>>,
specs: Vec<ToolSpec>,
specs: Vec<ConfiguredToolSpec>,
}
impl ToolRegistryBuilder {
@@ -151,7 +165,16 @@ impl ToolRegistryBuilder {
}
pub fn push_spec(&mut self, spec: ToolSpec) {
self.specs.push(spec);
self.push_spec_with_parallel_support(spec, false);
}
pub fn push_spec_with_parallel_support(
&mut self,
spec: ToolSpec,
supports_parallel_tool_calls: bool,
) {
self.specs
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
}
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
@@ -183,7 +206,7 @@ impl ToolRegistryBuilder {
// }
// }
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
let registry = ToolRegistry::new(self.handlers);
(self.specs, registry)
}

View File

@@ -1,15 +1,17 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::ToolSpec;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ConfiguredToolSpec;
use crate::tools::registry::ToolRegistry;
use crate::tools::spec::ToolsConfig;
use crate::tools::spec::build_specs;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_protocol::models::LocalShellAction;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
@@ -24,7 +26,7 @@ pub struct ToolCall {
pub struct ToolRouter {
registry: ToolRegistry,
specs: Vec<ToolSpec>,
specs: Vec<ConfiguredToolSpec>,
}
impl ToolRouter {
@@ -34,11 +36,22 @@ impl ToolRouter {
) -> Self {
let builder = build_specs(config, mcp_tools);
let (specs, registry) = builder.build();
Self { registry, specs }
}
pub fn specs(&self) -> &[ToolSpec] {
&self.specs
pub fn specs(&self) -> Vec<ToolSpec> {
self.specs
.iter()
.map(|config| config.spec.clone())
.collect()
}
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
self.specs
.iter()
.filter(|config| config.supports_parallel_tool_calls)
.any(|config| config.spec.name() == tool_name)
}
pub fn build_tool_call(
@@ -118,10 +131,10 @@ impl ToolRouter {
pub async fn dispatch_tool_call(
&self,
session: &Session,
turn: &TurnContext,
tracker: &mut TurnDiffTracker,
sub_id: &str,
session: Arc<Session>,
turn: Arc<TurnContext>,
tracker: SharedTurnDiffTracker,
sub_id: String,
call: ToolCall,
) -> Result<ResponseInputItem, FunctionCallError> {
let ToolCall {

View File

@@ -258,6 +258,68 @@ fn create_view_image_tool() -> ToolSpec {
})
}
fn create_test_sync_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"sleep_before_ms".to_string(),
JsonSchema::Number {
description: Some("Optional delay in milliseconds before any other action".to_string()),
},
);
properties.insert(
"sleep_after_ms".to_string(),
JsonSchema::Number {
description: Some(
"Optional delay in milliseconds after completing the barrier".to_string(),
),
},
);
let mut barrier_properties = BTreeMap::new();
barrier_properties.insert(
"id".to_string(),
JsonSchema::String {
description: Some(
"Identifier shared by concurrent calls that should rendezvous".to_string(),
),
},
);
barrier_properties.insert(
"participants".to_string(),
JsonSchema::Number {
description: Some(
"Number of tool calls that must arrive before the barrier opens".to_string(),
),
},
);
barrier_properties.insert(
"timeout_ms".to_string(),
JsonSchema::Number {
description: Some("Maximum time in milliseconds to wait at the barrier".to_string()),
},
);
properties.insert(
"barrier".to_string(),
JsonSchema::Object {
properties: barrier_properties,
required: Some(vec!["id".to_string(), "participants".to_string()]),
additional_properties: Some(false.into()),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "test_sync_tool".to_string(),
description: "Internal synchronization helper used by Codex integration tests.".to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: None,
additional_properties: Some(false.into()),
},
})
}
fn create_read_file_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
@@ -507,6 +569,7 @@ pub(crate) fn build_specs(
use crate::tools::handlers::PlanHandler;
use crate::tools::handlers::ReadFileHandler;
use crate::tools::handlers::ShellHandler;
use crate::tools::handlers::TestSyncHandler;
use crate::tools::handlers::UnifiedExecHandler;
use crate::tools::handlers::ViewImageHandler;
use std::sync::Arc;
@@ -573,16 +636,26 @@ pub(crate) fn build_specs(
.any(|tool| tool == "read_file")
{
let read_file_handler = Arc::new(ReadFileHandler);
builder.push_spec(create_read_file_tool());
builder.push_spec_with_parallel_support(create_read_file_tool(), true);
builder.register_handler("read_file", read_file_handler);
}
if config
.experimental_supported_tools
.iter()
.any(|tool| tool == "test_sync_tool")
{
let test_sync_handler = Arc::new(TestSyncHandler);
builder.push_spec_with_parallel_support(create_test_sync_tool(), true);
builder.register_handler("test_sync_tool", test_sync_handler);
}
if config.web_search_request {
builder.push_spec(ToolSpec::WebSearch {});
}
if config.include_view_image_tool {
builder.push_spec(create_view_image_tool());
builder.push_spec_with_parallel_support(create_view_image_tool(), true);
builder.register_handler("view_image", view_image_handler);
}
@@ -610,20 +683,25 @@ pub(crate) fn build_specs(
mod tests {
use crate::client_common::tools::FreeformTool;
use crate::model_family::find_family_for_model;
use crate::tools::registry::ConfiguredToolSpec;
use mcp_types::ToolInputSchema;
use pretty_assertions::assert_eq;
use super::*;
fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) {
fn tool_name(tool: &ToolSpec) -> &str {
match tool {
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
}
}
fn assert_eq_tool_names(tools: &[ConfiguredToolSpec], expected_names: &[&str]) {
let tool_names = tools
.iter()
.map(|tool| match tool {
ToolSpec::Function(ResponsesApiTool { name, .. }) => name,
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(FreeformTool { name, .. }) => name,
})
.map(|tool| tool_name(&tool.spec))
.collect::<Vec<_>>();
assert_eq!(
@@ -639,6 +717,16 @@ mod tests {
}
}
fn find_tool<'a>(
tools: &'a [ConfiguredToolSpec],
expected_name: &str,
) -> &'a ConfiguredToolSpec {
tools
.iter()
.find(|tool| tool_name(&tool.spec) == expected_name)
.unwrap_or_else(|| panic!("expected tool {expected_name}"))
}
#[test]
fn test_build_specs() {
let model_family = find_family_for_model("codex-mini-latest")
@@ -680,6 +768,53 @@ mod tests {
);
}
#[test]
#[ignore]
fn test_parallel_support_flags() {
let model_family = find_family_for_model("gpt-5-codex")
.expect("codex-mini-latest should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: true,
});
let (tools, _) = build_specs(&config, None).build();
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
}
#[test]
fn test_test_model_family_includes_sync_tool() {
let model_family = find_family_for_model("test-gpt-5-codex")
.expect("test-gpt-5-codex should be a valid model family");
let config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
include_plan_tool: false,
include_apply_patch_tool: false,
include_web_search_request: false,
use_streamable_shell_tool: false,
include_view_image_tool: false,
experimental_unified_exec_tool: false,
});
let (tools, _) = build_specs(&config, None).build();
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "test_sync_tool")
);
assert!(
tools
.iter()
.any(|tool| tool_name(&tool.spec) == "read_file")
);
}
#[test]
fn test_build_specs_mcp_tools() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
@@ -742,7 +877,7 @@ mod tests {
);
assert_eq!(
tools[3],
tools[3].spec,
ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {
@@ -911,7 +1046,7 @@ mod tests {
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/search".to_string(),
parameters: JsonSchema::Object {
@@ -977,7 +1112,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/paginate".to_string(),
parameters: JsonSchema::Object {
@@ -1041,7 +1176,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/tags".to_string(),
parameters: JsonSchema::Object {
@@ -1108,7 +1243,7 @@ mod tests {
],
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/value".to_string(),
parameters: JsonSchema::Object {
@@ -1213,7 +1348,7 @@ mod tests {
);
assert_eq!(
tools[4],
tools[4].spec,
ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {