From f59978ed3dab17a21147615c6ac2d07d9e480fb7 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Thu, 23 Oct 2025 08:47:10 -0700 Subject: [PATCH] Handle cancelling/aborting while processing a turn (#5543) Currently we collect all all turn items in a vector, then we add it to the history on success. This result in losing those items on errors including aborting `ctrl+c`. This PR: - Adds the ability for the tool call to handle cancellation - bubble the turn items up to where we are recording this info Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of error edge cases. The right thing to do is recording to the history on the spot as `items`/`tool calls output` come. However, this isn't possible because of having different `task_kind` that has different `conversation_histories`. The `try_run_turn` has no idea what thread are we using. We cannot also pass an `arc` to the `conversation_histories` because it's a private element of `state`. That's said, `abort` is the most common case and we should cover it until we remove `task kind` --- codex-rs/core/src/codex.rs | 160 ++++++----------------- codex-rs/core/src/error.rs | 10 +- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/response_processing.rs | 112 ++++++++++++++++ codex-rs/core/src/tools/parallel.rs | 48 +++++-- codex-rs/core/tests/common/responses.rs | 38 ++++++ codex-rs/core/tests/suite/abort_tasks.rs | 98 ++++++++++++++ 7 files changed, 339 insertions(+), 128 deletions(-) create mode 100644 codex-rs/core/src/response_processing.rs diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b261b083..77ac35ae 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -10,6 +10,7 @@ use crate::function_tool::FunctionCallError; use crate::mcp::auth::McpAuthStatusEntry; use crate::parse_command::parse_command; use crate::parse_turn_item; +use crate::response_processing::process_items; use crate::review_format::format_review_findings_block; use crate::terminal; use crate::user_notification::UserNotifier; @@ -855,7 +856,7 @@ impl Session { /// Records input items: always append to conversation history and /// persist these response items to rollout. - async fn record_conversation_items(&self, items: &[ResponseItem]) { + pub(crate) async fn record_conversation_items(&self, items: &[ResponseItem]) { self.record_into_history(items).await; self.persist_rollout_response_items(items).await; } @@ -1608,109 +1609,13 @@ pub(crate) async fn run_task( let token_limit_reached = total_usage_tokens .map(|tokens| tokens >= limit) .unwrap_or(false); - let mut items_to_record_in_conversation_history = Vec::::new(); - let mut responses = Vec::::new(); - for processed_response_item in processed_items { - let ProcessedResponseItem { item, response } = processed_response_item; - match (&item, &response) { - (ResponseItem::Message { role, .. }, None) if role == "assistant" => { - // If the model returned a message, we need to record it. - items_to_record_in_conversation_history.push(item); - } - ( - ResponseItem::LocalShellCall { .. }, - Some(ResponseInputItem::FunctionCallOutput { call_id, output }), - ) => { - items_to_record_in_conversation_history.push(item); - items_to_record_in_conversation_history.push( - ResponseItem::FunctionCallOutput { - call_id: call_id.clone(), - output: output.clone(), - }, - ); - } - ( - ResponseItem::FunctionCall { .. }, - Some(ResponseInputItem::FunctionCallOutput { call_id, output }), - ) => { - items_to_record_in_conversation_history.push(item); - items_to_record_in_conversation_history.push( - ResponseItem::FunctionCallOutput { - call_id: call_id.clone(), - output: output.clone(), - }, - ); - } - ( - ResponseItem::CustomToolCall { .. }, - Some(ResponseInputItem::CustomToolCallOutput { call_id, output }), - ) => { - items_to_record_in_conversation_history.push(item); - items_to_record_in_conversation_history.push( - ResponseItem::CustomToolCallOutput { - call_id: call_id.clone(), - output: output.clone(), - }, - ); - } - ( - ResponseItem::FunctionCall { .. }, - Some(ResponseInputItem::McpToolCallOutput { call_id, result }), - ) => { - items_to_record_in_conversation_history.push(item); - let output = match result { - Ok(call_tool_result) => { - convert_call_tool_result_to_function_call_output_payload( - call_tool_result, - ) - } - Err(err) => FunctionCallOutputPayload { - content: err.clone(), - success: Some(false), - }, - }; - items_to_record_in_conversation_history.push( - ResponseItem::FunctionCallOutput { - call_id: call_id.clone(), - output, - }, - ); - } - ( - ResponseItem::Reasoning { - id, - summary, - content, - encrypted_content, - }, - None, - ) => { - items_to_record_in_conversation_history.push(ResponseItem::Reasoning { - id: id.clone(), - summary: summary.clone(), - content: content.clone(), - encrypted_content: encrypted_content.clone(), - }); - } - _ => { - warn!("Unexpected response item: {item:?} with response: {response:?}"); - } - }; - if let Some(response) = response { - responses.push(response); - } - } - - // Only attempt to take the lock if there is something to record. - if !items_to_record_in_conversation_history.is_empty() { - if is_review_mode { - review_thread_history - .record_items(items_to_record_in_conversation_history.iter()); - } else { - sess.record_conversation_items(&items_to_record_in_conversation_history) - .await; - } - } + let (responses, items_to_record_in_conversation_history) = process_items( + processed_items, + is_review_mode, + &mut review_thread_history, + &sess, + ) + .await; if token_limit_reached { if auto_compact_recently_attempted { @@ -1749,7 +1654,16 @@ pub(crate) async fn run_task( } continue; } - Err(CodexErr::TurnAborted) => { + Err(CodexErr::TurnAborted { + dangling_artifacts: processed_items, + }) => { + let _ = process_items( + processed_items, + is_review_mode, + &mut review_thread_history, + &sess, + ) + .await; // Aborted turn is reported via a different event. break; } @@ -1850,7 +1764,13 @@ async fn run_turn( .await { Ok(output) => return Ok(output), - Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted), + Err(CodexErr::TurnAborted { + dangling_artifacts: processed_items, + }) => { + return Err(CodexErr::TurnAborted { + dangling_artifacts: processed_items, + }); + } Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), @@ -1903,9 +1823,9 @@ async fn run_turn( /// "handled" such that it produces a `ResponseInputItem` that needs to be /// sent back to the model on the next turn. #[derive(Debug)] -pub(crate) struct ProcessedResponseItem { - pub(crate) item: ResponseItem, - pub(crate) response: Option, +pub struct ProcessedResponseItem { + pub item: ResponseItem, + pub response: Option, } #[derive(Debug)] @@ -1954,7 +1874,15 @@ async fn try_run_turn( // Poll the next item from the model stream. We must inspect *both* Ok and Err // cases so that transient stream failures (e.g., dropped SSE connection before // `response.completed`) bubble up and trigger the caller's retry logic. - let event = stream.next().or_cancel(&cancellation_token).await?; + let event = match stream.next().or_cancel(&cancellation_token).await { + Ok(event) => event, + Err(codex_async_utils::CancelErr::Cancelled) => { + let processed_items = output.try_collect().await?; + return Err(CodexErr::TurnAborted { + dangling_artifacts: processed_items, + }); + } + }; let event = match event { Some(res) => res?, @@ -1978,7 +1906,8 @@ async fn try_run_turn( let payload_preview = call.payload.log_payload().into_owned(); tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); - let response = tool_runtime.handle_tool_call(call); + let response = + tool_runtime.handle_tool_call(call, cancellation_token.child_token()); output.push_back( async move { @@ -2060,12 +1989,7 @@ async fn try_run_turn( } => { sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref()) .await; - - let processed_items = output - .try_collect() - .or_cancel(&cancellation_token) - .await??; - + let processed_items = output.try_collect().await?; let unified_diff = { let mut tracker = turn_diff_tracker.lock().await; tracker.get_unified_diff() @@ -2169,7 +2093,7 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) - } }) } -fn convert_call_tool_result_to_function_call_output_payload( +pub(crate) fn convert_call_tool_result_to_function_call_output_payload( call_tool_result: &CallToolResult, ) -> FunctionCallOutputPayload { let CallToolResult { diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index 459cc175..e733b3c6 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -1,3 +1,4 @@ +use crate::codex::ProcessedResponseItem; use crate::exec::ExecToolCallOutput; use crate::token_data::KnownPlan; use crate::token_data::PlanType; @@ -53,8 +54,11 @@ pub enum SandboxErr { #[derive(Error, Debug)] pub enum CodexErr { + // todo(aibrahim): git rid of this error carrying the dangling artifacts #[error("turn aborted")] - TurnAborted, + TurnAborted { + dangling_artifacts: Vec, + }, /// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP /// handshake has succeeded but **before** it finished emitting `response.completed`. @@ -158,7 +162,9 @@ pub enum CodexErr { impl From for CodexErr { fn from(_: CancelErr) -> Self { - CodexErr::TurnAborted + CodexErr::TurnAborted { + dangling_artifacts: Vec::new(), + } } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 5e5b4e44..34b6df4a 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -36,6 +36,7 @@ mod mcp_tool_call; mod message_history; mod model_provider_info; pub mod parse_command; +mod response_processing; pub mod sandboxing; pub mod token_data; mod truncate; diff --git a/codex-rs/core/src/response_processing.rs b/codex-rs/core/src/response_processing.rs new file mode 100644 index 00000000..b9139ce6 --- /dev/null +++ b/codex-rs/core/src/response_processing.rs @@ -0,0 +1,112 @@ +use crate::codex::Session; +use crate::conversation_history::ConversationHistory; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; +use tracing::warn; + +/// Process streamed `ResponseItem`s from the model into the pair of: +/// - items we should record in conversation history; and +/// - `ResponseInputItem`s to send back to the model on the next turn. +pub(crate) async fn process_items( + processed_items: Vec, + is_review_mode: bool, + review_thread_history: &mut ConversationHistory, + sess: &Session, +) -> (Vec, Vec) { + let mut items_to_record_in_conversation_history = Vec::::new(); + let mut responses = Vec::::new(); + for processed_response_item in processed_items { + let crate::codex::ProcessedResponseItem { item, response } = processed_response_item; + match (&item, &response) { + (ResponseItem::Message { role, .. }, None) if role == "assistant" => { + // If the model returned a message, we need to record it. + items_to_record_in_conversation_history.push(item); + } + ( + ResponseItem::LocalShellCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }); + } + ( + ResponseItem::CustomToolCall { .. }, + Some(ResponseInputItem::CustomToolCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::McpToolCallOutput { call_id, result }), + ) => { + items_to_record_in_conversation_history.push(item); + let output = match result { + Ok(call_tool_result) => { + crate::codex::convert_call_tool_result_to_function_call_output_payload( + call_tool_result, + ) + } + Err(err) => FunctionCallOutputPayload { + content: err.clone(), + success: Some(false), + }, + }; + items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output, + }); + } + ( + ResponseItem::Reasoning { + id, + summary, + content, + encrypted_content, + }, + None, + ) => { + items_to_record_in_conversation_history.push(ResponseItem::Reasoning { + id: id.clone(), + summary: summary.clone(), + content: content.clone(), + encrypted_content: encrypted_content.clone(), + }); + } + _ => { + warn!("Unexpected response item: {item:?} with response: {response:?}"); + } + }; + if let Some(response) = response { + responses.push(response); + } + } + + // Only attempt to take the lock if there is something to record. + if !items_to_record_in_conversation_history.is_empty() { + if is_review_mode { + review_thread_history.record_items(items_to_record_in_conversation_history.iter()); + } else { + sess.record_conversation_items(&items_to_record_in_conversation_history) + .await; + } + } + (responses, items_to_record_in_conversation_history) +} diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index eae181c1..7f42bf5b 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use tokio_util::either::Either; +use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; use crate::codex::Session; @@ -9,8 +10,10 @@ use crate::codex::TurnContext; use crate::error::CodexErr; use crate::function_tool::FunctionCallError; use crate::tools::context::SharedTurnDiffTracker; +use crate::tools::context::ToolPayload; use crate::tools::router::ToolCall; use crate::tools::router::ToolRouter; +use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::ResponseInputItem; pub(crate) struct ToolCallRuntime { @@ -40,6 +43,7 @@ impl ToolCallRuntime { pub(crate) fn handle_tool_call( &self, call: ToolCall, + cancellation_token: CancellationToken, ) -> impl std::future::Future> { let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); @@ -48,18 +52,24 @@ impl ToolCallRuntime { let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); let lock = Arc::clone(&self.parallel_execution); + let aborted_response = Self::aborted_response(&call); let handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { - let _guard = if supports_parallel { - Either::Left(lock.read().await) - } else { - Either::Right(lock.write().await) - }; + tokio::select! { + _ = cancellation_token.cancelled() => Ok(aborted_response), + res = async { + let _guard = if supports_parallel { + Either::Left(lock.read().await) + } else { + Either::Right(lock.write().await) + }; - router - .dispatch_tool_call(session, turn, tracker, call) - .await + router + .dispatch_tool_call(session, turn, tracker, call) + .await + } => res, + } })); async move { @@ -74,3 +84,25 @@ impl ToolCallRuntime { } } } + +impl ToolCallRuntime { + fn aborted_response(call: &ToolCall) -> ResponseInputItem { + match &call.payload { + ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { + call_id: call.call_id.clone(), + output: "aborted".to_string(), + }, + ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput { + call_id: call.call_id.clone(), + result: Err("aborted".to_string()), + }, + _ => ResponseInputItem::FunctionCallOutput { + call_id: call.call_id.clone(), + output: FunctionCallOutputPayload { + content: "aborted".to_string(), + success: None, + }, + }, + } + } +} diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 102be353..511c0b5b 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -35,6 +35,22 @@ impl ResponseMock { pub fn requests(&self) -> Vec { self.requests.lock().unwrap().clone() } + + /// Returns true if any captured request contains a `function_call` with the + /// provided `call_id`. + pub fn saw_function_call(&self, call_id: &str) -> bool { + self.requests() + .iter() + .any(|req| req.has_function_call(call_id)) + } + + /// Returns the `output` string for a matching `function_call_output` with + /// the provided `call_id`, searching across all captured requests. + pub fn function_call_output_text(&self, call_id: &str) -> Option { + self.requests() + .iter() + .find_map(|req| req.function_call_output_text(call_id)) + } } #[derive(Debug, Clone)] @@ -70,6 +86,28 @@ impl ResponsesRequest { .unwrap_or_else(|| panic!("function call output {call_id} item not found in request")) } + /// Returns true if this request's `input` contains a `function_call` with + /// the specified `call_id`. + pub fn has_function_call(&self, call_id: &str) -> bool { + self.input().iter().any(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call") + && item.get("call_id").and_then(Value::as_str) == Some(call_id) + }) + } + + /// If present, returns the `output` string of the `function_call_output` + /// entry matching `call_id` in this request's `input`. + pub fn function_call_output_text(&self, call_id: &str) -> Option { + let binding = self.input(); + let item = binding.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + && item.get("call_id").and_then(Value::as_str) == Some(call_id) + })?; + item.get("output") + .and_then(Value::as_str) + .map(str::to_string) + } + pub fn header(&self, name: &str) -> Option { self.0 .headers diff --git a/codex-rs/core/tests/suite/abort_tasks.rs b/codex-rs/core/tests/suite/abort_tasks.rs index dcc65fc7..c9d59508 100644 --- a/codex-rs/core/tests/suite/abort_tasks.rs +++ b/codex-rs/core/tests/suite/abort_tasks.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use std::time::Duration; use codex_core::protocol::EventMsg; @@ -5,7 +6,9 @@ use codex_core::protocol::Op; use codex_protocol::user_input::UserInput; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; use core_test_support::responses::mount_sse_once; +use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::test_codex; @@ -67,3 +70,98 @@ async fn interrupt_long_running_tool_emits_turn_aborted() { ) .await; } + +/// After an interrupt we expect the next request to the model to include both +/// the original tool call and an `"aborted"` `function_call_output`. This test +/// exercises the follow-up flow: it sends another user turn, inspects the mock +/// responses server, and ensures the model receives the synthesized abort. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn interrupt_tool_records_history_entries() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "sleep 60".to_string(), + ]; + let call_id = "call-history"; + + let args = json!({ + "command": command, + "timeout_ms": 60_000 + }) + .to_string(); + let first_body = sse(vec![ + ev_response_created("resp-history"), + ev_function_call(call_id, "shell", &args), + ev_completed("resp-history"), + ]); + let follow_up_body = sse(vec![ + ev_response_created("resp-followup"), + ev_completed("resp-followup"), + ]); + + let server = start_mock_server().await; + let response_mock = mount_sse_sequence(&server, vec![first_body, follow_up_body]).await; + + let fixture = test_codex().build(&server).await.unwrap(); + let codex = Arc::clone(&fixture.codex); + + let wait_timeout = Duration::from_millis(100); + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "start history recording".into(), + }], + }) + .await + .unwrap(); + + wait_for_event_with_timeout( + &codex, + |ev| matches!(ev, EventMsg::ExecCommandBegin(_)), + wait_timeout, + ) + .await; + + codex.submit(Op::Interrupt).await.unwrap(); + + wait_for_event_with_timeout( + &codex, + |ev| matches!(ev, EventMsg::TurnAborted(_)), + wait_timeout, + ) + .await; + + codex + .submit(Op::UserInput { + items: vec![UserInput::Text { + text: "follow up".into(), + }], + }) + .await + .unwrap(); + + wait_for_event_with_timeout( + &codex, + |ev| matches!(ev, EventMsg::TaskComplete(_)), + wait_timeout, + ) + .await; + + let requests = response_mock.requests(); + assert!( + requests.len() == 2, + "expected two calls to the responses API, got {}", + requests.len() + ); + + assert!( + response_mock.saw_function_call(call_id), + "function call not recorded in responses payload" + ); + assert_eq!( + response_mock.function_call_output_text(call_id).as_deref(), + Some("aborted"), + "aborted function call output not recorded in responses payload" + ); +}