Handle cancelling/aborting while processing a turn (#5543)

Currently we collect all all turn items in a vector, then we add it to
the history on success. This result in losing those items on errors
including aborting `ctrl+c`.

This PR:
- Adds the ability for the tool call to handle cancellation
- bubble the turn items up to where we are recording this info

Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of
error edge cases. The right thing to do is recording to the history on
the spot as `items`/`tool calls output` come. However, this isn't
possible because of having different `task_kind` that has different
`conversation_histories`. The `try_run_turn` has no idea what thread are
we using. We cannot also pass an `arc` to the `conversation_histories`
because it's a private element of `state`.

That's said, `abort` is the most common case and we should cover it
until we remove `task kind`
This commit is contained in:
Ahmed Ibrahim
2025-10-23 08:47:10 -07:00
committed by GitHub
parent 3ab6028e80
commit f59978ed3d
7 changed files with 339 additions and 128 deletions

View File

@@ -10,6 +10,7 @@ use crate::function_tool::FunctionCallError;
use crate::mcp::auth::McpAuthStatusEntry; use crate::mcp::auth::McpAuthStatusEntry;
use crate::parse_command::parse_command; use crate::parse_command::parse_command;
use crate::parse_turn_item; use crate::parse_turn_item;
use crate::response_processing::process_items;
use crate::review_format::format_review_findings_block; use crate::review_format::format_review_findings_block;
use crate::terminal; use crate::terminal;
use crate::user_notification::UserNotifier; use crate::user_notification::UserNotifier;
@@ -855,7 +856,7 @@ impl Session {
/// Records input items: always append to conversation history and /// Records input items: always append to conversation history and
/// persist these response items to rollout. /// persist these response items to rollout.
async fn record_conversation_items(&self, items: &[ResponseItem]) { pub(crate) async fn record_conversation_items(&self, items: &[ResponseItem]) {
self.record_into_history(items).await; self.record_into_history(items).await;
self.persist_rollout_response_items(items).await; self.persist_rollout_response_items(items).await;
} }
@@ -1608,109 +1609,13 @@ pub(crate) async fn run_task(
let token_limit_reached = total_usage_tokens let token_limit_reached = total_usage_tokens
.map(|tokens| tokens >= limit) .map(|tokens| tokens >= limit)
.unwrap_or(false); .unwrap_or(false);
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new(); let (responses, items_to_record_in_conversation_history) = process_items(
let mut responses = Vec::<ResponseInputItem>::new(); processed_items,
for processed_response_item in processed_items { is_review_mode,
let ProcessedResponseItem { item, response } = processed_response_item; &mut review_thread_history,
match (&item, &response) { &sess,
(ResponseItem::Message { role, .. }, None) if role == "assistant" => { )
// If the model returned a message, we need to record it. .await;
items_to_record_in_conversation_history.push(item);
}
(
ResponseItem::LocalShellCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
},
);
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
},
);
}
(
ResponseItem::CustomToolCall { .. },
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(
ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
output: output.clone(),
},
);
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
) => {
items_to_record_in_conversation_history.push(item);
let output = match result {
Ok(call_tool_result) => {
convert_call_tool_result_to_function_call_output_payload(
call_tool_result,
)
}
Err(err) => FunctionCallOutputPayload {
content: err.clone(),
success: Some(false),
},
};
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output,
},
);
}
(
ResponseItem::Reasoning {
id,
summary,
content,
encrypted_content,
},
None,
) => {
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
id: id.clone(),
summary: summary.clone(),
content: content.clone(),
encrypted_content: encrypted_content.clone(),
});
}
_ => {
warn!("Unexpected response item: {item:?} with response: {response:?}");
}
};
if let Some(response) = response {
responses.push(response);
}
}
// Only attempt to take the lock if there is something to record.
if !items_to_record_in_conversation_history.is_empty() {
if is_review_mode {
review_thread_history
.record_items(items_to_record_in_conversation_history.iter());
} else {
sess.record_conversation_items(&items_to_record_in_conversation_history)
.await;
}
}
if token_limit_reached { if token_limit_reached {
if auto_compact_recently_attempted { if auto_compact_recently_attempted {
@@ -1749,7 +1654,16 @@ pub(crate) async fn run_task(
} }
continue; continue;
} }
Err(CodexErr::TurnAborted) => { Err(CodexErr::TurnAborted {
dangling_artifacts: processed_items,
}) => {
let _ = process_items(
processed_items,
is_review_mode,
&mut review_thread_history,
&sess,
)
.await;
// Aborted turn is reported via a different event. // Aborted turn is reported via a different event.
break; break;
} }
@@ -1850,7 +1764,13 @@ async fn run_turn(
.await .await
{ {
Ok(output) => return Ok(output), Ok(output) => return Ok(output),
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted), Err(CodexErr::TurnAborted {
dangling_artifacts: processed_items,
}) => {
return Err(CodexErr::TurnAborted {
dangling_artifacts: processed_items,
});
}
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::Fatal(_)) => return Err(e),
@@ -1903,9 +1823,9 @@ async fn run_turn(
/// "handled" such that it produces a `ResponseInputItem` that needs to be /// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn. /// sent back to the model on the next turn.
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ProcessedResponseItem { pub struct ProcessedResponseItem {
pub(crate) item: ResponseItem, pub item: ResponseItem,
pub(crate) response: Option<ResponseInputItem>, pub response: Option<ResponseInputItem>,
} }
#[derive(Debug)] #[derive(Debug)]
@@ -1954,7 +1874,15 @@ async fn try_run_turn(
// Poll the next item from the model stream. We must inspect *both* Ok and Err // Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before // cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic. // `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().or_cancel(&cancellation_token).await?; let event = match stream.next().or_cancel(&cancellation_token).await {
Ok(event) => event,
Err(codex_async_utils::CancelErr::Cancelled) => {
let processed_items = output.try_collect().await?;
return Err(CodexErr::TurnAborted {
dangling_artifacts: processed_items,
});
}
};
let event = match event { let event = match event {
Some(res) => res?, Some(res) => res?,
@@ -1978,7 +1906,8 @@ async fn try_run_turn(
let payload_preview = call.payload.log_payload().into_owned(); let payload_preview = call.payload.log_payload().into_owned();
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
let response = tool_runtime.handle_tool_call(call); let response =
tool_runtime.handle_tool_call(call, cancellation_token.child_token());
output.push_back( output.push_back(
async move { async move {
@@ -2060,12 +1989,7 @@ async fn try_run_turn(
} => { } => {
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref()) sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
.await; .await;
let processed_items = output.try_collect().await?;
let processed_items = output
.try_collect()
.or_cancel(&cancellation_token)
.await??;
let unified_diff = { let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await; let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff() tracker.get_unified_diff()
@@ -2169,7 +2093,7 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
} }
}) })
} }
fn convert_call_tool_result_to_function_call_output_payload( pub(crate) fn convert_call_tool_result_to_function_call_output_payload(
call_tool_result: &CallToolResult, call_tool_result: &CallToolResult,
) -> FunctionCallOutputPayload { ) -> FunctionCallOutputPayload {
let CallToolResult { let CallToolResult {

View File

@@ -1,3 +1,4 @@
use crate::codex::ProcessedResponseItem;
use crate::exec::ExecToolCallOutput; use crate::exec::ExecToolCallOutput;
use crate::token_data::KnownPlan; use crate::token_data::KnownPlan;
use crate::token_data::PlanType; use crate::token_data::PlanType;
@@ -53,8 +54,11 @@ pub enum SandboxErr {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CodexErr { pub enum CodexErr {
// todo(aibrahim): git rid of this error carrying the dangling artifacts
#[error("turn aborted")] #[error("turn aborted")]
TurnAborted, TurnAborted {
dangling_artifacts: Vec<ProcessedResponseItem>,
},
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP /// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
/// handshake has succeeded but **before** it finished emitting `response.completed`. /// handshake has succeeded but **before** it finished emitting `response.completed`.
@@ -158,7 +162,9 @@ pub enum CodexErr {
impl From<CancelErr> for CodexErr { impl From<CancelErr> for CodexErr {
fn from(_: CancelErr) -> Self { fn from(_: CancelErr) -> Self {
CodexErr::TurnAborted CodexErr::TurnAborted {
dangling_artifacts: Vec::new(),
}
} }
} }

View File

@@ -36,6 +36,7 @@ mod mcp_tool_call;
mod message_history; mod message_history;
mod model_provider_info; mod model_provider_info;
pub mod parse_command; pub mod parse_command;
mod response_processing;
pub mod sandboxing; pub mod sandboxing;
pub mod token_data; pub mod token_data;
mod truncate; mod truncate;

View File

@@ -0,0 +1,112 @@
use crate::codex::Session;
use crate::conversation_history::ConversationHistory;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem;
use codex_protocol::models::ResponseItem;
use tracing::warn;
/// Process streamed `ResponseItem`s from the model into the pair of:
/// - items we should record in conversation history; and
/// - `ResponseInputItem`s to send back to the model on the next turn.
pub(crate) async fn process_items(
processed_items: Vec<crate::codex::ProcessedResponseItem>,
is_review_mode: bool,
review_thread_history: &mut ConversationHistory,
sess: &Session,
) -> (Vec<ResponseInputItem>, Vec<ResponseItem>) {
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
for processed_response_item in processed_items {
let crate::codex::ProcessedResponseItem { item, response } = processed_response_item;
match (&item, &response) {
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
// If the model returned a message, we need to record it.
items_to_record_in_conversation_history.push(item);
}
(
ResponseItem::LocalShellCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::CustomToolCall { .. },
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
) => {
items_to_record_in_conversation_history.push(item);
let output = match result {
Ok(call_tool_result) => {
crate::codex::convert_call_tool_result_to_function_call_output_payload(
call_tool_result,
)
}
Err(err) => FunctionCallOutputPayload {
content: err.clone(),
success: Some(false),
},
};
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output,
});
}
(
ResponseItem::Reasoning {
id,
summary,
content,
encrypted_content,
},
None,
) => {
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
id: id.clone(),
summary: summary.clone(),
content: content.clone(),
encrypted_content: encrypted_content.clone(),
});
}
_ => {
warn!("Unexpected response item: {item:?} with response: {response:?}");
}
};
if let Some(response) = response {
responses.push(response);
}
}
// Only attempt to take the lock if there is something to record.
if !items_to_record_in_conversation_history.is_empty() {
if is_review_mode {
review_thread_history.record_items(items_to_record_in_conversation_history.iter());
} else {
sess.record_conversation_items(&items_to_record_in_conversation_history)
.await;
}
}
(responses, items_to_record_in_conversation_history)
}

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_util::either::Either; use tokio_util::either::Either;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle; use tokio_util::task::AbortOnDropHandle;
use crate::codex::Session; use crate::codex::Session;
@@ -9,8 +10,10 @@ use crate::codex::TurnContext;
use crate::error::CodexErr; use crate::error::CodexErr;
use crate::function_tool::FunctionCallError; use crate::function_tool::FunctionCallError;
use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::router::ToolCall; use crate::tools::router::ToolCall;
use crate::tools::router::ToolRouter; use crate::tools::router::ToolRouter;
use codex_protocol::models::FunctionCallOutputPayload;
use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseInputItem;
pub(crate) struct ToolCallRuntime { pub(crate) struct ToolCallRuntime {
@@ -40,6 +43,7 @@ impl ToolCallRuntime {
pub(crate) fn handle_tool_call( pub(crate) fn handle_tool_call(
&self, &self,
call: ToolCall, call: ToolCall,
cancellation_token: CancellationToken,
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> { ) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name); let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
@@ -48,18 +52,24 @@ impl ToolCallRuntime {
let turn = Arc::clone(&self.turn_context); let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker); let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution); let lock = Arc::clone(&self.parallel_execution);
let aborted_response = Self::aborted_response(&call);
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> = let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move { AbortOnDropHandle::new(tokio::spawn(async move {
let _guard = if supports_parallel { tokio::select! {
Either::Left(lock.read().await) _ = cancellation_token.cancelled() => Ok(aborted_response),
} else { res = async {
Either::Right(lock.write().await) let _guard = if supports_parallel {
}; Either::Left(lock.read().await)
} else {
Either::Right(lock.write().await)
};
router router
.dispatch_tool_call(session, turn, tracker, call) .dispatch_tool_call(session, turn, tracker, call)
.await .await
} => res,
}
})); }));
async move { async move {
@@ -74,3 +84,25 @@ impl ToolCallRuntime {
} }
} }
} }
impl ToolCallRuntime {
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
match &call.payload {
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
output: "aborted".to_string(),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call.call_id.clone(),
result: Err("aborted".to_string()),
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
success: None,
},
},
}
}
}

View File

@@ -35,6 +35,22 @@ impl ResponseMock {
pub fn requests(&self) -> Vec<ResponsesRequest> { pub fn requests(&self) -> Vec<ResponsesRequest> {
self.requests.lock().unwrap().clone() self.requests.lock().unwrap().clone()
} }
/// Returns true if any captured request contains a `function_call` with the
/// provided `call_id`.
pub fn saw_function_call(&self, call_id: &str) -> bool {
self.requests()
.iter()
.any(|req| req.has_function_call(call_id))
}
/// Returns the `output` string for a matching `function_call_output` with
/// the provided `call_id`, searching across all captured requests.
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
self.requests()
.iter()
.find_map(|req| req.function_call_output_text(call_id))
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -70,6 +86,28 @@ impl ResponsesRequest {
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request")) .unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
} }
/// Returns true if this request's `input` contains a `function_call` with
/// the specified `call_id`.
pub fn has_function_call(&self, call_id: &str) -> bool {
self.input().iter().any(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call")
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
})
}
/// If present, returns the `output` string of the `function_call_output`
/// entry matching `call_id` in this request's `input`.
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
let binding = self.input();
let item = binding.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
})?;
item.get("output")
.and_then(Value::as_str)
.map(str::to_string)
}
pub fn header(&self, name: &str) -> Option<String> { pub fn header(&self, name: &str) -> Option<String> {
self.0 self.0
.headers .headers

View File

@@ -1,3 +1,4 @@
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use codex_core::protocol::EventMsg; use codex_core::protocol::EventMsg;
@@ -5,7 +6,9 @@ use codex_core::protocol::Op;
use codex_protocol::user_input::UserInput; use codex_protocol::user_input::UserInput;
use core_test_support::responses::ev_completed; use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse; use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server; use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::test_codex; use core_test_support::test_codex::test_codex;
@@ -67,3 +70,98 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
) )
.await; .await;
} }
/// After an interrupt we expect the next request to the model to include both
/// the original tool call and an `"aborted"` `function_call_output`. This test
/// exercises the follow-up flow: it sends another user turn, inspects the mock
/// responses server, and ensures the model receives the synthesized abort.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn interrupt_tool_records_history_entries() {
let command = vec![
"bash".to_string(),
"-lc".to_string(),
"sleep 60".to_string(),
];
let call_id = "call-history";
let args = json!({
"command": command,
"timeout_ms": 60_000
})
.to_string();
let first_body = sse(vec![
ev_response_created("resp-history"),
ev_function_call(call_id, "shell", &args),
ev_completed("resp-history"),
]);
let follow_up_body = sse(vec![
ev_response_created("resp-followup"),
ev_completed("resp-followup"),
]);
let server = start_mock_server().await;
let response_mock = mount_sse_sequence(&server, vec![first_body, follow_up_body]).await;
let fixture = test_codex().build(&server).await.unwrap();
let codex = Arc::clone(&fixture.codex);
let wait_timeout = Duration::from_millis(100);
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: "start history recording".into(),
}],
})
.await
.unwrap();
wait_for_event_with_timeout(
&codex,
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
wait_timeout,
)
.await;
codex.submit(Op::Interrupt).await.unwrap();
wait_for_event_with_timeout(
&codex,
|ev| matches!(ev, EventMsg::TurnAborted(_)),
wait_timeout,
)
.await;
codex
.submit(Op::UserInput {
items: vec![UserInput::Text {
text: "follow up".into(),
}],
})
.await
.unwrap();
wait_for_event_with_timeout(
&codex,
|ev| matches!(ev, EventMsg::TaskComplete(_)),
wait_timeout,
)
.await;
let requests = response_mock.requests();
assert!(
requests.len() == 2,
"expected two calls to the responses API, got {}",
requests.len()
);
assert!(
response_mock.saw_function_call(call_id),
"function call not recorded in responses payload"
);
assert_eq!(
response_mock.function_call_output_text(call_id).as_deref(),
Some("aborted"),
"aborted function call output not recorded in responses payload"
);
}