ref: state - 2 (#4229)

Extracting tasks in a module and start abstraction behind a Trait (more
to come on this but each task will be tackled in a dedicated PR)
The goal was to drop the ActiveTask and to have a (potentially) set of
tasks during each turn
This commit is contained in:
jif-oai
2025-09-26 15:49:08 +02:00
committed by GitHub
parent eb2b739d6a
commit 1fc3413a46
15 changed files with 617 additions and 226 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -673,6 +673,7 @@ dependencies = [
"askama",
"assert_cmd",
"async-channel",
"async-trait",
"base64",
"bytes",
"chrono",
@@ -685,6 +686,7 @@ dependencies = [
"env-flags",
"eventsource-stream",
"futures",
"indexmap 2.10.0",
"landlock",
"libc",
"maplit",

View File

@@ -85,6 +85,7 @@ icu_decimal = "2.0.0"
icu_locale_core = "2.0.0"
ignore = "0.4.23"
image = { version = "^0.25.8", default-features = false }
indexmap = "2.6.0"
insta = "1.43.2"
itertools = "0.14.0"
landlock = "0.4.1"

View File

@@ -15,6 +15,7 @@ workspace = true
anyhow = { workspace = true }
askama = { workspace = true }
async-channel = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
@@ -26,6 +27,7 @@ dirs = { workspace = true }
env-flags = { workspace = true }
eventsource-stream = { workspace = true }
futures = { workspace = true }
indexmap = { workspace = true }
libc = { workspace = true }
mcp-types = { workspace = true }
os_info = { workspace = true }

View File

@@ -24,7 +24,6 @@ use codex_protocol::protocol::ReviewRequest;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TaskStartedEvent;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use codex_protocol::protocol::TurnContextItem;
use futures::prelude::*;
use mcp_types::CallToolResult;
@@ -34,7 +33,6 @@ use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio::task::AbortHandle;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -107,7 +105,6 @@ use crate::protocol::SandboxPolicy;
use crate::protocol::SessionConfiguredEvent;
use crate::protocol::StreamErrorEvent;
use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TokenCountEvent;
use crate::protocol::TokenUsage;
use crate::protocol::TurnDiffEvent;
@@ -120,6 +117,9 @@ use crate::safety::assess_safety_for_untrusted_command;
use crate::shell;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
@@ -262,7 +262,7 @@ pub(crate) struct Session {
conversation_id: ConversationId,
tx_event: Sender<Event>,
state: Mutex<SessionState>,
active_turn: Mutex<Option<ActiveTurn>>,
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
services: SessionServices,
next_internal_sub_id: AtomicU64,
}
@@ -495,38 +495,6 @@ impl Session {
Ok((sess, turn_context))
}
pub async fn set_task(&self, task: AgentTask) {
let mut state = self.state.lock().await;
if let Some(current_task) = state.current_task.take() {
current_task.abort(TurnAbortReason::Replaced);
}
state.current_task = Some(task);
if let Some(current_task) = &state.current_task {
let mut active = self.active_turn.lock().await;
*active = Some(ActiveTurn {
sub_id: current_task.sub_id.clone(),
turn_state: std::sync::Arc::new(tokio::sync::Mutex::new(
crate::state::TurnState::default(),
)),
});
}
}
pub async fn remove_task(&self, sub_id: &str) {
let mut state = self.state.lock().await;
if let Some(task) = &state.current_task
&& task.sub_id == sub_id
{
state.current_task.take();
}
let mut active = self.active_turn.lock().await;
if let Some(at) = &*active
&& at.sub_id == sub_id
{
*active = None;
}
}
fn next_internal_sub_id(&self) -> String {
let id = self
.next_internal_sub_id
@@ -1015,26 +983,25 @@ impl Session {
/// Returns the input if there was no task running to inject into
pub async fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
let state = self.state.lock().await;
if state.current_task.is_some() {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut active = self.active_turn.lock().await;
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
ts.push_pending_input(input.into());
Ok(())
}
Ok(())
} else {
Err(input)
None => Err(input),
}
}
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut ts = at.turn_state.lock().await;
ts.take_pending_input()
} else {
Vec::with_capacity(0)
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
ts.take_pending_input()
}
None => Vec::with_capacity(0),
}
}
@@ -1050,29 +1017,20 @@ impl Session {
.await
}
pub async fn interrupt_task(&self) {
pub async fn interrupt_task(self: &Arc<Self>) {
info!("interrupt received: abort current task, if any");
let mut state = self.state.lock().await;
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut ts = at.turn_state.lock().await;
ts.clear_pending();
}
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
}
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
}
fn interrupt_task_sync(&self) {
if let Ok(mut state) = self.state.try_lock() {
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
&& let Ok(mut ts) = at.turn_state.try_lock()
{
ts.clear_pending();
}
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
{
at.try_clear_pending_sync();
let tasks = at.drain_tasks();
*active = None;
for (_sub_id, task) in tasks {
task.handle.abort();
}
}
}
@@ -1111,106 +1069,6 @@ pub(crate) struct ApplyPatchCommandContext {
pub(crate) changes: HashMap<PathBuf, FileChange>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum AgentTaskKind {
Regular,
Review,
Compact,
}
/// A series of Turns in response to user input.
pub(crate) struct AgentTask {
sess: Arc<Session>,
sub_id: String,
handle: AbortHandle,
kind: AgentTaskKind,
}
impl AgentTask {
fn spawn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Regular,
}
}
fn review(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Review,
}
}
fn compact(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { compact::run_compact_task(sess, tc, sub_id, input).await })
.abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Compact,
}
}
fn abort(self, reason: TurnAbortReason) {
// TOCTOU?
if !self.handle.is_finished() {
self.handle.abort();
let sub_id = self.sub_id.clone();
let is_review = self.kind == AgentTaskKind::Review;
let sess = self.sess;
let event = Event {
id: sub_id.clone(),
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
};
tokio::spawn(async move {
if is_review {
exit_review_mode(sess.clone(), sub_id.clone(), None).await;
}
// Ensure active turn state is cleared when a task is aborted.
sess.remove_task(&sub_id).await;
sess.send_event(event).await;
});
}
}
}
async fn submission_loop(
sess: Arc<Session>,
turn_context: TurnContext,
@@ -1318,9 +1176,8 @@ async fn submission_loop(
// attempt to inject input into current task
if let Err(items) = sess.inject_input(items).await {
// no current task, spawn a new one
let task =
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task).await;
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
.await;
}
}
Op::UserTurn {
@@ -1396,10 +1253,9 @@ async fn submission_loop(
// Install the new persistent context for subsequent tasks/turns.
turn_context = Arc::new(fresh_turn_context);
// no current task, spawn a new one with the perturn context
let task =
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task).await;
// no current task, spawn a new one with the per-turn context
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
.await;
}
}
Op::ExecApproval { id, decision } => match decision {
@@ -1497,16 +1353,12 @@ async fn submission_loop(
}])
.await
{
compact::spawn_compact_task(
sess.clone(),
Arc::clone(&turn_context),
sub.id,
items,
)
.await;
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
.await;
}
}
Op::Shutdown => {
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
info!("Shutting down Codex instance");
// Gracefully flush and shutdown rollout recorder on session end so tests
@@ -1648,8 +1500,7 @@ async fn spawn_review_thread(
// Clone sub_id for the upcoming announcement before moving it into the task.
let sub_id_for_event = sub_id.clone();
let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input);
sess.set_task(task).await;
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
// Announce entering review mode so UIs can switch modes.
sess.send_event(Event {
@@ -1676,14 +1527,14 @@ async fn spawn_review_thread(
/// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an
/// isolated in-memory thread without the parent session's prior history or
/// user_instructions. Emits ExitedReviewMode upon final review message.
async fn run_task(
pub(crate) async fn run_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) {
) -> Option<String> {
if input.is_empty() {
return;
return None;
}
let event = Event {
id: sub_id.clone(),
@@ -1955,12 +1806,7 @@ async fn run_task(
.await;
}
sess.remove_task(&sub_id).await;
let event = Event {
id: sub_id,
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
};
sess.send_event(event).await;
last_agent_message
}
/// Parse the review output; when not valid JSON, build a structured
@@ -3219,7 +3065,7 @@ fn convert_call_tool_result_to_function_call_output_payload(
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
/// and records a developer message with the review output.
async fn exit_review_mode(
pub(crate) async fn exit_review_mode(
session: Arc<Session>,
task_sub_id: String,
review_output: Option<ReviewOutputEvent>,
@@ -3283,6 +3129,9 @@ mod tests {
use crate::protocol::CompactedItem;
use crate::protocol::InitialHistory;
use crate::protocol::ResumedHistory;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use codex_protocol::models::ContentItem;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
@@ -3292,6 +3141,8 @@ mod tests {
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::time::Duration;
use tokio::time::sleep;
#[test]
fn reconstruct_history_matches_live_compactions() {
@@ -3577,6 +3428,174 @@ mod tests {
(session, turn_context)
}
// Like make_session_and_context, but returns Arc<Session> and the event receiver
// so tests can assert on emitted events.
fn make_session_and_context_with_rx() -> (
Arc<Session>,
Arc<TurnContext>,
async_channel::Receiver<Event>,
) {
let (tx_event, rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("load default test config");
let config = Arc::new(config);
let conversation_id = ConversationId::default();
let client = ModelClient::new(
config.clone(),
None,
config.model_provider.clone(),
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
include_view_image_tool: config.include_view_image_tool,
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
});
let turn_context = Arc::new(TurnContext {
client,
cwd: config.cwd.clone(),
base_instructions: config.base_instructions.clone(),
user_instructions: config.user_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
tools_config,
is_review_mode: false,
final_output_json_schema: None,
});
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
session_manager: ExecSessionManager::default(),
unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::default(),
rollout: Mutex::new(None),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
};
let session = Arc::new(Session {
conversation_id,
tx_event,
state: Mutex::new(SessionState::new()),
active_turn: Mutex::new(None),
services,
next_internal_sub_id: AtomicU64::new(0),
});
(session, turn_context, rx_event)
}
#[derive(Clone, Copy)]
struct NeverEndingTask(TaskKind);
#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
self.0
}
async fn run(
self: Arc<Self>,
_session: Arc<SessionTaskContext>,
_ctx: Arc<TurnContext>,
_sub_id: String,
_input: Vec<InputItem>,
) -> Option<String> {
loop {
sleep(Duration::from_secs(60)).await;
}
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
if let TaskKind::Review = self.0 {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
}
}
}
#[tokio::test]
async fn abort_regular_task_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
let input = vec![InputItem::Text {
text: "hello".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Regular),
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let evt = rx.recv().await.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-review".to_string();
let input = vec![InputItem::Text {
text: "start review".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Review),
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let first = rx.recv().await.expect("first event");
match first.msg {
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
other => panic!("unexpected first event: {other:?}"),
}
let second = rx.recv().await.expect("second event");
match second.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected second event: {other:?}"),
}
let history = sess.history_snapshot().await;
let found = history.iter().any(|item| match item {
ResponseItem::Message { role, content, .. } if role == "user" => {
content.iter().any(|ci| match ci {
ContentItem::InputText { text } => {
text.contains("<user_action>")
&& text.contains("review")
&& text.contains("interrupted")
}
_ => false,
})
}
_ => false,
});
assert!(
found,
"synthetic review interruption not recorded in history"
);
}
fn sample_rollout(
session: &Session,
turn_context: &TurnContext,

View File

@@ -1,6 +1,5 @@
use std::sync::Arc;
use super::AgentTask;
use super::Session;
use super::TurnContext;
use super::get_last_assistant_message_from_turn;
@@ -15,7 +14,6 @@ use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::InputItem;
use crate::protocol::InputMessageKind;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TaskStartedEvent;
use crate::protocol::TurnContextItem;
use crate::truncate::truncate_middle;
@@ -37,17 +35,7 @@ struct HistoryBridgeTemplate<'a> {
summary_text: &'a str,
}
pub(super) async fn spawn_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) {
let task = AgentTask::compact(sess.clone(), turn_context, sub_id, input);
sess.set_task(task).await;
}
pub(super) async fn run_inline_auto_compact_task(
pub(crate) async fn run_inline_auto_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
) {
@@ -55,15 +43,15 @@ pub(super) async fn run_inline_auto_compact_task(
let input = vec![InputItem::Text {
text: SUMMARIZATION_PROMPT.to_string(),
}];
run_compact_task_inner(sess, turn_context, sub_id, input, false).await;
run_compact_task_inner(sess, turn_context, sub_id, input).await;
}
pub(super) async fn run_compact_task(
pub(crate) async fn run_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) {
) -> Option<String> {
let start_event = Event {
id: sub_id.clone(),
msg: EventMsg::TaskStarted(TaskStartedEvent {
@@ -71,14 +59,8 @@ pub(super) async fn run_compact_task(
}),
};
sess.send_event(start_event).await;
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input, true).await;
let event = Event {
id: sub_id,
msg: EventMsg::TaskComplete(TaskCompleteEvent {
last_agent_message: None,
}),
};
sess.send_event(event).await;
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await;
None
}
async fn run_compact_task_inner(
@@ -86,7 +68,6 @@ async fn run_compact_task_inner(
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
remove_task_on_completion: bool,
) {
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let turn_input = sess
@@ -148,9 +129,6 @@ async fn run_compact_task_inner(
}
}
if remove_task_on_completion {
sess.remove_task(&sub_id).await;
}
let history_snapshot = sess.history_snapshot().await;
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);

View File

@@ -77,6 +77,7 @@ pub use rollout::list::ConversationsPage;
pub use rollout::list::Cursor;
mod function_tool;
mod state;
mod tasks;
mod user_notification;
pub mod util;

View File

@@ -5,4 +5,5 @@ mod turn;
pub(crate) use service::SessionServices;
pub(crate) use session::SessionState;
pub(crate) use turn::ActiveTurn;
pub(crate) use turn::TurnState;
pub(crate) use turn::RunningTask;
pub(crate) use turn::TaskKind;

View File

@@ -4,7 +4,6 @@ use std::collections::HashSet;
use codex_protocol::models::ResponseItem;
use crate::codex::AgentTask;
use crate::conversation_history::ConversationHistory;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
@@ -14,7 +13,6 @@ use crate::protocol::TokenUsageInfo;
#[derive(Default)]
pub(crate) struct SessionState {
pub(crate) approved_commands: HashSet<Vec<String>>,
pub(crate) current_task: Option<AgentTask>,
pub(crate) history: ConversationHistory,
pub(crate) token_info: Option<TokenUsageInfo>,
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,

View File

@@ -1,21 +1,61 @@
//! Turn-scoped state and active turn metadata scaffolding.
use indexmap::IndexMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::AbortHandle;
use codex_protocol::models::ResponseInputItem;
use tokio::sync::oneshot;
use crate::protocol::ReviewDecision;
use crate::tasks::SessionTask;
/// Metadata about the currently running turn.
#[derive(Default)]
pub(crate) struct ActiveTurn {
pub(crate) sub_id: String,
pub(crate) tasks: IndexMap<String, RunningTask>,
pub(crate) turn_state: Arc<Mutex<TurnState>>,
}
impl Default for ActiveTurn {
fn default() -> Self {
Self {
tasks: IndexMap::new(),
turn_state: Arc::new(Mutex::new(TurnState::default())),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum TaskKind {
Regular,
Review,
Compact,
}
#[derive(Clone)]
pub(crate) struct RunningTask {
pub(crate) handle: AbortHandle,
pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn SessionTask>,
}
impl ActiveTurn {
pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) {
self.tasks.insert(sub_id, task);
}
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
self.tasks.swap_remove(sub_id);
self.tasks.is_empty()
}
pub(crate) fn drain_tasks(&mut self) -> IndexMap<String, RunningTask> {
std::mem::take(&mut self.tasks)
}
}
/// Mutable state for a single turn.
#[derive(Default)]
pub(crate) struct TurnState {
@@ -58,3 +98,18 @@ impl TurnState {
}
}
}
impl ActiveTurn {
/// Clear any pending approvals and input buffered for the current turn.
pub(crate) async fn clear_pending(&self) {
let mut ts = self.turn_state.lock().await;
ts.clear_pending();
}
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
pub(crate) fn try_clear_pending_sync(&self) {
if let Ok(mut ts) = self.turn_state.try_lock() {
ts.clear_pending();
}
}
}

View File

@@ -0,0 +1,31 @@
use std::sync::Arc;
use async_trait::async_trait;
use crate::codex::TurnContext;
use crate::codex::compact;
use crate::protocol::InputItem;
use crate::state::TaskKind;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct CompactTask;
#[async_trait]
impl SessionTask for CompactTask {
fn kind(&self) -> TaskKind {
TaskKind::Compact
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Option<String> {
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
}
}

View File

@@ -0,0 +1,166 @@
mod compact;
mod regular;
mod review;
use std::sync::Arc;
use async_trait::async_trait;
use tracing::trace;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::InputItem;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TurnAbortReason;
use crate::protocol::TurnAbortedEvent;
use crate::state::ActiveTurn;
use crate::state::RunningTask;
use crate::state::TaskKind;
pub(crate) use compact::CompactTask;
pub(crate) use regular::RegularTask;
pub(crate) use review::ReviewTask;
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
#[derive(Clone)]
pub(crate) struct SessionTaskContext {
session: Arc<Session>,
}
impl SessionTaskContext {
pub(crate) fn new(session: Arc<Session>) -> Self {
Self { session }
}
pub(crate) fn clone_session(&self) -> Arc<Session> {
Arc::clone(&self.session)
}
}
#[async_trait]
pub(crate) trait SessionTask: Send + Sync + 'static {
fn kind(&self) -> TaskKind;
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Option<String>;
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
let _ = (session, sub_id);
}
}
impl Session {
pub async fn spawn_task<T: SessionTask>(
self: &Arc<Self>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
task: T,
) {
self.abort_all_tasks(TurnAbortReason::Replaced).await;
let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
let handle = {
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task);
let sub_clone = sub_id.clone();
tokio::spawn(async move {
let last_agent_message = task_for_run
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
.await;
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(sub_clone, last_agent_message).await;
})
.abort_handle()
};
let running_task = RunningTask {
handle,
kind: task_kind,
task,
};
self.register_new_active_task(sub_id, running_task).await;
}
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
for (sub_id, task) in self.take_all_running_tasks().await {
self.handle_task_abort(sub_id, task, reason.clone()).await;
}
}
pub async fn on_task_finished(
self: &Arc<Self>,
sub_id: String,
last_agent_message: Option<String>,
) {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut()
&& at.remove_task(&sub_id)
{
*active = None;
}
drop(active);
let event = Event {
id: sub_id,
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
};
self.send_event(event).await;
}
async fn register_new_active_task(&self, sub_id: String, task: RunningTask) {
let mut active = self.active_turn.lock().await;
let mut turn = ActiveTurn::default();
turn.add_task(sub_id, task);
*active = Some(turn);
}
async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> {
let mut active = self.active_turn.lock().await;
match active.take() {
Some(mut at) => {
at.clear_pending().await;
let tasks = at.drain_tasks();
tasks.into_iter().collect()
}
None => Vec::new(),
}
}
async fn handle_task_abort(
self: &Arc<Self>,
sub_id: String,
task: RunningTask,
reason: TurnAbortReason,
) {
if task.handle.is_finished() {
return;
}
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
let session_task = task.task;
let handle = task.handle;
handle.abort();
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task.abort(session_ctx, &sub_id).await;
let event = Event {
id: sub_id.clone(),
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
};
self.send_event(event).await;
}
}
#[cfg(test)]
mod tests {}

View File

@@ -0,0 +1,32 @@
use std::sync::Arc;
use async_trait::async_trait;
use crate::codex::TurnContext;
use crate::codex::run_task;
use crate::protocol::InputItem;
use crate::state::TaskKind;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct RegularTask;
#[async_trait]
impl SessionTask for RegularTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input).await
}
}

View File

@@ -0,0 +1,37 @@
use std::sync::Arc;
use async_trait::async_trait;
use crate::codex::TurnContext;
use crate::codex::exit_review_mode;
use crate::codex::run_task;
use crate::protocol::InputItem;
use crate::state::TaskKind;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct ReviewTask;
#[async_trait]
impl SessionTask for ReviewTask {
fn kind(&self) -> TaskKind {
TaskKind::Review
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input).await
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
}
}

View File

@@ -0,0 +1,66 @@
use std::time::Duration;
use codex_core::protocol::EventMsg;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event_with_timeout;
use serde_json::json;
use wiremock::matchers::body_string_contains;
/// Integration test: spawn a longrunning shell tool via a mocked Responses SSE
/// function call, then interrupt the session and expect TurnAborted.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn interrupt_long_running_tool_emits_turn_aborted() {
let command = vec![
"bash".to_string(),
"-lc".to_string(),
"sleep 60".to_string(),
];
let args = json!({
"command": command,
"timeout_ms": 60_000
})
.to_string();
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
let server = start_mock_server().await;
mount_sse_once(&server, body_string_contains("start sleep"), body).await;
let codex = test_codex().build(&server).await.unwrap().codex;
let wait_timeout = Duration::from_secs(5);
// Kick off a turn that triggers the function call.
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "start sleep".into(),
}],
})
.await
.unwrap();
// Wait until the exec begins to avoid a race, then interrupt.
wait_for_event_with_timeout(
&codex,
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
wait_timeout,
)
.await;
codex.submit(Op::Interrupt).await.unwrap();
// Expect TurnAborted soon after.
wait_for_event_with_timeout(
&codex,
|ev| matches!(ev, EventMsg::TurnAborted(_)),
wait_timeout,
)
.await;
}

View File

@@ -1,5 +1,7 @@
// Aggregates all former standalone integration tests as modules.
#[cfg(not(target_os = "windows"))]
mod abort_tasks;
mod cli_stream;
mod client;
mod compact;