feat: parallel tool calls (#4663)

Add parallel tool calls. This is configurable at model level and tool
level
This commit is contained in:
jif-oai
2025-10-05 17:10:49 +01:00
committed by GitHub
parent 3203862167
commit dc3c6bf62a
23 changed files with 961 additions and 244 deletions

View File

@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use crate::client_common::tools::FreeformTool;
use crate::client_common::tools::FreeformToolFormat;
@@ -36,10 +37,7 @@ impl ToolHandler for ApplyPatchHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -79,10 +77,10 @@ impl ToolHandler for ApplyPatchHandler {
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -19,10 +19,7 @@ impl ToolHandler for ExecStreamHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
tool_name,

View File

@@ -16,10 +16,7 @@ impl ToolHandler for McpHandler {
ToolKind::Mcp
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -45,8 +42,8 @@ impl ToolHandler for McpHandler {
let arguments_str = raw_arguments;
let response = handle_mcp_tool_call(
session,
sub_id,
session.as_ref(),
&sub_id,
call_id.clone(),
server,
tool,

View File

@@ -4,6 +4,7 @@ mod mcp;
mod plan;
mod read_file;
mod shell;
mod test_sync;
mod unified_exec;
mod view_image;
@@ -15,5 +16,6 @@ pub use mcp::McpHandler;
pub use plan::PlanHandler;
pub use read_file::ReadFileHandler;
pub use shell::ShellHandler;
pub use test_sync::TestSyncHandler;
pub use unified_exec::UnifiedExecHandler;
pub use view_image::ViewImageHandler;

View File

@@ -65,10 +65,7 @@ impl ToolHandler for PlanHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
sub_id,
@@ -86,7 +83,8 @@ impl ToolHandler for PlanHandler {
}
};
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
let content =
handle_update_plan(session.as_ref(), arguments, sub_id.clone(), call_id).await?;
Ok(ToolOutput::Function {
content,

View File

@@ -42,10 +42,7 @@ impl ToolHandler for ReadFileHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {

View File

@@ -1,5 +1,6 @@
use async_trait::async_trait;
use codex_protocol::models::ShellToolCallParams;
use std::sync::Arc;
use crate::codex::TurnContext;
use crate::exec::ExecParams;
@@ -40,10 +41,7 @@ impl ToolHandler for ShellHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
@@ -62,14 +60,14 @@ impl ToolHandler for ShellHandler {
"failed to parse function arguments: {e:?}"
))
})?;
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;
@@ -79,14 +77,14 @@ impl ToolHandler for ShellHandler {
})
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn);
let exec_params = Self::to_exec_params(params, turn.as_ref());
let content = handle_container_exec_with_params(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
sub_id.to_string(),
Arc::clone(&session),
Arc::clone(&turn),
Arc::clone(&tracker),
sub_id.clone(),
call_id.clone(),
)
.await?;

View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::sync::Barrier;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct TestSyncHandler;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
struct BarrierState {
barrier: Arc<Barrier>,
participants: usize,
}
#[derive(Debug, Deserialize)]
struct BarrierArgs {
id: String,
participants: usize,
#[serde(default = "default_timeout_ms")]
timeout_ms: u64,
}
#[derive(Debug, Deserialize)]
struct TestSyncArgs {
#[serde(default)]
sleep_before_ms: Option<u64>,
#[serde(default)]
sleep_after_ms: Option<u64>,
#[serde(default)]
barrier: Option<BarrierArgs>,
}
fn default_timeout_ms() -> u64 {
DEFAULT_TIMEOUT_MS
}
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
}
#[async_trait]
impl ToolHandler for TestSyncHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"test_sync_tool handler received unsupported payload".to_string(),
));
}
};
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
if let Some(delay) = args.sleep_before_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
if let Some(barrier) = args.barrier {
wait_on_barrier(barrier).await?;
}
if let Some(delay) = args.sleep_after_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
Ok(ToolOutput::Function {
content: "ok".to_string(),
success: Some(true),
})
}
}
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
if args.participants == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier participants must be greater than zero".to_string(),
));
}
if args.timeout_ms == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier timeout must be greater than zero".to_string(),
));
}
let barrier_id = args.id.clone();
let barrier = {
let mut map = barrier_map().lock().await;
match map.entry(barrier_id.clone()) {
Entry::Occupied(entry) => {
let state = entry.get();
if state.participants != args.participants {
let existing = state.participants;
return Err(FunctionCallError::RespondToModel(format!(
"barrier {barrier_id} already registered with {existing} participants"
)));
}
state.barrier.clone()
}
Entry::Vacant(entry) => {
let barrier = Arc::new(Barrier::new(args.participants));
entry.insert(BarrierState {
barrier: barrier.clone(),
participants: args.participants,
});
barrier
}
}
};
let timeout = Duration::from_millis(args.timeout_ms);
let wait_result = tokio::time::timeout(timeout, barrier.wait())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
})?;
if wait_result.is_leader() {
let mut map = barrier_map().lock().await;
if let Some(state) = map.get(&barrier_id)
&& Arc::ptr_eq(&state.barrier, &barrier)
{
map.remove(&barrier_id);
}
}
Ok(())
}

View File

@@ -33,10 +33,7 @@ impl ToolHandler for UnifiedExecHandler {
)
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session, payload, ..
} = invocation;

View File

@@ -26,10 +26,7 @@ impl ToolHandler for ViewImageHandler {
ToolKind::Function
}
async fn handle(
&self,
invocation: ToolInvocation<'_>,
) -> Result<ToolOutput, FunctionCallError> {
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,