From c03e31ecf5d4d9fed16ea163b3bcce83d030f74d Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Fri, 17 Oct 2025 11:52:57 -0700 Subject: [PATCH] Support graceful agent interruption (#5287) --- codex-rs/Cargo.lock | 34 +++++++++ codex-rs/Cargo.toml | 3 + codex-rs/async-utils/Cargo.toml | 15 ++++ codex-rs/async-utils/src/lib.rs | 86 +++++++++++++++++++++ codex-rs/core/Cargo.toml | 2 + codex-rs/core/src/codex.rs | 117 +++++++++++++++++++++-------- codex-rs/core/src/error.rs | 10 +++ codex-rs/core/src/lib.rs | 3 +- codex-rs/core/src/state/turn.rs | 15 ++-- codex-rs/core/src/tasks/compact.rs | 2 + codex-rs/core/src/tasks/mod.rs | 53 ++++++++++--- codex-rs/core/src/tasks/regular.rs | 12 ++- codex-rs/core/src/tasks/review.rs | 12 ++- 13 files changed, 309 insertions(+), 55 deletions(-) create mode 100644 codex-rs/async-utils/Cargo.toml create mode 100644 codex-rs/async-utils/src/lib.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 45a41b81..9d18361c 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -899,6 +899,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "codex-async-utils" +version = "0.0.0" +dependencies = [ + "async-trait", + "pretty_assertions", + "tokio", + "tokio-util", +] + [[package]] name = "codex-backend-client" version = "0.0.0" @@ -1037,6 +1047,7 @@ dependencies = [ "chrono", "codex-app-server-protocol", "codex-apply-patch", + "codex-async-utils", "codex-file-search", "codex-mcp-client", "codex-otel", @@ -1073,6 +1084,7 @@ dependencies = [ "similar", "strum_macros 0.27.2", "tempfile", + "test-log", "thiserror 2.0.16", "time", "tokio", @@ -6022,6 +6034,28 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "textwrap" version = "0.11.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 4ca4b7e6..64eebb62 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -2,6 +2,7 @@ members = [ "backend-client", "ansi-escape", + "async-utils", "app-server", "app-server-protocol", "apply-patch", @@ -56,6 +57,7 @@ codex-arg0 = { path = "arg0" } codex-chatgpt = { path = "chatgpt" } codex-common = { path = "common" } codex-core = { path = "core" } +codex-async-utils = { path = "async-utils" } codex-exec = { path = "exec" } codex-feedback = { path = "feedback" } codex-file-search = { path = "file-search" } @@ -164,6 +166,7 @@ strum_macros = "0.27.2" supports-color = "3.0.2" sys-locale = "0.3.2" tempfile = "3.23.0" +test-log = "0.2.18" textwrap = "0.16.2" thiserror = "2.0.16" time = "0.3" diff --git a/codex-rs/async-utils/Cargo.toml b/codex-rs/async-utils/Cargo.toml new file mode 100644 index 00000000..5203db0f --- /dev/null +++ b/codex-rs/async-utils/Cargo.toml @@ -0,0 +1,15 @@ +[package] +edition.workspace = true +name = "codex-async-utils" +version.workspace = true + +[lints] +workspace = true + +[dependencies] +async-trait.workspace = true +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "time"] } +tokio-util.workspace = true + +[dev-dependencies] +pretty_assertions.workspace = true diff --git a/codex-rs/async-utils/src/lib.rs b/codex-rs/async-utils/src/lib.rs new file mode 100644 index 00000000..bd880ae1 --- /dev/null +++ b/codex-rs/async-utils/src/lib.rs @@ -0,0 +1,86 @@ +use async_trait::async_trait; +use std::future::Future; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, PartialEq, Eq)] +pub enum CancelErr { + Cancelled, +} + +#[async_trait] +pub trait OrCancelExt: Sized { + type Output; + + async fn or_cancel(self, token: &CancellationToken) -> Result; +} + +#[async_trait] +impl OrCancelExt for F +where + F: Future + Send, + F::Output: Send, +{ + type Output = F::Output; + + async fn or_cancel(self, token: &CancellationToken) -> Result { + tokio::select! { + _ = token.cancelled() => Err(CancelErr::Cancelled), + res = self => Ok(res), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::time::Duration; + use tokio::task; + use tokio::time::sleep; + + #[tokio::test] + async fn returns_ok_when_future_completes_first() { + let token = CancellationToken::new(); + let value = async { 42 }; + + let result = value.or_cancel(&token).await; + + assert_eq!(Ok(42), result); + } + + #[tokio::test] + async fn returns_err_when_token_cancelled_first() { + let token = CancellationToken::new(); + let token_clone = token.clone(); + + let cancel_handle = task::spawn(async move { + sleep(Duration::from_millis(10)).await; + token_clone.cancel(); + }); + + let result = async { + sleep(Duration::from_millis(100)).await; + 7 + } + .or_cancel(&token) + .await; + + cancel_handle.await.expect("cancel task panicked"); + assert_eq!(Err(CancelErr::Cancelled), result); + } + + #[tokio::test] + async fn returns_err_when_token_already_cancelled() { + let token = CancellationToken::new(); + token.cancel(); + + let result = async { + sleep(Duration::from_millis(50)).await; + 5 + } + .or_cancel(&token) + .await; + + assert_eq!(Err(CancelErr::Cancelled), result); + } +} diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 4259e64f..d1320569 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -26,6 +26,7 @@ codex-mcp-client = { workspace = true } codex-otel = { workspace = true, features = ["otel"] } codex-protocol = { workspace = true } codex-rmcp-client = { workspace = true } +codex-async-utils = { workspace = true } codex-utils-string = { workspace = true } dirs = { workspace = true } dunce = { workspace = true } @@ -47,6 +48,7 @@ shlex = { workspace = true } similar = { workspace = true } strum_macros = { workspace = true } tempfile = { workspace = true } +test-log = { workspace = true } thiserror = { workspace = true } time = { workspace = true, features = [ "formatting", diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 93ba73eb..23751d1b 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -38,6 +38,7 @@ use serde_json; use serde_json::Value; use tokio::sync::Mutex; use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; use tracing::debug; use tracing::error; use tracing::info; @@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; use crate::user_notification::UserNotification; use crate::util::backoff; +use codex_async_utils::OrCancelExt; use codex_otel::otel_event_manager::OtelEventManager; use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig; use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; @@ -1170,19 +1172,6 @@ impl Session { self.abort_all_tasks(TurnAbortReason::Interrupted).await; } - fn interrupt_task_sync(&self) { - if let Ok(mut active) = self.active_turn.try_lock() - && let Some(at) = active.as_mut() - { - at.try_clear_pending_sync(); - let tasks = at.drain_tasks(); - *active = None; - for (_sub_id, task) in tasks { - task.handle.abort(); - } - } - } - pub(crate) fn notifier(&self) -> &UserNotifier { &self.services.notifier } @@ -1196,12 +1185,6 @@ impl Session { } } -impl Drop for Session { - fn drop(&mut self) { - self.interrupt_task_sync(); - } -} - async fn submission_loop( sess: Arc, turn_context: TurnContext, @@ -1711,6 +1694,7 @@ pub(crate) async fn run_task( sub_id: String, input: Vec, task_kind: TaskKind, + cancellation_token: CancellationToken, ) -> Option { if input.is_empty() { return None; @@ -1795,6 +1779,7 @@ pub(crate) async fn run_task( sub_id.clone(), turn_input, task_kind, + cancellation_token.child_token(), ) .await { @@ -1956,6 +1941,10 @@ pub(crate) async fn run_task( } continue; } + Err(CodexErr::TurnAborted) => { + // Aborted turn is reported via a different event. + break; + } Err(e) => { info!("Turn error: {e:#}"); let event = Event { @@ -2022,6 +2011,7 @@ async fn run_turn( sub_id: String, input: Vec, task_kind: TaskKind, + cancellation_token: CancellationToken, ) -> CodexResult { let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); let router = Arc::new(ToolRouter::from_config( @@ -2052,10 +2042,12 @@ async fn run_turn( &sub_id, &prompt, task_kind, + cancellation_token.child_token(), ) .await { Ok(output) => return Ok(output), + Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), @@ -2118,6 +2110,7 @@ struct TurnRunResult { total_token_usage: Option, } +#[allow(clippy::too_many_arguments)] async fn try_run_turn( router: Arc, sess: Arc, @@ -2126,6 +2119,7 @@ async fn try_run_turn( sub_id: &str, prompt: &Prompt, task_kind: TaskKind, + cancellation_token: CancellationToken, ) -> CodexResult { // call_ids that are part of this response. let completed_call_ids = prompt @@ -2195,7 +2189,8 @@ async fn try_run_turn( .client .clone() .stream_with_task_kind(prompt.as_ref(), task_kind) - .await?; + .or_cancel(&cancellation_token) + .await??; let tool_runtime = ToolCallRuntime::new( Arc::clone(&router), @@ -2211,7 +2206,8 @@ async fn try_run_turn( // Poll the next item from the model stream. We must inspect *both* Ok and Err // cases so that transient stream failures (e.g., dropped SSE connection before // `response.completed`) bubble up and trigger the caller's retry logic. - let event = stream.next().await; + let event = stream.next().or_cancel(&cancellation_token).await?; + let event = match event { Some(res) => res?, None => { @@ -2316,7 +2312,10 @@ async fn try_run_turn( sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref()) .await; - let processed_items: Vec = output.try_collect().await?; + let processed_items = output + .try_collect() + .or_cancel(&cancellation_token) + .await??; let unified_diff = { let mut tracker = turn_diff_tracker.lock().await; @@ -2554,6 +2553,8 @@ mod tests { use codex_app_server_protocol::AuthMode; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; + use std::time::Duration; + use tokio::time::sleep; use mcp_types::ContentBlock; use mcp_types::TextContent; @@ -2563,8 +2564,6 @@ mod tests { use std::path::PathBuf; use std::sync::Arc; use std::time::Duration as StdDuration; - use tokio::time::Duration; - use tokio::time::sleep; #[test] fn reconstruct_history_matches_live_compactions() { @@ -2944,12 +2943,15 @@ mod tests { } #[derive(Clone, Copy)] - struct NeverEndingTask(TaskKind); + struct NeverEndingTask { + kind: TaskKind, + listen_to_cancellation_token: bool, + } #[async_trait::async_trait] impl SessionTask for NeverEndingTask { fn kind(&self) -> TaskKind { - self.0 + self.kind } async fn run( @@ -2958,20 +2960,26 @@ mod tests { _ctx: Arc, _sub_id: String, _input: Vec, + cancellation_token: CancellationToken, ) -> Option { + if self.listen_to_cancellation_token { + cancellation_token.cancelled().await; + return None; + } loop { sleep(Duration::from_secs(60)).await; } } async fn abort(&self, session: Arc, sub_id: &str) { - if let TaskKind::Review = self.0 { + if let TaskKind::Review = self.kind { exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; } } } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[test_log::test] async fn abort_regular_task_emits_turn_aborted_only() { let (sess, tc, rx) = make_session_and_context_with_rx(); let sub_id = "sub-regular".to_string(); @@ -2982,7 +2990,41 @@ mod tests { Arc::clone(&tc), sub_id.clone(), input, - NeverEndingTask(TaskKind::Regular), + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: false, + }, + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for event") + .expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn abort_gracefuly_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx(); + let sub_id = "sub-regular".to_string(); + let input = vec![InputItem::Text { + text: "hello".to_string(), + }]; + sess.spawn_task( + Arc::clone(&tc), + sub_id.clone(), + input, + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, ) .await; @@ -2996,7 +3038,7 @@ mod tests { assert!(rx.try_recv().is_err()); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn abort_review_task_emits_exited_then_aborted_and_records_history() { let (sess, tc, rx) = make_session_and_context_with_rx(); let sub_id = "sub-review".to_string(); @@ -3007,18 +3049,27 @@ mod tests { Arc::clone(&tc), sub_id.clone(), input, - NeverEndingTask(TaskKind::Review), + NeverEndingTask { + kind: TaskKind::Review, + listen_to_cancellation_token: false, + }, ) .await; sess.abort_all_tasks(TurnAbortReason::Interrupted).await; - let first = rx.recv().await.expect("first event"); + let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for first event") + .expect("first event"); match first.msg { EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()), other => panic!("unexpected first event: {other:?}"), } - let second = rx.recv().await.expect("second event"); + let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv()) + .await + .expect("timeout waiting for second event") + .expect("second event"); match second.msg { EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), other => panic!("unexpected second event: {other:?}"), diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index da2a868d..951ba395 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -2,6 +2,7 @@ use crate::exec::ExecToolCallOutput; use crate::token_data::KnownPlan; use crate::token_data::PlanType; use crate::truncate::truncate_middle; +use codex_async_utils::CancelErr; use codex_protocol::ConversationId; use codex_protocol::protocol::RateLimitSnapshot; use reqwest::StatusCode; @@ -50,6 +51,9 @@ pub enum SandboxErr { #[derive(Error, Debug)] pub enum CodexErr { + #[error("turn aborted")] + TurnAborted, + /// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP /// handshake has succeeded but **before** it finished emitting `response.completed`. /// @@ -150,6 +154,12 @@ pub enum CodexErr { EnvVar(EnvVarError), } +impl From for CodexErr { + fn from(_: CancelErr) -> Self { + CodexErr::TurnAborted + } +} + #[derive(Debug)] pub struct ConnectionFailedError { pub source: reqwest::Error, diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 22e1d4cd..8ddf8bff 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -13,7 +13,6 @@ mod client; mod client_common; pub mod codex; mod codex_conversation; -pub mod token_data; pub use codex_conversation::CodexConversation; mod command_safety; pub mod config; @@ -39,6 +38,7 @@ mod mcp_tool_call; mod message_history; mod model_provider_info; pub mod parse_command; +pub mod token_data; mod truncate; mod unified_exec; mod user_instructions; @@ -107,5 +107,4 @@ pub use codex_protocol::models::LocalShellExecAction; pub use codex_protocol::models::LocalShellStatus; pub use codex_protocol::models::ReasoningItemContent; pub use codex_protocol::models::ResponseItem; - pub mod otel_init; diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index 89af13a1..3ed63c5b 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -4,7 +4,9 @@ use indexmap::IndexMap; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; -use tokio::task::AbortHandle; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; use codex_protocol::models::ResponseInputItem; use tokio::sync::oneshot; @@ -46,9 +48,11 @@ impl TaskKind { #[derive(Clone)] pub(crate) struct RunningTask { - pub(crate) handle: AbortHandle, + pub(crate) done: Arc, pub(crate) kind: TaskKind, pub(crate) task: Arc, + pub(crate) cancellation_token: CancellationToken, + pub(crate) handle: Arc>, } impl ActiveTurn { @@ -115,13 +119,6 @@ impl ActiveTurn { let mut ts = self.turn_state.lock().await; ts.clear_pending(); } - - /// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt). - pub(crate) fn try_clear_pending_sync(&self) { - if let Ok(mut ts) = self.turn_state.try_lock() { - ts.clear_pending(); - } - } } #[cfg(test)] diff --git a/codex-rs/core/src/tasks/compact.rs b/codex-rs/core/src/tasks/compact.rs index 823febfc..12e61b1f 100644 --- a/codex-rs/core/src/tasks/compact.rs +++ b/codex-rs/core/src/tasks/compact.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; +use tokio_util::sync::CancellationToken; use crate::codex::TurnContext; use crate::codex::compact; @@ -25,6 +26,7 @@ impl SessionTask for CompactTask { ctx: Arc, sub_id: String, input: Vec, + _cancellation_token: CancellationToken, ) -> Option { compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await } diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 464c1e63..15ec419f 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -3,9 +3,15 @@ mod regular; mod review; use std::sync::Arc; +use std::time::Duration; use async_trait::async_trait; +use tokio::select; +use tokio::sync::Notify; +use tokio_util::sync::CancellationToken; +use tokio_util::task::AbortOnDropHandle; use tracing::trace; +use tracing::warn; use crate::codex::Session; use crate::codex::TurnContext; @@ -23,6 +29,8 @@ pub(crate) use compact::CompactTask; pub(crate) use regular::RegularTask; pub(crate) use review::ReviewTask; +const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100; + /// Thin wrapper that exposes the parts of [`Session`] task runners need. #[derive(Clone)] pub(crate) struct SessionTaskContext { @@ -49,6 +57,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static { ctx: Arc, sub_id: String, input: Vec, + cancellation_token: CancellationToken, ) -> Option; async fn abort(&self, session: Arc, sub_id: &str) { @@ -69,26 +78,42 @@ impl Session { let task: Arc = Arc::new(task); let task_kind = task.kind(); + let cancellation_token = CancellationToken::new(); + let done = Arc::new(Notify::new()); + + let done_clone = Arc::clone(&done); let handle = { let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); let ctx = Arc::clone(&turn_context); let task_for_run = Arc::clone(&task); let sub_clone = sub_id.clone(); + let task_cancellation_token = cancellation_token.child_token(); tokio::spawn(async move { let last_agent_message = task_for_run - .run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input) + .run( + Arc::clone(&session_ctx), + ctx, + sub_clone.clone(), + input, + task_cancellation_token.child_token(), + ) .await; - // Emit completion uniformly from spawn site so all tasks share the same lifecycle. - let sess = session_ctx.clone_session(); - sess.on_task_finished(sub_clone, last_agent_message).await; + + if !task_cancellation_token.is_cancelled() { + // Emit completion uniformly from spawn site so all tasks share the same lifecycle. + let sess = session_ctx.clone_session(); + sess.on_task_finished(sub_clone, last_agent_message).await; + } + done_clone.notify_waiters(); }) - .abort_handle() }; let running_task = RunningTask { - handle, + done, + handle: Arc::new(AbortOnDropHandle::new(handle)), kind: task_kind, task, + cancellation_token, }; self.register_new_active_task(sub_id, running_task).await; } @@ -143,14 +168,24 @@ impl Session { task: RunningTask, reason: TurnAbortReason, ) { - if task.handle.is_finished() { + if task.cancellation_token.is_cancelled() { return; } trace!(task_kind = ?task.kind, sub_id, "aborting running task"); + task.cancellation_token.cancel(); let session_task = task.task; - let handle = task.handle; - handle.abort(); + + select! { + _ = task.done.notified() => { + }, + _ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => { + warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS); + } + } + + task.handle.abort(); + let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); session_task.abort(session_ctx, &sub_id).await; diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs index b3758d5f..3a3fa267 100644 --- a/codex-rs/core/src/tasks/regular.rs +++ b/codex-rs/core/src/tasks/regular.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; +use tokio_util::sync::CancellationToken; use crate::codex::TurnContext; use crate::codex::run_task; @@ -25,8 +26,17 @@ impl SessionTask for RegularTask { ctx: Arc, sub_id: String, input: Vec, + cancellation_token: CancellationToken, ) -> Option { let sess = session.clone_session(); - run_task(sess, ctx, sub_id, input, TaskKind::Regular).await + run_task( + sess, + ctx, + sub_id, + input, + TaskKind::Regular, + cancellation_token, + ) + .await } } diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs index cec92432..6b4b2175 100644 --- a/codex-rs/core/src/tasks/review.rs +++ b/codex-rs/core/src/tasks/review.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; +use tokio_util::sync::CancellationToken; use crate::codex::TurnContext; use crate::codex::exit_review_mode; @@ -26,9 +27,18 @@ impl SessionTask for ReviewTask { ctx: Arc, sub_id: String, input: Vec, + cancellation_token: CancellationToken, ) -> Option { let sess = session.clone_session(); - run_task(sess, ctx, sub_id, input, TaskKind::Review).await + run_task( + sess, + ctx, + sub_id, + input, + TaskKind::Review, + cancellation_token, + ) + .await } async fn abort(&self, session: Arc, sub_id: &str) {