Simplify parallel (#4829)

make tool processing return a future and then collect futures.
handle cleanup on Drop
This commit is contained in:
pakrym-oai
2025-10-07 10:12:38 -07:00
committed by GitHub
parent 27f169bb91
commit f2555422b9
4 changed files with 58 additions and 114 deletions

View File

@@ -23,7 +23,9 @@ use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::TaskStartedEvent;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnContextItem;
use futures::future::BoxFuture;
use futures::prelude::*;
use futures::stream::FuturesOrdered;
use mcp_types::CallToolResult;
use serde_json;
use serde_json::Value;
@@ -2101,14 +2103,15 @@ async fn try_run_turn(
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut output = Vec::new();
let mut tool_runtime = ToolCallRuntime::new(
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
sub_id.to_string(),
);
let mut output: FuturesOrdered<BoxFuture<CodexResult<ProcessedResponseItem>>> =
FuturesOrdered::new();
loop {
// Poll the next item from the model stream. We must inspect *both* Ok and Err
@@ -2116,9 +2119,8 @@ async fn try_run_turn(
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let event = match event {
Some(event) => event,
Some(res) => res?,
None => {
tool_runtime.abort_all();
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
@@ -2126,14 +2128,8 @@ async fn try_run_turn(
}
};
let event = match event {
Ok(ev) => ev,
Err(e) => {
tool_runtime.abort_all();
// Propagate the underlying stream error to the caller (run_turn), which
// will apply the configured `stream_max_retries` policy.
return Err(e);
}
let add_completed = &mut |response_item: ProcessedResponseItem| {
output.push_back(future::ready(Ok(response_item)).boxed());
};
match event {
@@ -2143,14 +2139,18 @@ async fn try_run_turn(
Ok(Some(call)) => {
let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
let index = output.len();
output.push(ProcessedResponseItem {
item,
response: None,
});
tool_runtime
.handle_tool_call(call, index, output.as_mut_slice())
.await?;
let response = tool_runtime.handle_tool_call(call);
output.push_back(
async move {
Ok(ProcessedResponseItem {
item,
response: Some(response.await?),
})
}
.boxed(),
);
}
Ok(None) => {
let response = handle_non_tool_response_item(
@@ -2160,7 +2160,7 @@ async fn try_run_turn(
item.clone(),
)
.await?;
output.push(ProcessedResponseItem { item, response });
add_completed(ProcessedResponseItem { item, response });
}
Err(FunctionCallError::MissingLocalShellCallId) => {
let msg = "LocalShellCall without call_id or id";
@@ -2177,7 +2177,7 @@ async fn try_run_turn(
success: None,
},
};
output.push(ProcessedResponseItem {
add_completed(ProcessedResponseItem {
item,
response: Some(response),
});
@@ -2190,7 +2190,7 @@ async fn try_run_turn(
success: None,
},
};
output.push(ProcessedResponseItem {
add_completed(ProcessedResponseItem {
item,
response: Some(response),
});
@@ -2221,7 +2221,7 @@ async fn try_run_turn(
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await;
tool_runtime.resolve_pending(output.as_mut_slice()).await?;
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
@@ -2237,7 +2237,7 @@ async fn try_run_turn(
}
let result = TurnRunResult {
processed_items: output,
processed_items,
total_token_usage: token_usage.clone(),
};