Support graceful agent interruption (#5287)

This commit is contained in:
pakrym-oai
2025-10-17 11:52:57 -07:00
committed by GitHub
parent 6915ba2100
commit c03e31ecf5
13 changed files with 309 additions and 55 deletions

34
codex-rs/Cargo.lock generated
View File

@@ -899,6 +899,16 @@ dependencies = [
"tokio",
]
[[package]]
name = "codex-async-utils"
version = "0.0.0"
dependencies = [
"async-trait",
"pretty_assertions",
"tokio",
"tokio-util",
]
[[package]]
name = "codex-backend-client"
version = "0.0.0"
@@ -1037,6 +1047,7 @@ dependencies = [
"chrono",
"codex-app-server-protocol",
"codex-apply-patch",
"codex-async-utils",
"codex-file-search",
"codex-mcp-client",
"codex-otel",
@@ -1073,6 +1084,7 @@ dependencies = [
"similar",
"strum_macros 0.27.2",
"tempfile",
"test-log",
"thiserror 2.0.16",
"time",
"tokio",
@@ -6022,6 +6034,28 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]]
name = "test-log"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
dependencies = [
"env_logger",
"test-log-macros",
"tracing-subscriber",
]
[[package]]
name = "test-log-macros"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]]
name = "textwrap"
version = "0.11.0"

View File

@@ -2,6 +2,7 @@
members = [
"backend-client",
"ansi-escape",
"async-utils",
"app-server",
"app-server-protocol",
"apply-patch",
@@ -56,6 +57,7 @@ codex-arg0 = { path = "arg0" }
codex-chatgpt = { path = "chatgpt" }
codex-common = { path = "common" }
codex-core = { path = "core" }
codex-async-utils = { path = "async-utils" }
codex-exec = { path = "exec" }
codex-feedback = { path = "feedback" }
codex-file-search = { path = "file-search" }
@@ -164,6 +166,7 @@ strum_macros = "0.27.2"
supports-color = "3.0.2"
sys-locale = "0.3.2"
tempfile = "3.23.0"
test-log = "0.2.18"
textwrap = "0.16.2"
thiserror = "2.0.16"
time = "0.3"

View File

@@ -0,0 +1,15 @@
[package]
edition.workspace = true
name = "codex-async-utils"
version.workspace = true
[lints]
workspace = true
[dependencies]
async-trait.workspace = true
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "time"] }
tokio-util.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

View File

@@ -0,0 +1,86 @@
use async_trait::async_trait;
use std::future::Future;
use tokio_util::sync::CancellationToken;
#[derive(Debug, PartialEq, Eq)]
pub enum CancelErr {
Cancelled,
}
#[async_trait]
pub trait OrCancelExt: Sized {
type Output;
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr>;
}
#[async_trait]
impl<F> OrCancelExt for F
where
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr> {
tokio::select! {
_ = token.cancelled() => Err(CancelErr::Cancelled),
res = self => Ok(res),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tokio::task;
use tokio::time::sleep;
#[tokio::test]
async fn returns_ok_when_future_completes_first() {
let token = CancellationToken::new();
let value = async { 42 };
let result = value.or_cancel(&token).await;
assert_eq!(Ok(42), result);
}
#[tokio::test]
async fn returns_err_when_token_cancelled_first() {
let token = CancellationToken::new();
let token_clone = token.clone();
let cancel_handle = task::spawn(async move {
sleep(Duration::from_millis(10)).await;
token_clone.cancel();
});
let result = async {
sleep(Duration::from_millis(100)).await;
7
}
.or_cancel(&token)
.await;
cancel_handle.await.expect("cancel task panicked");
assert_eq!(Err(CancelErr::Cancelled), result);
}
#[tokio::test]
async fn returns_err_when_token_already_cancelled() {
let token = CancellationToken::new();
token.cancel();
let result = async {
sleep(Duration::from_millis(50)).await;
5
}
.or_cancel(&token)
.await;
assert_eq!(Err(CancelErr::Cancelled), result);
}
}

View File

@@ -26,6 +26,7 @@ codex-mcp-client = { workspace = true }
codex-otel = { workspace = true, features = ["otel"] }
codex-protocol = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-async-utils = { workspace = true }
codex-utils-string = { workspace = true }
dirs = { workspace = true }
dunce = { workspace = true }
@@ -47,6 +48,7 @@ shlex = { workspace = true }
similar = { workspace = true }
strum_macros = { workspace = true }
tempfile = { workspace = true }
test-log = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true, features = [
"formatting",

View File

@@ -38,6 +38,7 @@ use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
use crate::user_notification::UserNotification;
use crate::util::backoff;
use codex_async_utils::OrCancelExt;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
@@ -1170,19 +1172,6 @@ impl Session {
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
}
fn interrupt_task_sync(&self) {
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
{
at.try_clear_pending_sync();
let tasks = at.drain_tasks();
*active = None;
for (_sub_id, task) in tasks {
task.handle.abort();
}
}
}
pub(crate) fn notifier(&self) -> &UserNotifier {
&self.services.notifier
}
@@ -1196,12 +1185,6 @@ impl Session {
}
}
impl Drop for Session {
fn drop(&mut self) {
self.interrupt_task_sync();
}
}
async fn submission_loop(
sess: Arc<Session>,
turn_context: TurnContext,
@@ -1711,6 +1694,7 @@ pub(crate) async fn run_task(
sub_id: String,
input: Vec<InputItem>,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> Option<String> {
if input.is_empty() {
return None;
@@ -1795,6 +1779,7 @@ pub(crate) async fn run_task(
sub_id.clone(),
turn_input,
task_kind,
cancellation_token.child_token(),
)
.await
{
@@ -1956,6 +1941,10 @@ pub(crate) async fn run_task(
}
continue;
}
Err(CodexErr::TurnAborted) => {
// Aborted turn is reported via a different event.
break;
}
Err(e) => {
info!("Turn error: {e:#}");
let event = Event {
@@ -2022,6 +2011,7 @@ async fn run_turn(
sub_id: String,
input: Vec<ResponseItem>,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = Arc::new(ToolRouter::from_config(
@@ -2052,10 +2042,12 @@ async fn run_turn(
&sub_id,
&prompt,
task_kind,
cancellation_token.child_token(),
)
.await
{
Ok(output) => return Ok(output),
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e),
@@ -2118,6 +2110,7 @@ struct TurnRunResult {
total_token_usage: Option<TokenUsage>,
}
#[allow(clippy::too_many_arguments)]
async fn try_run_turn(
router: Arc<ToolRouter>,
sess: Arc<Session>,
@@ -2126,6 +2119,7 @@ async fn try_run_turn(
sub_id: &str,
prompt: &Prompt,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
@@ -2195,7 +2189,8 @@ async fn try_run_turn(
.client
.clone()
.stream_with_task_kind(prompt.as_ref(), task_kind)
.await?;
.or_cancel(&cancellation_token)
.await??;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
@@ -2211,7 +2206,8 @@ async fn try_run_turn(
// Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let event = stream.next().or_cancel(&cancellation_token).await?;
let event = match event {
Some(res) => res?,
None => {
@@ -2316,7 +2312,10 @@ async fn try_run_turn(
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await;
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
let processed_items = output
.try_collect()
.or_cancel(&cancellation_token)
.await??;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
@@ -2554,6 +2553,8 @@ mod tests {
use codex_app_server_protocol::AuthMode;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use std::time::Duration;
use tokio::time::sleep;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
@@ -2563,8 +2564,6 @@ mod tests {
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::time::Duration;
use tokio::time::sleep;
#[test]
fn reconstruct_history_matches_live_compactions() {
@@ -2944,12 +2943,15 @@ mod tests {
}
#[derive(Clone, Copy)]
struct NeverEndingTask(TaskKind);
struct NeverEndingTask {
kind: TaskKind,
listen_to_cancellation_token: bool,
}
#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
self.0
self.kind
}
async fn run(
@@ -2958,20 +2960,26 @@ mod tests {
_ctx: Arc<TurnContext>,
_sub_id: String,
_input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> {
if self.listen_to_cancellation_token {
cancellation_token.cancelled().await;
return None;
}
loop {
sleep(Duration::from_secs(60)).await;
}
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
if let TaskKind::Review = self.0 {
if let TaskKind::Review = self.kind {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
}
}
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_regular_task_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
@@ -2982,7 +2990,41 @@ mod tests {
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Regular),
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for event")
.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn abort_gracefuly_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
let input = vec![InputItem::Text {
text: "hello".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
)
.await;
@@ -2996,7 +3038,7 @@ mod tests {
assert!(rx.try_recv().is_err());
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-review".to_string();
@@ -3007,18 +3049,27 @@ mod tests {
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Review),
NeverEndingTask {
kind: TaskKind::Review,
listen_to_cancellation_token: false,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let first = rx.recv().await.expect("first event");
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for first event")
.expect("first event");
match first.msg {
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
other => panic!("unexpected first event: {other:?}"),
}
let second = rx.recv().await.expect("second event");
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for second event")
.expect("second event");
match second.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected second event: {other:?}"),

View File

@@ -2,6 +2,7 @@ use crate::exec::ExecToolCallOutput;
use crate::token_data::KnownPlan;
use crate::token_data::PlanType;
use crate::truncate::truncate_middle;
use codex_async_utils::CancelErr;
use codex_protocol::ConversationId;
use codex_protocol::protocol::RateLimitSnapshot;
use reqwest::StatusCode;
@@ -50,6 +51,9 @@ pub enum SandboxErr {
#[derive(Error, Debug)]
pub enum CodexErr {
#[error("turn aborted")]
TurnAborted,
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
/// handshake has succeeded but **before** it finished emitting `response.completed`.
///
@@ -150,6 +154,12 @@ pub enum CodexErr {
EnvVar(EnvVarError),
}
impl From<CancelErr> for CodexErr {
fn from(_: CancelErr) -> Self {
CodexErr::TurnAborted
}
}
#[derive(Debug)]
pub struct ConnectionFailedError {
pub source: reqwest::Error,

View File

@@ -13,7 +13,6 @@ mod client;
mod client_common;
pub mod codex;
mod codex_conversation;
pub mod token_data;
pub use codex_conversation::CodexConversation;
mod command_safety;
pub mod config;
@@ -39,6 +38,7 @@ mod mcp_tool_call;
mod message_history;
mod model_provider_info;
pub mod parse_command;
pub mod token_data;
mod truncate;
mod unified_exec;
mod user_instructions;
@@ -107,5 +107,4 @@ pub use codex_protocol::models::LocalShellExecAction;
pub use codex_protocol::models::LocalShellStatus;
pub use codex_protocol::models::ReasoningItemContent;
pub use codex_protocol::models::ResponseItem;
pub mod otel_init;

View File

@@ -4,7 +4,9 @@ use indexmap::IndexMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::AbortHandle;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use codex_protocol::models::ResponseInputItem;
use tokio::sync::oneshot;
@@ -46,9 +48,11 @@ impl TaskKind {
#[derive(Clone)]
pub(crate) struct RunningTask {
pub(crate) handle: AbortHandle,
pub(crate) done: Arc<Notify>,
pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn SessionTask>,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
}
impl ActiveTurn {
@@ -115,13 +119,6 @@ impl ActiveTurn {
let mut ts = self.turn_state.lock().await;
ts.clear_pending();
}
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
pub(crate) fn try_clear_pending_sync(&self) {
if let Ok(mut ts) = self.turn_state.try_lock() {
ts.clear_pending();
}
}
}
#[cfg(test)]

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext;
use crate::codex::compact;
@@ -25,6 +26,7 @@ impl SessionTask for CompactTask {
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
_cancellation_token: CancellationToken,
) -> Option<String> {
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
}

View File

@@ -3,9 +3,15 @@ mod regular;
mod review;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::select;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing::trace;
use tracing::warn;
use crate::codex::Session;
use crate::codex::TurnContext;
@@ -23,6 +29,8 @@ pub(crate) use compact::CompactTask;
pub(crate) use regular::RegularTask;
pub(crate) use review::ReviewTask;
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
#[derive(Clone)]
pub(crate) struct SessionTaskContext {
@@ -49,6 +57,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String>;
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
@@ -69,26 +78,42 @@ impl Session {
let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
let cancellation_token = CancellationToken::new();
let done = Arc::new(Notify::new());
let done_clone = Arc::clone(&done);
let handle = {
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task);
let sub_clone = sub_id.clone();
let task_cancellation_token = cancellation_token.child_token();
tokio::spawn(async move {
let last_agent_message = task_for_run
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
.run(
Arc::clone(&session_ctx),
ctx,
sub_clone.clone(),
input,
task_cancellation_token.child_token(),
)
.await;
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(sub_clone, last_agent_message).await;
if !task_cancellation_token.is_cancelled() {
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(sub_clone, last_agent_message).await;
}
done_clone.notify_waiters();
})
.abort_handle()
};
let running_task = RunningTask {
handle,
done,
handle: Arc::new(AbortOnDropHandle::new(handle)),
kind: task_kind,
task,
cancellation_token,
};
self.register_new_active_task(sub_id, running_task).await;
}
@@ -143,14 +168,24 @@ impl Session {
task: RunningTask,
reason: TurnAbortReason,
) {
if task.handle.is_finished() {
if task.cancellation_token.is_cancelled() {
return;
}
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
task.cancellation_token.cancel();
let session_task = task.task;
let handle = task.handle;
handle.abort();
select! {
_ = task.done.notified() => {
},
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
}
}
task.handle.abort();
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task.abort(session_ctx, &sub_id).await;

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext;
use crate::codex::run_task;
@@ -25,8 +26,17 @@ impl SessionTask for RegularTask {
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input, TaskKind::Regular).await
run_task(
sess,
ctx,
sub_id,
input,
TaskKind::Regular,
cancellation_token,
)
.await
}
}

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext;
use crate::codex::exit_review_mode;
@@ -26,9 +27,18 @@ impl SessionTask for ReviewTask {
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input, TaskKind::Review).await
run_task(
sess,
ctx,
sub_id,
input,
TaskKind::Review,
cancellation_token,
)
.await
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {