Support graceful agent interruption (#5287)
This commit is contained in:
@@ -38,6 +38,7 @@ use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
@@ -1170,19 +1172,6 @@ impl Session {
|
||||
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
}
|
||||
|
||||
fn interrupt_task_sync(&self) {
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
{
|
||||
at.try_clear_pending_sync();
|
||||
let tasks = at.drain_tasks();
|
||||
*active = None;
|
||||
for (_sub_id, task) in tasks {
|
||||
task.handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn notifier(&self) -> &UserNotifier {
|
||||
&self.services.notifier
|
||||
}
|
||||
@@ -1196,12 +1185,6 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.interrupt_task_sync();
|
||||
}
|
||||
}
|
||||
|
||||
async fn submission_loop(
|
||||
sess: Arc<Session>,
|
||||
turn_context: TurnContext,
|
||||
@@ -1711,6 +1694,7 @@ pub(crate) async fn run_task(
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
if input.is_empty() {
|
||||
return None;
|
||||
@@ -1795,6 +1779,7 @@ pub(crate) async fn run_task(
|
||||
sub_id.clone(),
|
||||
turn_input,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1956,6 +1941,10 @@ pub(crate) async fn run_task(
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Err(CodexErr::TurnAborted) => {
|
||||
// Aborted turn is reported via a different event.
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
let event = Event {
|
||||
@@ -2022,6 +2011,7 @@ async fn run_turn(
|
||||
sub_id: String,
|
||||
input: Vec<ResponseItem>,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
@@ -2052,10 +2042,12 @@ async fn run_turn(
|
||||
&sub_id,
|
||||
&prompt,
|
||||
task_kind,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
@@ -2118,6 +2110,7 @@ struct TurnRunResult {
|
||||
total_token_usage: Option<TokenUsage>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn try_run_turn(
|
||||
router: Arc<ToolRouter>,
|
||||
sess: Arc<Session>,
|
||||
@@ -2126,6 +2119,7 @@ async fn try_run_turn(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
// call_ids that are part of this response.
|
||||
let completed_call_ids = prompt
|
||||
@@ -2195,7 +2189,8 @@ async fn try_run_turn(
|
||||
.client
|
||||
.clone()
|
||||
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
||||
.await?;
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let tool_runtime = ToolCallRuntime::new(
|
||||
Arc::clone(&router),
|
||||
@@ -2211,7 +2206,8 @@ async fn try_run_turn(
|
||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||
let event = stream.next().await;
|
||||
let event = stream.next().or_cancel(&cancellation_token).await?;
|
||||
|
||||
let event = match event {
|
||||
Some(res) => res?,
|
||||
None => {
|
||||
@@ -2316,7 +2312,10 @@ async fn try_run_turn(
|
||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||
.await;
|
||||
|
||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
||||
let processed_items = output
|
||||
.try_collect()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
let unified_diff = {
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
@@ -2554,6 +2553,8 @@ mod tests {
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
@@ -2563,8 +2564,6 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_history_matches_live_compactions() {
|
||||
@@ -2944,12 +2943,15 @@ mod tests {
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NeverEndingTask(TaskKind);
|
||||
struct NeverEndingTask {
|
||||
kind: TaskKind,
|
||||
listen_to_cancellation_token: bool,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SessionTask for NeverEndingTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
self.0
|
||||
self.kind
|
||||
}
|
||||
|
||||
async fn run(
|
||||
@@ -2958,20 +2960,26 @@ mod tests {
|
||||
_ctx: Arc<TurnContext>,
|
||||
_sub_id: String,
|
||||
_input: Vec<InputItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
if self.listen_to_cancellation_token {
|
||||
cancellation_token.cancelled().await;
|
||||
return None;
|
||||
}
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
if let TaskKind::Review = self.0 {
|
||||
if let TaskKind::Review = self.kind {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[test_log::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
@@ -2982,7 +2990,41 @@ mod tests {
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Regular),
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_gracefuly_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -2996,7 +3038,7 @@ mod tests {
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-review".to_string();
|
||||
@@ -3007,18 +3049,27 @@ mod tests {
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Review),
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Review,
|
||||
listen_to_cancellation_token: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let first = rx.recv().await.expect("first event");
|
||||
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for first event")
|
||||
.expect("first event");
|
||||
match first.msg {
|
||||
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
let second = rx.recv().await.expect("second event");
|
||||
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("timeout waiting for second event")
|
||||
.expect("second event");
|
||||
match second.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
|
||||
Reference in New Issue
Block a user