Simplify parallel (#4829)
make tool processing return a future and then collect futures. handle cleanup on Drop
This commit is contained in:
@@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::prelude::*;
|
||||
use futures::stream::FuturesOrdered;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde_json;
|
||||
use serde_json::Value;
|
||||
@@ -2101,14 +2103,15 @@ async fn try_run_turn(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut tool_runtime = ToolCallRuntime::new(
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
sub_id.to_string(),
|
||||
);
|
||||
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
|
||||
FuturesOrdered::new();
|
||||
|
||||
loop {
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
@@ -2116,9 +2119,8 @@ async fn try_run_turn(
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let event = match event {
|
||||
Some(event) => event,
|
||||
Some(res) => res?,
|
||||
None => {
|
||||
tool_runtime.abort_all();
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
@@ -2126,14 +2128,8 @@ async fn try_run_turn(
|
||||
}
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Ok(ev) => ev,
|
||||
Err(e) => {
|
||||
tool_runtime.abort_all();
|
||||
// Propagate the underlying stream error to the caller (run_turn), which
|
||||
// will apply the configured `stream_max_retries` policy.
|
||||
return Err(e);
|
||||
}
|
||||
let add_completed = &mut |response_item: ProcessedResponseItem| {
|
||||
output.push_back(future::ready(Ok(response_item)).boxed());
|
||||
};
|
||||
|
||||
match event {
|
||||
@@ -2143,14 +2139,18 @@ async fn try_run_turn(
|
||||
Ok(Some(call)) => {
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
let index = output.len();
|
||||
output.push(ProcessedResponseItem {
|
||||
item,
|
||||
response: None,
|
||||
});
|
||||
tool_runtime
|
||||
.handle_tool_call(call, index, output.as_mut_slice())
|
||||
.await?;
|
||||
|
||||
let response = tool_runtime.handle_tool_call(call);
|
||||
|
||||
output.push_back(
|
||||
async move {
|
||||
Ok(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response.await?),
|
||||
})
|
||||
}
|
||||
.boxed(),
|
||||
);
|
||||
}
|
||||
Ok(None) => {
|
||||
let response = handle_non_tool_response_item(
|
||||
@@ -2160,7 +2160,7 @@ async fn try_run_turn(
|
||||
item.clone(),
|
||||
)
|
||||
.await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
add_completed(ProcessedResponseItem { item, response });
|
||||
}
|
||||
Err(FunctionCallError::MissingLocalShellCallId) => {
|
||||
let msg = "LocalShellCall without call_id or id";
|
||||
@@ -2177,7 +2177,7 @@ async fn try_run_turn(
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
output.push(ProcessedResponseItem {
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
@@ -2190,7 +2190,7 @@ async fn try_run_turn(
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
output.push(ProcessedResponseItem {
|
||||
add_completed(ProcessedResponseItem {
|
||||
item,
|
||||
response: Some(response),
|
||||
});
|
||||
@@ -2221,7 +2221,7 @@ async fn try_run_turn(
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
|
||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
@@ -2237,7 +2237,7 @@ async fn try_run_turn(
|
||||
}
|
||||
|
||||
let result = TurnRunResult {
|
||||
processed_items: output,
|
||||
processed_items,
|
||||
total_token_usage: token_usage.clone(),
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
@@ -11,20 +13,13 @@ use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
use crate::codex::ProcessedResponseItem;
|
||||
|
||||
struct PendingToolCall {
|
||||
index: usize,
|
||||
handle: JoinHandle<Result<ResponseInputItem, FunctionCallError>>,
|
||||
}
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
sub_id: String,
|
||||
pending_calls: Vec<PendingToolCall>,
|
||||
parallel_execution: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
@@ -41,97 +36,45 @@ impl ToolCallRuntime {
|
||||
turn_context,
|
||||
tracker,
|
||||
sub_id,
|
||||
pending_calls: Vec::new(),
|
||||
parallel_execution: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_tool_call(
|
||||
&mut self,
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
output_index: usize,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
if supports_parallel {
|
||||
self.spawn_parallel(call, output_index);
|
||||
} else {
|
||||
self.resolve_pending(output).await?;
|
||||
let response = self.dispatch_serial(call).await?;
|
||||
let slot = output.get_mut(output_index).ok_or_else(|| {
|
||||
CodexErr::Fatal(format!("tool output index {output_index} out of bounds"))
|
||||
})?;
|
||||
slot.response = Some(response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn abort_all(&mut self) {
|
||||
while let Some(pending) = self.pending_calls.pop() {
|
||||
pending.handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn resolve_pending(
|
||||
&mut self,
|
||||
output: &mut [ProcessedResponseItem],
|
||||
) -> Result<(), CodexErr> {
|
||||
while let Some(PendingToolCall { index, handle }) = self.pending_calls.pop() {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => {
|
||||
if let Some(slot) = output.get_mut(index) {
|
||||
slot.response = Some(response);
|
||||
}
|
||||
}
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(message));
|
||||
}
|
||||
Ok(Err(other)) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(other.to_string()));
|
||||
}
|
||||
Err(join_err) => {
|
||||
self.abort_all();
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to join: {join_err}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_parallel(&mut self, call: ToolCall, index: usize) {
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let sub_id = self.sub_id.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
});
|
||||
self.pending_calls.push(PendingToolCall { index, handle });
|
||||
}
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
|
||||
async fn dispatch_serial(&self, call: ToolCall) -> Result<ResponseInputItem, CodexErr> {
|
||||
match self
|
||||
.router
|
||||
.dispatch_tool_call(
|
||||
Arc::clone(&self.session),
|
||||
Arc::clone(&self.turn_context),
|
||||
Arc::clone(&self.tracker),
|
||||
self.sub_id.clone(),
|
||||
call,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
||||
Err(other) => Err(CodexErr::Fatal(other.to_string())),
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, sub_id, call)
|
||||
.await
|
||||
}));
|
||||
|
||||
async move {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
|
||||
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
|
||||
Err(err) => Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to receive: {err:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user