Support graceful agent interruption (#5287)

This commit is contained in:
pakrym-oai
2025-10-17 11:52:57 -07:00
committed by GitHub
parent 6915ba2100
commit c03e31ecf5
13 changed files with 309 additions and 55 deletions

View File

@@ -38,6 +38,7 @@ use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
use crate::user_notification::UserNotification;
use crate::util::backoff;
use codex_async_utils::OrCancelExt;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
@@ -1170,19 +1172,6 @@ impl Session {
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
}
fn interrupt_task_sync(&self) {
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
{
at.try_clear_pending_sync();
let tasks = at.drain_tasks();
*active = None;
for (_sub_id, task) in tasks {
task.handle.abort();
}
}
}
pub(crate) fn notifier(&self) -> &UserNotifier {
&self.services.notifier
}
@@ -1196,12 +1185,6 @@ impl Session {
}
}
impl Drop for Session {
fn drop(&mut self) {
self.interrupt_task_sync();
}
}
async fn submission_loop(
sess: Arc<Session>,
turn_context: TurnContext,
@@ -1711,6 +1694,7 @@ pub(crate) async fn run_task(
sub_id: String,
input: Vec<InputItem>,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> Option<String> {
if input.is_empty() {
return None;
@@ -1795,6 +1779,7 @@ pub(crate) async fn run_task(
sub_id.clone(),
turn_input,
task_kind,
cancellation_token.child_token(),
)
.await
{
@@ -1956,6 +1941,10 @@ pub(crate) async fn run_task(
}
continue;
}
Err(CodexErr::TurnAborted) => {
// Aborted turn is reported via a different event.
break;
}
Err(e) => {
info!("Turn error: {e:#}");
let event = Event {
@@ -2022,6 +2011,7 @@ async fn run_turn(
sub_id: String,
input: Vec<ResponseItem>,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = Arc::new(ToolRouter::from_config(
@@ -2052,10 +2042,12 @@ async fn run_turn(
&sub_id,
&prompt,
task_kind,
cancellation_token.child_token(),
)
.await
{
Ok(output) => return Ok(output),
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e),
@@ -2118,6 +2110,7 @@ struct TurnRunResult {
total_token_usage: Option<TokenUsage>,
}
#[allow(clippy::too_many_arguments)]
async fn try_run_turn(
router: Arc<ToolRouter>,
sess: Arc<Session>,
@@ -2126,6 +2119,7 @@ async fn try_run_turn(
sub_id: &str,
prompt: &Prompt,
task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
@@ -2195,7 +2189,8 @@ async fn try_run_turn(
.client
.clone()
.stream_with_task_kind(prompt.as_ref(), task_kind)
.await?;
.or_cancel(&cancellation_token)
.await??;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),
@@ -2211,7 +2206,8 @@ async fn try_run_turn(
// Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await;
let event = stream.next().or_cancel(&cancellation_token).await?;
let event = match event {
Some(res) => res?,
None => {
@@ -2316,7 +2312,10 @@ async fn try_run_turn(
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await;
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
let processed_items = output
.try_collect()
.or_cancel(&cancellation_token)
.await??;
let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await;
@@ -2554,6 +2553,8 @@ mod tests {
use codex_app_server_protocol::AuthMode;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use std::time::Duration;
use tokio::time::sleep;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
@@ -2563,8 +2564,6 @@ mod tests {
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::time::Duration;
use tokio::time::sleep;
#[test]
fn reconstruct_history_matches_live_compactions() {
@@ -2944,12 +2943,15 @@ mod tests {
}
#[derive(Clone, Copy)]
struct NeverEndingTask(TaskKind);
struct NeverEndingTask {
kind: TaskKind,
listen_to_cancellation_token: bool,
}
#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
self.0
self.kind
}
async fn run(
@@ -2958,20 +2960,26 @@ mod tests {
_ctx: Arc<TurnContext>,
_sub_id: String,
_input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> {
if self.listen_to_cancellation_token {
cancellation_token.cancelled().await;
return None;
}
loop {
sleep(Duration::from_secs(60)).await;
}
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
if let TaskKind::Review = self.0 {
if let TaskKind::Review = self.kind {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
}
}
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_regular_task_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
@@ -2982,7 +2990,41 @@ mod tests {
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Regular),
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for event")
.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn abort_gracefuly_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
let input = vec![InputItem::Text {
text: "hello".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
)
.await;
@@ -2996,7 +3038,7 @@ mod tests {
assert!(rx.try_recv().is_err());
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-review".to_string();
@@ -3007,18 +3049,27 @@ mod tests {
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Review),
NeverEndingTask {
kind: TaskKind::Review,
listen_to_cancellation_token: false,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let first = rx.recv().await.expect("first event");
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for first event")
.expect("first event");
match first.msg {
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
other => panic!("unexpected first event: {other:?}"),
}
let second = rx.recv().await.expect("second event");
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for second event")
.expect("second event");
match second.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected second event: {other:?}"),