From 1fc3413a46f8b07e2175cbee2bba3e9585f37ccf Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 26 Sep 2025 15:49:08 +0200 Subject: [PATCH] ref: state - 2 (#4229) Extracting tasks in a module and start abstraction behind a Trait (more to come on this but each task will be tackled in a dedicated PR) The goal was to drop the ActiveTask and to have a (potentially) set of tasks during each turn --- codex-rs/Cargo.lock | 2 + codex-rs/Cargo.toml | 1 + codex-rs/core/Cargo.toml | 2 + codex-rs/core/src/codex.rs | 405 ++++++++++++----------- codex-rs/core/src/codex/compact.rs | 34 +- codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/state/mod.rs | 3 +- codex-rs/core/src/state/session.rs | 2 - codex-rs/core/src/state/turn.rs | 59 +++- codex-rs/core/src/tasks/compact.rs | 31 ++ codex-rs/core/src/tasks/mod.rs | 166 ++++++++++ codex-rs/core/src/tasks/regular.rs | 32 ++ codex-rs/core/src/tasks/review.rs | 37 +++ codex-rs/core/tests/suite/abort_tasks.rs | 66 ++++ codex-rs/core/tests/suite/mod.rs | 2 + 15 files changed, 617 insertions(+), 226 deletions(-) create mode 100644 codex-rs/core/src/tasks/compact.rs create mode 100644 codex-rs/core/src/tasks/mod.rs create mode 100644 codex-rs/core/src/tasks/regular.rs create mode 100644 codex-rs/core/src/tasks/review.rs create mode 100644 codex-rs/core/tests/suite/abort_tasks.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 8b71f139..a54222a8 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -673,6 +673,7 @@ dependencies = [ "askama", "assert_cmd", "async-channel", + "async-trait", "base64", "bytes", "chrono", @@ -685,6 +686,7 @@ dependencies = [ "env-flags", "eventsource-stream", "futures", + "indexmap 2.10.0", "landlock", "libc", "maplit", diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 7b4db5fc..237f5ea0 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -85,6 +85,7 @@ icu_decimal = "2.0.0" icu_locale_core = "2.0.0" ignore = "0.4.23" image = { version = "^0.25.8", default-features = false } +indexmap = "2.6.0" insta = "1.43.2" itertools = "0.14.0" landlock = "0.4.1" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index d9ded082..a1e7876a 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -15,6 +15,7 @@ workspace = true anyhow = { workspace = true } askama = { workspace = true } async-channel = { workspace = true } +async-trait = { workspace = true } base64 = { workspace = true } bytes = { workspace = true } chrono = { workspace = true, features = ["serde"] } @@ -26,6 +27,7 @@ dirs = { workspace = true } env-flags = { workspace = true } eventsource-stream = { workspace = true } futures = { workspace = true } +indexmap = { workspace = true } libc = { workspace = true } mcp-types = { workspace = true } os_info = { workspace = true } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 48aa3dbd..93850a65 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -24,7 +24,6 @@ use codex_protocol::protocol::ReviewRequest; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::TaskStartedEvent; use codex_protocol::protocol::TurnAbortReason; -use codex_protocol::protocol::TurnAbortedEvent; use codex_protocol::protocol::TurnContextItem; use futures::prelude::*; use mcp_types::CallToolResult; @@ -34,7 +33,6 @@ use serde_json; use serde_json::Value; use tokio::sync::Mutex; use tokio::sync::oneshot; -use tokio::task::AbortHandle; use tracing::debug; use tracing::error; use tracing::info; @@ -107,7 +105,6 @@ use crate::protocol::SandboxPolicy; use crate::protocol::SessionConfiguredEvent; use crate::protocol::StreamErrorEvent; use crate::protocol::Submission; -use crate::protocol::TaskCompleteEvent; use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; use crate::protocol::TurnDiffEvent; @@ -120,6 +117,9 @@ use crate::safety::assess_safety_for_untrusted_command; use crate::shell; use crate::state::ActiveTurn; use crate::state::SessionServices; +use crate::tasks::CompactTask; +use crate::tasks::RegularTask; +use crate::tasks::ReviewTask; use crate::turn_diff_tracker::TurnDiffTracker; use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; @@ -262,7 +262,7 @@ pub(crate) struct Session { conversation_id: ConversationId, tx_event: Sender, state: Mutex, - active_turn: Mutex>, + pub(crate) active_turn: Mutex>, services: SessionServices, next_internal_sub_id: AtomicU64, } @@ -495,38 +495,6 @@ impl Session { Ok((sess, turn_context)) } - pub async fn set_task(&self, task: AgentTask) { - let mut state = self.state.lock().await; - if let Some(current_task) = state.current_task.take() { - current_task.abort(TurnAbortReason::Replaced); - } - state.current_task = Some(task); - if let Some(current_task) = &state.current_task { - let mut active = self.active_turn.lock().await; - *active = Some(ActiveTurn { - sub_id: current_task.sub_id.clone(), - turn_state: std::sync::Arc::new(tokio::sync::Mutex::new( - crate::state::TurnState::default(), - )), - }); - } - } - - pub async fn remove_task(&self, sub_id: &str) { - let mut state = self.state.lock().await; - if let Some(task) = &state.current_task - && task.sub_id == sub_id - { - state.current_task.take(); - } - let mut active = self.active_turn.lock().await; - if let Some(at) = &*active - && at.sub_id == sub_id - { - *active = None; - } - } - fn next_internal_sub_id(&self) -> String { let id = self .next_internal_sub_id @@ -1015,26 +983,25 @@ impl Session { /// Returns the input if there was no task running to inject into pub async fn inject_input(&self, input: Vec) -> Result<(), Vec> { - let state = self.state.lock().await; - if state.current_task.is_some() { - let mut active = self.active_turn.lock().await; - if let Some(at) = active.as_mut() { + let mut active = self.active_turn.lock().await; + match active.as_mut() { + Some(at) => { let mut ts = at.turn_state.lock().await; ts.push_pending_input(input.into()); + Ok(()) } - Ok(()) - } else { - Err(input) + None => Err(input), } } pub async fn get_pending_input(&self) -> Vec { let mut active = self.active_turn.lock().await; - if let Some(at) = active.as_mut() { - let mut ts = at.turn_state.lock().await; - ts.take_pending_input() - } else { - Vec::with_capacity(0) + match active.as_mut() { + Some(at) => { + let mut ts = at.turn_state.lock().await; + ts.take_pending_input() + } + None => Vec::with_capacity(0), } } @@ -1050,29 +1017,20 @@ impl Session { .await } - pub async fn interrupt_task(&self) { + pub async fn interrupt_task(self: &Arc) { info!("interrupt received: abort current task, if any"); - let mut state = self.state.lock().await; - let mut active = self.active_turn.lock().await; - if let Some(at) = active.as_mut() { - let mut ts = at.turn_state.lock().await; - ts.clear_pending(); - } - if let Some(task) = state.current_task.take() { - task.abort(TurnAbortReason::Interrupted); - } + self.abort_all_tasks(TurnAbortReason::Interrupted).await; } fn interrupt_task_sync(&self) { - if let Ok(mut state) = self.state.try_lock() { - if let Ok(mut active) = self.active_turn.try_lock() - && let Some(at) = active.as_mut() - && let Ok(mut ts) = at.turn_state.try_lock() - { - ts.clear_pending(); - } - if let Some(task) = state.current_task.take() { - task.abort(TurnAbortReason::Interrupted); + if let Ok(mut active) = self.active_turn.try_lock() + && let Some(at) = active.as_mut() + { + at.try_clear_pending_sync(); + let tasks = at.drain_tasks(); + *active = None; + for (_sub_id, task) in tasks { + task.handle.abort(); } } } @@ -1111,106 +1069,6 @@ pub(crate) struct ApplyPatchCommandContext { pub(crate) changes: HashMap, } -#[derive(Clone, Debug, Eq, PartialEq)] -enum AgentTaskKind { - Regular, - Review, - Compact, -} - -/// A series of Turns in response to user input. -pub(crate) struct AgentTask { - sess: Arc, - sub_id: String, - handle: AbortHandle, - kind: AgentTaskKind, -} - -impl AgentTask { - fn spawn( - sess: Arc, - turn_context: Arc, - sub_id: String, - input: Vec, - ) -> Self { - let handle = { - let sess = sess.clone(); - let sub_id = sub_id.clone(); - let tc = Arc::clone(&turn_context); - tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle() - }; - Self { - sess, - sub_id, - handle, - kind: AgentTaskKind::Regular, - } - } - - fn review( - sess: Arc, - turn_context: Arc, - sub_id: String, - input: Vec, - ) -> Self { - let handle = { - let sess = sess.clone(); - let sub_id = sub_id.clone(); - let tc = Arc::clone(&turn_context); - tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle() - }; - Self { - sess, - sub_id, - handle, - kind: AgentTaskKind::Review, - } - } - - fn compact( - sess: Arc, - turn_context: Arc, - sub_id: String, - input: Vec, - ) -> Self { - let handle = { - let sess = sess.clone(); - let sub_id = sub_id.clone(); - let tc = Arc::clone(&turn_context); - tokio::spawn(async move { compact::run_compact_task(sess, tc, sub_id, input).await }) - .abort_handle() - }; - Self { - sess, - sub_id, - handle, - kind: AgentTaskKind::Compact, - } - } - - fn abort(self, reason: TurnAbortReason) { - // TOCTOU? - if !self.handle.is_finished() { - self.handle.abort(); - let sub_id = self.sub_id.clone(); - let is_review = self.kind == AgentTaskKind::Review; - let sess = self.sess; - let event = Event { - id: sub_id.clone(), - msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }), - }; - tokio::spawn(async move { - if is_review { - exit_review_mode(sess.clone(), sub_id.clone(), None).await; - } - // Ensure active turn state is cleared when a task is aborted. - sess.remove_task(&sub_id).await; - sess.send_event(event).await; - }); - } - } -} - async fn submission_loop( sess: Arc, turn_context: TurnContext, @@ -1318,9 +1176,8 @@ async fn submission_loop( // attempt to inject input into current task if let Err(items) = sess.inject_input(items).await { // no current task, spawn a new one - let task = - AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); - sess.set_task(task).await; + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask) + .await; } } Op::UserTurn { @@ -1396,10 +1253,9 @@ async fn submission_loop( // Install the new persistent context for subsequent tasks/turns. turn_context = Arc::new(fresh_turn_context); - // no current task, spawn a new one with the per‑turn context - let task = - AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); - sess.set_task(task).await; + // no current task, spawn a new one with the per-turn context + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask) + .await; } } Op::ExecApproval { id, decision } => match decision { @@ -1497,16 +1353,12 @@ async fn submission_loop( }]) .await { - compact::spawn_compact_task( - sess.clone(), - Arc::clone(&turn_context), - sub.id, - items, - ) - .await; + sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask) + .await; } } Op::Shutdown => { + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; info!("Shutting down Codex instance"); // Gracefully flush and shutdown rollout recorder on session end so tests @@ -1648,8 +1500,7 @@ async fn spawn_review_thread( // Clone sub_id for the upcoming announcement before moving it into the task. let sub_id_for_event = sub_id.clone(); - let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input); - sess.set_task(task).await; + sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await; // Announce entering review mode so UIs can switch modes. sess.send_event(Event { @@ -1676,14 +1527,14 @@ async fn spawn_review_thread( /// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an /// isolated in-memory thread without the parent session's prior history or /// user_instructions. Emits ExitedReviewMode upon final review message. -async fn run_task( +pub(crate) async fn run_task( sess: Arc, turn_context: Arc, sub_id: String, input: Vec, -) { +) -> Option { if input.is_empty() { - return; + return None; } let event = Event { id: sub_id.clone(), @@ -1955,12 +1806,7 @@ async fn run_task( .await; } - sess.remove_task(&sub_id).await; - let event = Event { - id: sub_id, - msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }), - }; - sess.send_event(event).await; + last_agent_message } /// Parse the review output; when not valid JSON, build a structured @@ -3219,7 +3065,7 @@ fn convert_call_tool_result_to_function_call_output_payload( /// Emits an ExitedReviewMode Event with optional ReviewOutput, /// and records a developer message with the review output. -async fn exit_review_mode( +pub(crate) async fn exit_review_mode( session: Arc, task_sub_id: String, review_output: Option, @@ -3283,6 +3129,9 @@ mod tests { use crate::protocol::CompactedItem; use crate::protocol::InitialHistory; use crate::protocol::ResumedHistory; + use crate::state::TaskKind; + use crate::tasks::SessionTask; + use crate::tasks::SessionTaskContext; use codex_protocol::models::ContentItem; use mcp_types::ContentBlock; use mcp_types::TextContent; @@ -3292,6 +3141,8 @@ mod tests { use std::path::PathBuf; use std::sync::Arc; use std::time::Duration as StdDuration; + use tokio::time::Duration; + use tokio::time::sleep; #[test] fn reconstruct_history_matches_live_compactions() { @@ -3577,6 +3428,174 @@ mod tests { (session, turn_context) } + // Like make_session_and_context, but returns Arc and the event receiver + // so tests can assert on emitted events. + fn make_session_and_context_with_rx() -> ( + Arc, + Arc, + async_channel::Receiver, + ) { + let (tx_event, rx_event) = async_channel::unbounded(); + let codex_home = tempfile::tempdir().expect("create temp dir"); + let config = Config::load_from_base_config_with_overrides( + ConfigToml::default(), + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + ) + .expect("load default test config"); + let config = Arc::new(config); + let conversation_id = ConversationId::default(); + let client = ModelClient::new( + config.clone(), + None, + config.model_provider.clone(), + config.model_reasoning_effort, + config.model_reasoning_summary, + conversation_id, + ); + let tools_config = ToolsConfig::new(&ToolsConfigParams { + model_family: &config.model_family, + include_plan_tool: config.include_plan_tool, + include_apply_patch_tool: config.include_apply_patch_tool, + include_web_search_request: config.tools_web_search_request, + use_streamable_shell_tool: config.use_experimental_streamable_shell_tool, + include_view_image_tool: config.include_view_image_tool, + experimental_unified_exec_tool: config.use_experimental_unified_exec_tool, + }); + let turn_context = Arc::new(TurnContext { + client, + cwd: config.cwd.clone(), + base_instructions: config.base_instructions.clone(), + user_instructions: config.user_instructions.clone(), + approval_policy: config.approval_policy, + sandbox_policy: config.sandbox_policy.clone(), + shell_environment_policy: config.shell_environment_policy.clone(), + tools_config, + is_review_mode: false, + final_output_json_schema: None, + }); + let services = SessionServices { + mcp_connection_manager: McpConnectionManager::default(), + session_manager: ExecSessionManager::default(), + unified_exec_manager: UnifiedExecSessionManager::default(), + notifier: UserNotifier::default(), + rollout: Mutex::new(None), + codex_linux_sandbox_exe: None, + user_shell: shell::Shell::Unknown, + show_raw_agent_reasoning: config.show_raw_agent_reasoning, + }; + let session = Arc::new(Session { + conversation_id, + tx_event, + state: Mutex::new(SessionState::new()), + active_turn: Mutex::new(None), + services, + next_internal_sub_id: AtomicU64::new(0), + }); + (session, turn_context, rx_event) + } + + #[derive(Clone, Copy)] + struct NeverEndingTask(TaskKind); + + #[async_trait::async_trait] + impl SessionTask for NeverEndingTask { + fn kind(&self) -> TaskKind { + self.0 + } + + async fn run( + self: Arc, + _session: Arc, + _ctx: Arc, + _sub_id: String, + _input: Vec, + ) -> Option { + loop { + sleep(Duration::from_secs(60)).await; + } + } + + async fn abort(&self, session: Arc, sub_id: &str) { + if let TaskKind::Review = self.0 { + exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; + } + } + } + + #[tokio::test] + async fn abort_regular_task_emits_turn_aborted_only() { + let (sess, tc, rx) = make_session_and_context_with_rx(); + let sub_id = "sub-regular".to_string(); + let input = vec![InputItem::Text { + text: "hello".to_string(), + }]; + sess.spawn_task( + Arc::clone(&tc), + sub_id.clone(), + input, + NeverEndingTask(TaskKind::Regular), + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + let evt = rx.recv().await.expect("event"); + match evt.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected event: {other:?}"), + } + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn abort_review_task_emits_exited_then_aborted_and_records_history() { + let (sess, tc, rx) = make_session_and_context_with_rx(); + let sub_id = "sub-review".to_string(); + let input = vec![InputItem::Text { + text: "start review".to_string(), + }]; + sess.spawn_task( + Arc::clone(&tc), + sub_id.clone(), + input, + NeverEndingTask(TaskKind::Review), + ) + .await; + + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; + + let first = rx.recv().await.expect("first event"); + match first.msg { + EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()), + other => panic!("unexpected first event: {other:?}"), + } + let second = rx.recv().await.expect("second event"); + match second.msg { + EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason), + other => panic!("unexpected second event: {other:?}"), + } + + let history = sess.history_snapshot().await; + let found = history.iter().any(|item| match item { + ResponseItem::Message { role, content, .. } if role == "user" => { + content.iter().any(|ci| match ci { + ContentItem::InputText { text } => { + text.contains("") + && text.contains("review") + && text.contains("interrupted") + } + _ => false, + }) + } + _ => false, + }); + assert!( + found, + "synthetic review interruption not recorded in history" + ); + } + fn sample_rollout( session: &Session, turn_context: &TurnContext, diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index 4facd45d..8689bd02 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use super::AgentTask; use super::Session; use super::TurnContext; use super::get_last_assistant_message_from_turn; @@ -15,7 +14,6 @@ use crate::protocol::Event; use crate::protocol::EventMsg; use crate::protocol::InputItem; use crate::protocol::InputMessageKind; -use crate::protocol::TaskCompleteEvent; use crate::protocol::TaskStartedEvent; use crate::protocol::TurnContextItem; use crate::truncate::truncate_middle; @@ -37,17 +35,7 @@ struct HistoryBridgeTemplate<'a> { summary_text: &'a str, } -pub(super) async fn spawn_compact_task( - sess: Arc, - turn_context: Arc, - sub_id: String, - input: Vec, -) { - let task = AgentTask::compact(sess.clone(), turn_context, sub_id, input); - sess.set_task(task).await; -} - -pub(super) async fn run_inline_auto_compact_task( +pub(crate) async fn run_inline_auto_compact_task( sess: Arc, turn_context: Arc, ) { @@ -55,15 +43,15 @@ pub(super) async fn run_inline_auto_compact_task( let input = vec![InputItem::Text { text: SUMMARIZATION_PROMPT.to_string(), }]; - run_compact_task_inner(sess, turn_context, sub_id, input, false).await; + run_compact_task_inner(sess, turn_context, sub_id, input).await; } -pub(super) async fn run_compact_task( +pub(crate) async fn run_compact_task( sess: Arc, turn_context: Arc, sub_id: String, input: Vec, -) { +) -> Option { let start_event = Event { id: sub_id.clone(), msg: EventMsg::TaskStarted(TaskStartedEvent { @@ -71,14 +59,8 @@ pub(super) async fn run_compact_task( }), }; sess.send_event(start_event).await; - run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input, true).await; - let event = Event { - id: sub_id, - msg: EventMsg::TaskComplete(TaskCompleteEvent { - last_agent_message: None, - }), - }; - sess.send_event(event).await; + run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await; + None } async fn run_compact_task_inner( @@ -86,7 +68,6 @@ async fn run_compact_task_inner( turn_context: Arc, sub_id: String, input: Vec, - remove_task_on_completion: bool, ) { let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); let turn_input = sess @@ -148,9 +129,6 @@ async fn run_compact_task_inner( } } - if remove_task_on_completion { - sess.remove_task(&sub_id).await; - } let history_snapshot = sess.history_snapshot().await; let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default(); let user_messages = collect_user_messages(&history_snapshot); diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 36287c1a..ad040ec8 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -77,6 +77,7 @@ pub use rollout::list::ConversationsPage; pub use rollout::list::Cursor; mod function_tool; mod state; +mod tasks; mod user_notification; pub mod util; diff --git a/codex-rs/core/src/state/mod.rs b/codex-rs/core/src/state/mod.rs index 927f5981..642433a7 100644 --- a/codex-rs/core/src/state/mod.rs +++ b/codex-rs/core/src/state/mod.rs @@ -5,4 +5,5 @@ mod turn; pub(crate) use service::SessionServices; pub(crate) use session::SessionState; pub(crate) use turn::ActiveTurn; -pub(crate) use turn::TurnState; +pub(crate) use turn::RunningTask; +pub(crate) use turn::TaskKind; diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index f8afbdec..ee0c5fc9 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -4,7 +4,6 @@ use std::collections::HashSet; use codex_protocol::models::ResponseItem; -use crate::codex::AgentTask; use crate::conversation_history::ConversationHistory; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; @@ -14,7 +13,6 @@ use crate::protocol::TokenUsageInfo; #[derive(Default)] pub(crate) struct SessionState { pub(crate) approved_commands: HashSet>, - pub(crate) current_task: Option, pub(crate) history: ConversationHistory, pub(crate) token_info: Option, pub(crate) latest_rate_limits: Option, diff --git a/codex-rs/core/src/state/turn.rs b/codex-rs/core/src/state/turn.rs index b49c86b5..f715d548 100644 --- a/codex-rs/core/src/state/turn.rs +++ b/codex-rs/core/src/state/turn.rs @@ -1,21 +1,61 @@ //! Turn-scoped state and active turn metadata scaffolding. +use indexmap::IndexMap; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Mutex; +use tokio::task::AbortHandle; use codex_protocol::models::ResponseInputItem; use tokio::sync::oneshot; use crate::protocol::ReviewDecision; +use crate::tasks::SessionTask; /// Metadata about the currently running turn. -#[derive(Default)] pub(crate) struct ActiveTurn { - pub(crate) sub_id: String, + pub(crate) tasks: IndexMap, pub(crate) turn_state: Arc>, } +impl Default for ActiveTurn { + fn default() -> Self { + Self { + tasks: IndexMap::new(), + turn_state: Arc::new(Mutex::new(TurnState::default())), + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum TaskKind { + Regular, + Review, + Compact, +} + +#[derive(Clone)] +pub(crate) struct RunningTask { + pub(crate) handle: AbortHandle, + pub(crate) kind: TaskKind, + pub(crate) task: Arc, +} + +impl ActiveTurn { + pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) { + self.tasks.insert(sub_id, task); + } + + pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool { + self.tasks.swap_remove(sub_id); + self.tasks.is_empty() + } + + pub(crate) fn drain_tasks(&mut self) -> IndexMap { + std::mem::take(&mut self.tasks) + } +} + /// Mutable state for a single turn. #[derive(Default)] pub(crate) struct TurnState { @@ -58,3 +98,18 @@ impl TurnState { } } } + +impl ActiveTurn { + /// Clear any pending approvals and input buffered for the current turn. + pub(crate) async fn clear_pending(&self) { + let mut ts = self.turn_state.lock().await; + ts.clear_pending(); + } + + /// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt). + pub(crate) fn try_clear_pending_sync(&self) { + if let Ok(mut ts) = self.turn_state.try_lock() { + ts.clear_pending(); + } + } +} diff --git a/codex-rs/core/src/tasks/compact.rs b/codex-rs/core/src/tasks/compact.rs new file mode 100644 index 00000000..823febfc --- /dev/null +++ b/codex-rs/core/src/tasks/compact.rs @@ -0,0 +1,31 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::codex::TurnContext; +use crate::codex::compact; +use crate::protocol::InputItem; +use crate::state::TaskKind; + +use super::SessionTask; +use super::SessionTaskContext; + +#[derive(Clone, Copy, Default)] +pub(crate) struct CompactTask; + +#[async_trait] +impl SessionTask for CompactTask { + fn kind(&self) -> TaskKind { + TaskKind::Compact + } + + async fn run( + self: Arc, + session: Arc, + ctx: Arc, + sub_id: String, + input: Vec, + ) -> Option { + compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await + } +} diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs new file mode 100644 index 00000000..464c1e63 --- /dev/null +++ b/codex-rs/core/src/tasks/mod.rs @@ -0,0 +1,166 @@ +mod compact; +mod regular; +mod review; + +use std::sync::Arc; + +use async_trait::async_trait; +use tracing::trace; + +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::protocol::Event; +use crate::protocol::EventMsg; +use crate::protocol::InputItem; +use crate::protocol::TaskCompleteEvent; +use crate::protocol::TurnAbortReason; +use crate::protocol::TurnAbortedEvent; +use crate::state::ActiveTurn; +use crate::state::RunningTask; +use crate::state::TaskKind; + +pub(crate) use compact::CompactTask; +pub(crate) use regular::RegularTask; +pub(crate) use review::ReviewTask; + +/// Thin wrapper that exposes the parts of [`Session`] task runners need. +#[derive(Clone)] +pub(crate) struct SessionTaskContext { + session: Arc, +} + +impl SessionTaskContext { + pub(crate) fn new(session: Arc) -> Self { + Self { session } + } + + pub(crate) fn clone_session(&self) -> Arc { + Arc::clone(&self.session) + } +} + +#[async_trait] +pub(crate) trait SessionTask: Send + Sync + 'static { + fn kind(&self) -> TaskKind; + + async fn run( + self: Arc, + session: Arc, + ctx: Arc, + sub_id: String, + input: Vec, + ) -> Option; + + async fn abort(&self, session: Arc, sub_id: &str) { + let _ = (session, sub_id); + } +} + +impl Session { + pub async fn spawn_task( + self: &Arc, + turn_context: Arc, + sub_id: String, + input: Vec, + task: T, + ) { + self.abort_all_tasks(TurnAbortReason::Replaced).await; + + let task: Arc = Arc::new(task); + let task_kind = task.kind(); + + let handle = { + let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); + let ctx = Arc::clone(&turn_context); + let task_for_run = Arc::clone(&task); + let sub_clone = sub_id.clone(); + tokio::spawn(async move { + let last_agent_message = task_for_run + .run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input) + .await; + // Emit completion uniformly from spawn site so all tasks share the same lifecycle. + let sess = session_ctx.clone_session(); + sess.on_task_finished(sub_clone, last_agent_message).await; + }) + .abort_handle() + }; + + let running_task = RunningTask { + handle, + kind: task_kind, + task, + }; + self.register_new_active_task(sub_id, running_task).await; + } + + pub async fn abort_all_tasks(self: &Arc, reason: TurnAbortReason) { + for (sub_id, task) in self.take_all_running_tasks().await { + self.handle_task_abort(sub_id, task, reason.clone()).await; + } + } + + pub async fn on_task_finished( + self: &Arc, + sub_id: String, + last_agent_message: Option, + ) { + let mut active = self.active_turn.lock().await; + if let Some(at) = active.as_mut() + && at.remove_task(&sub_id) + { + *active = None; + } + drop(active); + let event = Event { + id: sub_id, + msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }), + }; + self.send_event(event).await; + } + + async fn register_new_active_task(&self, sub_id: String, task: RunningTask) { + let mut active = self.active_turn.lock().await; + let mut turn = ActiveTurn::default(); + turn.add_task(sub_id, task); + *active = Some(turn); + } + + async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> { + let mut active = self.active_turn.lock().await; + match active.take() { + Some(mut at) => { + at.clear_pending().await; + let tasks = at.drain_tasks(); + tasks.into_iter().collect() + } + None => Vec::new(), + } + } + + async fn handle_task_abort( + self: &Arc, + sub_id: String, + task: RunningTask, + reason: TurnAbortReason, + ) { + if task.handle.is_finished() { + return; + } + + trace!(task_kind = ?task.kind, sub_id, "aborting running task"); + let session_task = task.task; + let handle = task.handle; + handle.abort(); + let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self))); + session_task.abort(session_ctx, &sub_id).await; + + let event = Event { + id: sub_id.clone(), + msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }), + }; + self.send_event(event).await; + } +} + +#[cfg(test)] +mod tests {} diff --git a/codex-rs/core/src/tasks/regular.rs b/codex-rs/core/src/tasks/regular.rs new file mode 100644 index 00000000..9d240997 --- /dev/null +++ b/codex-rs/core/src/tasks/regular.rs @@ -0,0 +1,32 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::codex::TurnContext; +use crate::codex::run_task; +use crate::protocol::InputItem; +use crate::state::TaskKind; + +use super::SessionTask; +use super::SessionTaskContext; + +#[derive(Clone, Copy, Default)] +pub(crate) struct RegularTask; + +#[async_trait] +impl SessionTask for RegularTask { + fn kind(&self) -> TaskKind { + TaskKind::Regular + } + + async fn run( + self: Arc, + session: Arc, + ctx: Arc, + sub_id: String, + input: Vec, + ) -> Option { + let sess = session.clone_session(); + run_task(sess, ctx, sub_id, input).await + } +} diff --git a/codex-rs/core/src/tasks/review.rs b/codex-rs/core/src/tasks/review.rs new file mode 100644 index 00000000..047a2f40 --- /dev/null +++ b/codex-rs/core/src/tasks/review.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::codex::TurnContext; +use crate::codex::exit_review_mode; +use crate::codex::run_task; +use crate::protocol::InputItem; +use crate::state::TaskKind; + +use super::SessionTask; +use super::SessionTaskContext; + +#[derive(Clone, Copy, Default)] +pub(crate) struct ReviewTask; + +#[async_trait] +impl SessionTask for ReviewTask { + fn kind(&self) -> TaskKind { + TaskKind::Review + } + + async fn run( + self: Arc, + session: Arc, + ctx: Arc, + sub_id: String, + input: Vec, + ) -> Option { + let sess = session.clone_session(); + run_task(sess, ctx, sub_id, input).await + } + + async fn abort(&self, session: Arc, sub_id: &str) { + exit_review_mode(session.clone_session(), sub_id.to_string(), None).await; + } +} diff --git a/codex-rs/core/tests/suite/abort_tasks.rs b/codex-rs/core/tests/suite/abort_tasks.rs new file mode 100644 index 00000000..2fa3d4df --- /dev/null +++ b/codex-rs/core/tests/suite/abort_tasks.rs @@ -0,0 +1,66 @@ +use std::time::Duration; + +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::mount_sse_once; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event_with_timeout; +use serde_json::json; +use wiremock::matchers::body_string_contains; + +/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE +/// function call, then interrupt the session and expect TurnAborted. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn interrupt_long_running_tool_emits_turn_aborted() { + let command = vec![ + "bash".to_string(), + "-lc".to_string(), + "sleep 60".to_string(), + ]; + + let args = json!({ + "command": command, + "timeout_ms": 60_000 + }) + .to_string(); + let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]); + + let server = start_mock_server().await; + mount_sse_once(&server, body_string_contains("start sleep"), body).await; + + let codex = test_codex().build(&server).await.unwrap().codex; + + let wait_timeout = Duration::from_secs(5); + + // Kick off a turn that triggers the function call. + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "start sleep".into(), + }], + }) + .await + .unwrap(); + + // Wait until the exec begins to avoid a race, then interrupt. + wait_for_event_with_timeout( + &codex, + |ev| matches!(ev, EventMsg::ExecCommandBegin(_)), + wait_timeout, + ) + .await; + + codex.submit(Op::Interrupt).await.unwrap(); + + // Expect TurnAborted soon after. + wait_for_event_with_timeout( + &codex, + |ev| matches!(ev, EventMsg::TurnAborted(_)), + wait_timeout, + ) + .await; +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 2d91e330..0e4e725c 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -1,5 +1,7 @@ // Aggregates all former standalone integration tests as modules. +#[cfg(not(target_os = "windows"))] +mod abort_tasks; mod cli_stream; mod client; mod compact;