Handle cancelling/aborting while processing a turn (#5543)
Currently we collect all all turn items in a vector, then we add it to the history on success. This result in losing those items on errors including aborting `ctrl+c`. This PR: - Adds the ability for the tool call to handle cancellation - bubble the turn items up to where we are recording this info Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of error edge cases. The right thing to do is recording to the history on the spot as `items`/`tool calls output` come. However, this isn't possible because of having different `task_kind` that has different `conversation_histories`. The `try_run_turn` has no idea what thread are we using. We cannot also pass an `arc` to the `conversation_histories` because it's a private element of `state`. That's said, `abort` is the most common case and we should cover it until we remove `task kind`
This commit is contained in:
@@ -10,6 +10,7 @@ use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp::auth::McpAuthStatusEntry;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::parse_turn_item;
|
||||
use crate::response_processing::process_items;
|
||||
use crate::review_format::format_review_findings_block;
|
||||
use crate::terminal;
|
||||
use crate::user_notification::UserNotifier;
|
||||
@@ -855,7 +856,7 @@ impl Session {
|
||||
|
||||
/// Records input items: always append to conversation history and
|
||||
/// persist these response items to rollout.
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
pub(crate) async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
self.record_into_history(items).await;
|
||||
self.persist_rollout_response_items(items).await;
|
||||
}
|
||||
@@ -1608,109 +1609,13 @@ pub(crate) async fn run_task(
|
||||
let token_limit_reached = total_usage_tokens
|
||||
.map(|tokens| tokens >= limit)
|
||||
.unwrap_or(false);
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
for processed_response_item in processed_items {
|
||||
let ProcessedResponseItem { item, response } = processed_response_item;
|
||||
match (&item, &response) {
|
||||
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
||||
// If the model returned a message, we need to record it.
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
}
|
||||
(
|
||||
ResponseItem::LocalShellCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::CustomToolCall { .. },
|
||||
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let output = match result {
|
||||
Ok(call_tool_result) => {
|
||||
convert_call_tool_result_to_function_call_output_payload(
|
||||
call_tool_result,
|
||||
)
|
||||
}
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: err.clone(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output,
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content,
|
||||
},
|
||||
None,
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||
id: id.clone(),
|
||||
summary: summary.clone(),
|
||||
content: content.clone(),
|
||||
encrypted_content: encrypted_content.clone(),
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||
}
|
||||
};
|
||||
if let Some(response) = response {
|
||||
responses.push(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
if is_review_mode {
|
||||
review_thread_history
|
||||
.record_items(items_to_record_in_conversation_history.iter());
|
||||
} else {
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
let (responses, items_to_record_in_conversation_history) = process_items(
|
||||
processed_items,
|
||||
is_review_mode,
|
||||
&mut review_thread_history,
|
||||
&sess,
|
||||
)
|
||||
.await;
|
||||
|
||||
if token_limit_reached {
|
||||
if auto_compact_recently_attempted {
|
||||
@@ -1749,7 +1654,16 @@ pub(crate) async fn run_task(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(CodexErr::TurnAborted) => {
|
||||
Err(CodexErr::TurnAborted {
|
||||
dangling_artifacts: processed_items,
|
||||
}) => {
|
||||
let _ = process_items(
|
||||
processed_items,
|
||||
is_review_mode,
|
||||
&mut review_thread_history,
|
||||
&sess,
|
||||
)
|
||||
.await;
|
||||
// Aborted turn is reported via a different event.
|
||||
break;
|
||||
}
|
||||
@@ -1850,7 +1764,13 @@ async fn run_turn(
|
||||
.await
|
||||
{
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
|
||||
Err(CodexErr::TurnAborted {
|
||||
dangling_artifacts: processed_items,
|
||||
}) => {
|
||||
return Err(CodexErr::TurnAborted {
|
||||
dangling_artifacts: processed_items,
|
||||
});
|
||||
}
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
@@ -1903,9 +1823,9 @@ async fn run_turn(
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ProcessedResponseItem {
|
||||
pub(crate) item: ResponseItem,
|
||||
pub(crate) response: Option<ResponseInputItem>,
|
||||
pub struct ProcessedResponseItem {
|
||||
pub item: ResponseItem,
|
||||
pub response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -1954,7 +1874,15 @@ async fn try_run_turn(
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().or_cancel(&cancellation_token).await?;
|
||||
let event = match stream.next().or_cancel(&cancellation_token).await {
|
||||
Ok(event) => event,
|
||||
Err(codex_async_utils::CancelErr::Cancelled) => {
|
||||
let processed_items = output.try_collect().await?;
|
||||
return Err(CodexErr::TurnAborted {
|
||||
dangling_artifacts: processed_items,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
@@ -1978,7 +1906,8 @@ async fn try_run_turn(
|
||||
let payload_preview = call.payload.log_payload().into_owned();
|
||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||
|
||||
let response = tool_runtime.handle_tool_call(call);
|
||||
let response =
|
||||
tool_runtime.handle_tool_call(call, cancellation_token.child_token());
|
||||
|
||||
output.push_back(
|
||||
async move {
|
||||
@@ -2060,12 +1989,7 @@ async fn try_run_turn(
|
||||
} => {
|
||||
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let processed_items = output
|
||||
.try_collect()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let processed_items = output.try_collect().await?;
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
@@ -2169,7 +2093,7 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
||||
}
|
||||
})
|
||||
}
|
||||
fn convert_call_tool_result_to_function_call_output_payload(
|
||||
pub(crate) fn convert_call_tool_result_to_function_call_output_payload(
|
||||
call_tool_result: &CallToolResult,
|
||||
) -> FunctionCallOutputPayload {
|
||||
let CallToolResult {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::codex::ProcessedResponseItem;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
@@ -53,8 +54,11 @@ pub enum SandboxErr {
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CodexErr {
|
||||
// todo(aibrahim): git rid of this error carrying the dangling artifacts
|
||||
#[error("turn aborted")]
|
||||
TurnAborted,
|
||||
TurnAborted {
|
||||
dangling_artifacts: Vec<ProcessedResponseItem>,
|
||||
},
|
||||
|
||||
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
||||
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
||||
@@ -158,7 +162,9 @@ pub enum CodexErr {
|
||||
|
||||
impl From<CancelErr> for CodexErr {
|
||||
fn from(_: CancelErr) -> Self {
|
||||
CodexErr::TurnAborted
|
||||
CodexErr::TurnAborted {
|
||||
dangling_artifacts: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod response_processing;
|
||||
pub mod sandboxing;
|
||||
pub mod token_data;
|
||||
mod truncate;
|
||||
|
||||
112
codex-rs/core/src/response_processing.rs
Normal file
112
codex-rs/core/src/response_processing.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::codex::Session;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use tracing::warn;
|
||||
|
||||
/// Process streamed `ResponseItem`s from the model into the pair of:
|
||||
/// - items we should record in conversation history; and
|
||||
/// - `ResponseInputItem`s to send back to the model on the next turn.
|
||||
pub(crate) async fn process_items(
|
||||
processed_items: Vec<crate::codex::ProcessedResponseItem>,
|
||||
is_review_mode: bool,
|
||||
review_thread_history: &mut ConversationHistory,
|
||||
sess: &Session,
|
||||
) -> (Vec<ResponseInputItem>, Vec<ResponseItem>) {
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
for processed_response_item in processed_items {
|
||||
let crate::codex::ProcessedResponseItem { item, response } = processed_response_item;
|
||||
match (&item, &response) {
|
||||
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
||||
// If the model returned a message, we need to record it.
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
}
|
||||
(
|
||||
ResponseItem::LocalShellCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::CustomToolCall { .. },
|
||||
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let output = match result {
|
||||
Ok(call_tool_result) => {
|
||||
crate::codex::convert_call_tool_result_to_function_call_output_payload(
|
||||
call_tool_result,
|
||||
)
|
||||
}
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: err.clone(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output,
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content,
|
||||
},
|
||||
None,
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||
id: id.clone(),
|
||||
summary: summary.clone(),
|
||||
content: content.clone(),
|
||||
encrypted_content: encrypted_content.clone(),
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||
}
|
||||
};
|
||||
if let Some(response) = response {
|
||||
responses.push(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
if is_review_mode {
|
||||
review_thread_history.record_items(items_to_record_in_conversation_history.iter());
|
||||
} else {
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
(responses, items_to_record_in_conversation_history)
|
||||
}
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
@@ -9,8 +10,10 @@ use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
@@ -40,6 +43,7 @@ impl ToolCallRuntime {
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
@@ -48,18 +52,24 @@ impl ToolCallRuntime {
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let aborted_response = Self::aborted_response(&call);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
||||
res = async {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, call)
|
||||
.await
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, call)
|
||||
.await
|
||||
} => res,
|
||||
}
|
||||
}));
|
||||
|
||||
async move {
|
||||
@@ -74,3 +84,25 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
result: Err("aborted".to_string()),
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user