Support graceful agent interruption (#5287)

This commit is contained in:
pakrym-oai
2025-10-17 11:52:57 -07:00
committed by GitHub
parent 6915ba2100
commit c03e31ecf5
13 changed files with 309 additions and 55 deletions

34
codex-rs/Cargo.lock generated
View File

@@ -899,6 +899,16 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "codex-async-utils"
version = "0.0.0"
dependencies = [
"async-trait",
"pretty_assertions",
"tokio",
"tokio-util",
]
[[package]] [[package]]
name = "codex-backend-client" name = "codex-backend-client"
version = "0.0.0" version = "0.0.0"
@@ -1037,6 +1047,7 @@ dependencies = [
"chrono", "chrono",
"codex-app-server-protocol", "codex-app-server-protocol",
"codex-apply-patch", "codex-apply-patch",
"codex-async-utils",
"codex-file-search", "codex-file-search",
"codex-mcp-client", "codex-mcp-client",
"codex-otel", "codex-otel",
@@ -1073,6 +1084,7 @@ dependencies = [
"similar", "similar",
"strum_macros 0.27.2", "strum_macros 0.27.2",
"tempfile", "tempfile",
"test-log",
"thiserror 2.0.16", "thiserror 2.0.16",
"time", "time",
"tokio", "tokio",
@@ -6022,6 +6034,28 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]]
name = "test-log"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
dependencies = [
"env_logger",
"test-log-macros",
"tracing-subscriber",
]
[[package]]
name = "test-log-macros"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.104",
]
[[package]] [[package]]
name = "textwrap" name = "textwrap"
version = "0.11.0" version = "0.11.0"

View File

@@ -2,6 +2,7 @@
members = [ members = [
"backend-client", "backend-client",
"ansi-escape", "ansi-escape",
"async-utils",
"app-server", "app-server",
"app-server-protocol", "app-server-protocol",
"apply-patch", "apply-patch",
@@ -56,6 +57,7 @@ codex-arg0 = { path = "arg0" }
codex-chatgpt = { path = "chatgpt" } codex-chatgpt = { path = "chatgpt" }
codex-common = { path = "common" } codex-common = { path = "common" }
codex-core = { path = "core" } codex-core = { path = "core" }
codex-async-utils = { path = "async-utils" }
codex-exec = { path = "exec" } codex-exec = { path = "exec" }
codex-feedback = { path = "feedback" } codex-feedback = { path = "feedback" }
codex-file-search = { path = "file-search" } codex-file-search = { path = "file-search" }
@@ -164,6 +166,7 @@ strum_macros = "0.27.2"
supports-color = "3.0.2" supports-color = "3.0.2"
sys-locale = "0.3.2" sys-locale = "0.3.2"
tempfile = "3.23.0" tempfile = "3.23.0"
test-log = "0.2.18"
textwrap = "0.16.2" textwrap = "0.16.2"
thiserror = "2.0.16" thiserror = "2.0.16"
time = "0.3" time = "0.3"

View File

@@ -0,0 +1,15 @@
[package]
edition.workspace = true
name = "codex-async-utils"
version.workspace = true
[lints]
workspace = true
[dependencies]
async-trait.workspace = true
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "time"] }
tokio-util.workspace = true
[dev-dependencies]
pretty_assertions.workspace = true

View File

@@ -0,0 +1,86 @@
use async_trait::async_trait;
use std::future::Future;
use tokio_util::sync::CancellationToken;
#[derive(Debug, PartialEq, Eq)]
pub enum CancelErr {
Cancelled,
}
#[async_trait]
pub trait OrCancelExt: Sized {
type Output;
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr>;
}
#[async_trait]
impl<F> OrCancelExt for F
where
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr> {
tokio::select! {
_ = token.cancelled() => Err(CancelErr::Cancelled),
res = self => Ok(res),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tokio::task;
use tokio::time::sleep;
#[tokio::test]
async fn returns_ok_when_future_completes_first() {
let token = CancellationToken::new();
let value = async { 42 };
let result = value.or_cancel(&token).await;
assert_eq!(Ok(42), result);
}
#[tokio::test]
async fn returns_err_when_token_cancelled_first() {
let token = CancellationToken::new();
let token_clone = token.clone();
let cancel_handle = task::spawn(async move {
sleep(Duration::from_millis(10)).await;
token_clone.cancel();
});
let result = async {
sleep(Duration::from_millis(100)).await;
7
}
.or_cancel(&token)
.await;
cancel_handle.await.expect("cancel task panicked");
assert_eq!(Err(CancelErr::Cancelled), result);
}
#[tokio::test]
async fn returns_err_when_token_already_cancelled() {
let token = CancellationToken::new();
token.cancel();
let result = async {
sleep(Duration::from_millis(50)).await;
5
}
.or_cancel(&token)
.await;
assert_eq!(Err(CancelErr::Cancelled), result);
}
}

View File

@@ -26,6 +26,7 @@ codex-mcp-client = { workspace = true }
codex-otel = { workspace = true, features = ["otel"] } codex-otel = { workspace = true, features = ["otel"] }
codex-protocol = { workspace = true } codex-protocol = { workspace = true }
codex-rmcp-client = { workspace = true } codex-rmcp-client = { workspace = true }
codex-async-utils = { workspace = true }
codex-utils-string = { workspace = true } codex-utils-string = { workspace = true }
dirs = { workspace = true } dirs = { workspace = true }
dunce = { workspace = true } dunce = { workspace = true }
@@ -47,6 +48,7 @@ shlex = { workspace = true }
similar = { workspace = true } similar = { workspace = true }
strum_macros = { workspace = true } strum_macros = { workspace = true }
tempfile = { workspace = true } tempfile = { workspace = true }
test-log = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
time = { workspace = true, features = [ time = { workspace = true, features = [
"formatting", "formatting",

View File

@@ -38,6 +38,7 @@ use serde_json;
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::debug; use tracing::debug;
use tracing::error; use tracing::error;
use tracing::info; use tracing::info;
@@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions; use crate::user_instructions::UserInstructions;
use crate::user_notification::UserNotification; use crate::user_notification::UserNotification;
use crate::util::backoff; use crate::util::backoff;
use codex_async_utils::OrCancelExt;
use codex_otel::otel_event_manager::OtelEventManager; use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
@@ -1170,19 +1172,6 @@ impl Session {
self.abort_all_tasks(TurnAbortReason::Interrupted).await; self.abort_all_tasks(TurnAbortReason::Interrupted).await;
} }
fn interrupt_task_sync(&self) {
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
{
at.try_clear_pending_sync();
let tasks = at.drain_tasks();
*active = None;
for (_sub_id, task) in tasks {
task.handle.abort();
}
}
}
pub(crate) fn notifier(&self) -> &UserNotifier { pub(crate) fn notifier(&self) -> &UserNotifier {
&self.services.notifier &self.services.notifier
} }
@@ -1196,12 +1185,6 @@ impl Session {
} }
} }
impl Drop for Session {
fn drop(&mut self) {
self.interrupt_task_sync();
}
}
async fn submission_loop( async fn submission_loop(
sess: Arc<Session>, sess: Arc<Session>,
turn_context: TurnContext, turn_context: TurnContext,
@@ -1711,6 +1694,7 @@ pub(crate) async fn run_task(
sub_id: String, sub_id: String,
input: Vec<InputItem>, input: Vec<InputItem>,
task_kind: TaskKind, task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> Option<String> { ) -> Option<String> {
if input.is_empty() { if input.is_empty() {
return None; return None;
@@ -1795,6 +1779,7 @@ pub(crate) async fn run_task(
sub_id.clone(), sub_id.clone(),
turn_input, turn_input,
task_kind, task_kind,
cancellation_token.child_token(),
) )
.await .await
{ {
@@ -1956,6 +1941,10 @@ pub(crate) async fn run_task(
} }
continue; continue;
} }
Err(CodexErr::TurnAborted) => {
// Aborted turn is reported via a different event.
break;
}
Err(e) => { Err(e) => {
info!("Turn error: {e:#}"); info!("Turn error: {e:#}");
let event = Event { let event = Event {
@@ -2022,6 +2011,7 @@ async fn run_turn(
sub_id: String, sub_id: String,
input: Vec<ResponseItem>, input: Vec<ResponseItem>,
task_kind: TaskKind, task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> { ) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = Arc::new(ToolRouter::from_config( let router = Arc::new(ToolRouter::from_config(
@@ -2052,10 +2042,12 @@ async fn run_turn(
&sub_id, &sub_id,
&prompt, &prompt,
task_kind, task_kind,
cancellation_token.child_token(),
) )
.await .await
{ {
Ok(output) => return Ok(output), Ok(output) => return Ok(output),
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::Fatal(_)) => return Err(e),
@@ -2118,6 +2110,7 @@ struct TurnRunResult {
total_token_usage: Option<TokenUsage>, total_token_usage: Option<TokenUsage>,
} }
#[allow(clippy::too_many_arguments)]
async fn try_run_turn( async fn try_run_turn(
router: Arc<ToolRouter>, router: Arc<ToolRouter>,
sess: Arc<Session>, sess: Arc<Session>,
@@ -2126,6 +2119,7 @@ async fn try_run_turn(
sub_id: &str, sub_id: &str,
prompt: &Prompt, prompt: &Prompt,
task_kind: TaskKind, task_kind: TaskKind,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> { ) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response. // call_ids that are part of this response.
let completed_call_ids = prompt let completed_call_ids = prompt
@@ -2195,7 +2189,8 @@ async fn try_run_turn(
.client .client
.clone() .clone()
.stream_with_task_kind(prompt.as_ref(), task_kind) .stream_with_task_kind(prompt.as_ref(), task_kind)
.await?; .or_cancel(&cancellation_token)
.await??;
let tool_runtime = ToolCallRuntime::new( let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router), Arc::clone(&router),
@@ -2211,7 +2206,8 @@ async fn try_run_turn(
// Poll the next item from the model stream. We must inspect *both* Ok and Err // Poll the next item from the model stream. We must inspect *both* Ok and Err
// cases so that transient stream failures (e.g., dropped SSE connection before // cases so that transient stream failures (e.g., dropped SSE connection before
// `response.completed`) bubble up and trigger the caller's retry logic. // `response.completed`) bubble up and trigger the caller's retry logic.
let event = stream.next().await; let event = stream.next().or_cancel(&cancellation_token).await?;
let event = match event { let event = match event {
Some(res) => res?, Some(res) => res?,
None => { None => {
@@ -2316,7 +2312,10 @@ async fn try_run_turn(
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref()) sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
.await; .await;
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?; let processed_items = output
.try_collect()
.or_cancel(&cancellation_token)
.await??;
let unified_diff = { let unified_diff = {
let mut tracker = turn_diff_tracker.lock().await; let mut tracker = turn_diff_tracker.lock().await;
@@ -2554,6 +2553,8 @@ mod tests {
use codex_app_server_protocol::AuthMode; use codex_app_server_protocol::AuthMode;
use codex_protocol::models::ContentItem; use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem; use codex_protocol::models::ResponseItem;
use std::time::Duration;
use tokio::time::sleep;
use mcp_types::ContentBlock; use mcp_types::ContentBlock;
use mcp_types::TextContent; use mcp_types::TextContent;
@@ -2563,8 +2564,6 @@ mod tests {
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration as StdDuration; use std::time::Duration as StdDuration;
use tokio::time::Duration;
use tokio::time::sleep;
#[test] #[test]
fn reconstruct_history_matches_live_compactions() { fn reconstruct_history_matches_live_compactions() {
@@ -2944,12 +2943,15 @@ mod tests {
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
struct NeverEndingTask(TaskKind); struct NeverEndingTask {
kind: TaskKind,
listen_to_cancellation_token: bool,
}
#[async_trait::async_trait] #[async_trait::async_trait]
impl SessionTask for NeverEndingTask { impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind { fn kind(&self) -> TaskKind {
self.0 self.kind
} }
async fn run( async fn run(
@@ -2958,20 +2960,26 @@ mod tests {
_ctx: Arc<TurnContext>, _ctx: Arc<TurnContext>,
_sub_id: String, _sub_id: String,
_input: Vec<InputItem>, _input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> { ) -> Option<String> {
if self.listen_to_cancellation_token {
cancellation_token.cancelled().await;
return None;
}
loop { loop {
sleep(Duration::from_secs(60)).await; sleep(Duration::from_secs(60)).await;
} }
} }
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) { async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
if let TaskKind::Review = self.0 { if let TaskKind::Review = self.kind {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
} }
} }
} }
#[tokio::test] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[test_log::test]
async fn abort_regular_task_emits_turn_aborted_only() { async fn abort_regular_task_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx(); let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string(); let sub_id = "sub-regular".to_string();
@@ -2982,7 +2990,41 @@ mod tests {
Arc::clone(&tc), Arc::clone(&tc),
sub_id.clone(), sub_id.clone(),
input, input,
NeverEndingTask(TaskKind::Regular), NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: false,
},
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for event")
.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn abort_gracefuly_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
let input = vec![InputItem::Text {
text: "hello".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask {
kind: TaskKind::Regular,
listen_to_cancellation_token: true,
},
) )
.await; .await;
@@ -2996,7 +3038,7 @@ mod tests {
assert!(rx.try_recv().is_err()); assert!(rx.try_recv().is_err());
} }
#[tokio::test] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() { async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx(); let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-review".to_string(); let sub_id = "sub-review".to_string();
@@ -3007,18 +3049,27 @@ mod tests {
Arc::clone(&tc), Arc::clone(&tc),
sub_id.clone(), sub_id.clone(),
input, input,
NeverEndingTask(TaskKind::Review), NeverEndingTask {
kind: TaskKind::Review,
listen_to_cancellation_token: false,
},
) )
.await; .await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await; sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let first = rx.recv().await.expect("first event"); let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for first event")
.expect("first event");
match first.msg { match first.msg {
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()), EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
other => panic!("unexpected first event: {other:?}"), other => panic!("unexpected first event: {other:?}"),
} }
let second = rx.recv().await.expect("second event"); let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
.await
.expect("timeout waiting for second event")
.expect("second event");
match second.msg { match second.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected second event: {other:?}"), other => panic!("unexpected second event: {other:?}"),

View File

@@ -2,6 +2,7 @@ use crate::exec::ExecToolCallOutput;
use crate::token_data::KnownPlan; use crate::token_data::KnownPlan;
use crate::token_data::PlanType; use crate::token_data::PlanType;
use crate::truncate::truncate_middle; use crate::truncate::truncate_middle;
use codex_async_utils::CancelErr;
use codex_protocol::ConversationId; use codex_protocol::ConversationId;
use codex_protocol::protocol::RateLimitSnapshot; use codex_protocol::protocol::RateLimitSnapshot;
use reqwest::StatusCode; use reqwest::StatusCode;
@@ -50,6 +51,9 @@ pub enum SandboxErr {
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum CodexErr { pub enum CodexErr {
#[error("turn aborted")]
TurnAborted,
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP /// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
/// handshake has succeeded but **before** it finished emitting `response.completed`. /// handshake has succeeded but **before** it finished emitting `response.completed`.
/// ///
@@ -150,6 +154,12 @@ pub enum CodexErr {
EnvVar(EnvVarError), EnvVar(EnvVarError),
} }
impl From<CancelErr> for CodexErr {
fn from(_: CancelErr) -> Self {
CodexErr::TurnAborted
}
}
#[derive(Debug)] #[derive(Debug)]
pub struct ConnectionFailedError { pub struct ConnectionFailedError {
pub source: reqwest::Error, pub source: reqwest::Error,

View File

@@ -13,7 +13,6 @@ mod client;
mod client_common; mod client_common;
pub mod codex; pub mod codex;
mod codex_conversation; mod codex_conversation;
pub mod token_data;
pub use codex_conversation::CodexConversation; pub use codex_conversation::CodexConversation;
mod command_safety; mod command_safety;
pub mod config; pub mod config;
@@ -39,6 +38,7 @@ mod mcp_tool_call;
mod message_history; mod message_history;
mod model_provider_info; mod model_provider_info;
pub mod parse_command; pub mod parse_command;
pub mod token_data;
mod truncate; mod truncate;
mod unified_exec; mod unified_exec;
mod user_instructions; mod user_instructions;
@@ -107,5 +107,4 @@ pub use codex_protocol::models::LocalShellExecAction;
pub use codex_protocol::models::LocalShellStatus; pub use codex_protocol::models::LocalShellStatus;
pub use codex_protocol::models::ReasoningItemContent; pub use codex_protocol::models::ReasoningItemContent;
pub use codex_protocol::models::ResponseItem; pub use codex_protocol::models::ResponseItem;
pub mod otel_init; pub mod otel_init;

View File

@@ -4,7 +4,9 @@ use indexmap::IndexMap;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::AbortHandle; use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseInputItem;
use tokio::sync::oneshot; use tokio::sync::oneshot;
@@ -46,9 +48,11 @@ impl TaskKind {
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct RunningTask { pub(crate) struct RunningTask {
pub(crate) handle: AbortHandle, pub(crate) done: Arc<Notify>,
pub(crate) kind: TaskKind, pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn SessionTask>, pub(crate) task: Arc<dyn SessionTask>,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
} }
impl ActiveTurn { impl ActiveTurn {
@@ -115,13 +119,6 @@ impl ActiveTurn {
let mut ts = self.turn_state.lock().await; let mut ts = self.turn_state.lock().await;
ts.clear_pending(); ts.clear_pending();
} }
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
pub(crate) fn try_clear_pending_sync(&self) {
if let Ok(mut ts) = self.turn_state.try_lock() {
ts.clear_pending();
}
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -1,6 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext; use crate::codex::TurnContext;
use crate::codex::compact; use crate::codex::compact;
@@ -25,6 +26,7 @@ impl SessionTask for CompactTask {
ctx: Arc<TurnContext>, ctx: Arc<TurnContext>,
sub_id: String, sub_id: String,
input: Vec<InputItem>, input: Vec<InputItem>,
_cancellation_token: CancellationToken,
) -> Option<String> { ) -> Option<String> {
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
} }

View File

@@ -3,9 +3,15 @@ mod regular;
mod review; mod review;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait; use async_trait::async_trait;
use tokio::select;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing::trace; use tracing::trace;
use tracing::warn;
use crate::codex::Session; use crate::codex::Session;
use crate::codex::TurnContext; use crate::codex::TurnContext;
@@ -23,6 +29,8 @@ pub(crate) use compact::CompactTask;
pub(crate) use regular::RegularTask; pub(crate) use regular::RegularTask;
pub(crate) use review::ReviewTask; pub(crate) use review::ReviewTask;
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
/// Thin wrapper that exposes the parts of [`Session`] task runners need. /// Thin wrapper that exposes the parts of [`Session`] task runners need.
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct SessionTaskContext { pub(crate) struct SessionTaskContext {
@@ -49,6 +57,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
ctx: Arc<TurnContext>, ctx: Arc<TurnContext>,
sub_id: String, sub_id: String,
input: Vec<InputItem>, input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String>; ) -> Option<String>;
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) { async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
@@ -69,26 +78,42 @@ impl Session {
let task: Arc<dyn SessionTask> = Arc::new(task); let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind(); let task_kind = task.kind();
let cancellation_token = CancellationToken::new();
let done = Arc::new(Notify::new());
let done_clone = Arc::clone(&done);
let handle = { let handle = {
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let ctx = Arc::clone(&turn_context); let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task); let task_for_run = Arc::clone(&task);
let sub_clone = sub_id.clone(); let sub_clone = sub_id.clone();
let task_cancellation_token = cancellation_token.child_token();
tokio::spawn(async move { tokio::spawn(async move {
let last_agent_message = task_for_run let last_agent_message = task_for_run
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input) .run(
Arc::clone(&session_ctx),
ctx,
sub_clone.clone(),
input,
task_cancellation_token.child_token(),
)
.await; .await;
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session(); if !task_cancellation_token.is_cancelled() {
sess.on_task_finished(sub_clone, last_agent_message).await; // Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(sub_clone, last_agent_message).await;
}
done_clone.notify_waiters();
}) })
.abort_handle()
}; };
let running_task = RunningTask { let running_task = RunningTask {
handle, done,
handle: Arc::new(AbortOnDropHandle::new(handle)),
kind: task_kind, kind: task_kind,
task, task,
cancellation_token,
}; };
self.register_new_active_task(sub_id, running_task).await; self.register_new_active_task(sub_id, running_task).await;
} }
@@ -143,14 +168,24 @@ impl Session {
task: RunningTask, task: RunningTask,
reason: TurnAbortReason, reason: TurnAbortReason,
) { ) {
if task.handle.is_finished() { if task.cancellation_token.is_cancelled() {
return; return;
} }
trace!(task_kind = ?task.kind, sub_id, "aborting running task"); trace!(task_kind = ?task.kind, sub_id, "aborting running task");
task.cancellation_token.cancel();
let session_task = task.task; let session_task = task.task;
let handle = task.handle;
handle.abort(); select! {
_ = task.done.notified() => {
},
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
}
}
task.handle.abort();
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task.abort(session_ctx, &sub_id).await; session_task.abort(session_ctx, &sub_id).await;

View File

@@ -1,6 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext; use crate::codex::TurnContext;
use crate::codex::run_task; use crate::codex::run_task;
@@ -25,8 +26,17 @@ impl SessionTask for RegularTask {
ctx: Arc<TurnContext>, ctx: Arc<TurnContext>,
sub_id: String, sub_id: String,
input: Vec<InputItem>, input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> { ) -> Option<String> {
let sess = session.clone_session(); let sess = session.clone_session();
run_task(sess, ctx, sub_id, input, TaskKind::Regular).await run_task(
sess,
ctx,
sub_id,
input,
TaskKind::Regular,
cancellation_token,
)
.await
} }
} }

View File

@@ -1,6 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::codex::TurnContext; use crate::codex::TurnContext;
use crate::codex::exit_review_mode; use crate::codex::exit_review_mode;
@@ -26,9 +27,18 @@ impl SessionTask for ReviewTask {
ctx: Arc<TurnContext>, ctx: Arc<TurnContext>,
sub_id: String, sub_id: String,
input: Vec<InputItem>, input: Vec<InputItem>,
cancellation_token: CancellationToken,
) -> Option<String> { ) -> Option<String> {
let sess = session.clone_session(); let sess = session.clone_session();
run_task(sess, ctx, sub_id, input, TaskKind::Review).await run_task(
sess,
ctx,
sub_id,
input,
TaskKind::Review,
cancellation_token,
)
.await
} }
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) { async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {