ref: state - 2 (#4229)

Extracting tasks in a module and start abstraction behind a Trait (more
to come on this but each task will be tackled in a dedicated PR)
The goal was to drop the ActiveTask and to have a (potentially) set of
tasks during each turn
This commit is contained in:
jif-oai
2025-09-26 15:49:08 +02:00
committed by GitHub
parent eb2b739d6a
commit 1fc3413a46
15 changed files with 617 additions and 226 deletions

View File

@@ -24,7 +24,6 @@ use codex_protocol::protocol::ReviewRequest;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TaskStartedEvent;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::protocol::TurnAbortedEvent;
use codex_protocol::protocol::TurnContextItem;
use futures::prelude::*;
use mcp_types::CallToolResult;
@@ -34,7 +33,6 @@ use serde_json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tokio::task::AbortHandle;
use tracing::debug;
use tracing::error;
use tracing::info;
@@ -107,7 +105,6 @@ use crate::protocol::SandboxPolicy;
use crate::protocol::SessionConfiguredEvent;
use crate::protocol::StreamErrorEvent;
use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TokenCountEvent;
use crate::protocol::TokenUsage;
use crate::protocol::TurnDiffEvent;
@@ -120,6 +117,9 @@ use crate::safety::assess_safety_for_untrusted_command;
use crate::shell;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_instructions::UserInstructions;
@@ -262,7 +262,7 @@ pub(crate) struct Session {
conversation_id: ConversationId,
tx_event: Sender<Event>,
state: Mutex<SessionState>,
active_turn: Mutex<Option<ActiveTurn>>,
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
services: SessionServices,
next_internal_sub_id: AtomicU64,
}
@@ -495,38 +495,6 @@ impl Session {
Ok((sess, turn_context))
}
pub async fn set_task(&self, task: AgentTask) {
let mut state = self.state.lock().await;
if let Some(current_task) = state.current_task.take() {
current_task.abort(TurnAbortReason::Replaced);
}
state.current_task = Some(task);
if let Some(current_task) = &state.current_task {
let mut active = self.active_turn.lock().await;
*active = Some(ActiveTurn {
sub_id: current_task.sub_id.clone(),
turn_state: std::sync::Arc::new(tokio::sync::Mutex::new(
crate::state::TurnState::default(),
)),
});
}
}
pub async fn remove_task(&self, sub_id: &str) {
let mut state = self.state.lock().await;
if let Some(task) = &state.current_task
&& task.sub_id == sub_id
{
state.current_task.take();
}
let mut active = self.active_turn.lock().await;
if let Some(at) = &*active
&& at.sub_id == sub_id
{
*active = None;
}
}
fn next_internal_sub_id(&self) -> String {
let id = self
.next_internal_sub_id
@@ -1015,26 +983,25 @@ impl Session {
/// Returns the input if there was no task running to inject into
pub async fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
let state = self.state.lock().await;
if state.current_task.is_some() {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut active = self.active_turn.lock().await;
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
ts.push_pending_input(input.into());
Ok(())
}
Ok(())
} else {
Err(input)
None => Err(input),
}
}
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut ts = at.turn_state.lock().await;
ts.take_pending_input()
} else {
Vec::with_capacity(0)
match active.as_mut() {
Some(at) => {
let mut ts = at.turn_state.lock().await;
ts.take_pending_input()
}
None => Vec::with_capacity(0),
}
}
@@ -1050,29 +1017,20 @@ impl Session {
.await
}
pub async fn interrupt_task(&self) {
pub async fn interrupt_task(self: &Arc<Self>) {
info!("interrupt received: abort current task, if any");
let mut state = self.state.lock().await;
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut() {
let mut ts = at.turn_state.lock().await;
ts.clear_pending();
}
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
}
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
}
fn interrupt_task_sync(&self) {
if let Ok(mut state) = self.state.try_lock() {
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
&& let Ok(mut ts) = at.turn_state.try_lock()
{
ts.clear_pending();
}
if let Some(task) = state.current_task.take() {
task.abort(TurnAbortReason::Interrupted);
if let Ok(mut active) = self.active_turn.try_lock()
&& let Some(at) = active.as_mut()
{
at.try_clear_pending_sync();
let tasks = at.drain_tasks();
*active = None;
for (_sub_id, task) in tasks {
task.handle.abort();
}
}
}
@@ -1111,106 +1069,6 @@ pub(crate) struct ApplyPatchCommandContext {
pub(crate) changes: HashMap<PathBuf, FileChange>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum AgentTaskKind {
Regular,
Review,
Compact,
}
/// A series of Turns in response to user input.
pub(crate) struct AgentTask {
sess: Arc<Session>,
sub_id: String,
handle: AbortHandle,
kind: AgentTaskKind,
}
impl AgentTask {
fn spawn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Regular,
}
}
fn review(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Review,
}
}
fn compact(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) -> Self {
let handle = {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { compact::run_compact_task(sess, tc, sub_id, input).await })
.abort_handle()
};
Self {
sess,
sub_id,
handle,
kind: AgentTaskKind::Compact,
}
}
fn abort(self, reason: TurnAbortReason) {
// TOCTOU?
if !self.handle.is_finished() {
self.handle.abort();
let sub_id = self.sub_id.clone();
let is_review = self.kind == AgentTaskKind::Review;
let sess = self.sess;
let event = Event {
id: sub_id.clone(),
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
};
tokio::spawn(async move {
if is_review {
exit_review_mode(sess.clone(), sub_id.clone(), None).await;
}
// Ensure active turn state is cleared when a task is aborted.
sess.remove_task(&sub_id).await;
sess.send_event(event).await;
});
}
}
}
async fn submission_loop(
sess: Arc<Session>,
turn_context: TurnContext,
@@ -1318,9 +1176,8 @@ async fn submission_loop(
// attempt to inject input into current task
if let Err(items) = sess.inject_input(items).await {
// no current task, spawn a new one
let task =
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task).await;
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
.await;
}
}
Op::UserTurn {
@@ -1396,10 +1253,9 @@ async fn submission_loop(
// Install the new persistent context for subsequent tasks/turns.
turn_context = Arc::new(fresh_turn_context);
// no current task, spawn a new one with the perturn context
let task =
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task).await;
// no current task, spawn a new one with the per-turn context
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
.await;
}
}
Op::ExecApproval { id, decision } => match decision {
@@ -1497,16 +1353,12 @@ async fn submission_loop(
}])
.await
{
compact::spawn_compact_task(
sess.clone(),
Arc::clone(&turn_context),
sub.id,
items,
)
.await;
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
.await;
}
}
Op::Shutdown => {
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
info!("Shutting down Codex instance");
// Gracefully flush and shutdown rollout recorder on session end so tests
@@ -1648,8 +1500,7 @@ async fn spawn_review_thread(
// Clone sub_id for the upcoming announcement before moving it into the task.
let sub_id_for_event = sub_id.clone();
let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input);
sess.set_task(task).await;
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
// Announce entering review mode so UIs can switch modes.
sess.send_event(Event {
@@ -1676,14 +1527,14 @@ async fn spawn_review_thread(
/// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an
/// isolated in-memory thread without the parent session's prior history or
/// user_instructions. Emits ExitedReviewMode upon final review message.
async fn run_task(
pub(crate) async fn run_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) {
) -> Option<String> {
if input.is_empty() {
return;
return None;
}
let event = Event {
id: sub_id.clone(),
@@ -1955,12 +1806,7 @@ async fn run_task(
.await;
}
sess.remove_task(&sub_id).await;
let event = Event {
id: sub_id,
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
};
sess.send_event(event).await;
last_agent_message
}
/// Parse the review output; when not valid JSON, build a structured
@@ -3219,7 +3065,7 @@ fn convert_call_tool_result_to_function_call_output_payload(
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
/// and records a developer message with the review output.
async fn exit_review_mode(
pub(crate) async fn exit_review_mode(
session: Arc<Session>,
task_sub_id: String,
review_output: Option<ReviewOutputEvent>,
@@ -3283,6 +3129,9 @@ mod tests {
use crate::protocol::CompactedItem;
use crate::protocol::InitialHistory;
use crate::protocol::ResumedHistory;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use codex_protocol::models::ContentItem;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
@@ -3292,6 +3141,8 @@ mod tests {
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::time::Duration;
use tokio::time::sleep;
#[test]
fn reconstruct_history_matches_live_compactions() {
@@ -3577,6 +3428,174 @@ mod tests {
(session, turn_context)
}
// Like make_session_and_context, but returns Arc<Session> and the event receiver
// so tests can assert on emitted events.
fn make_session_and_context_with_rx() -> (
Arc<Session>,
Arc<TurnContext>,
async_channel::Receiver<Event>,
) {
let (tx_event, rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("load default test config");
let config = Arc::new(config);
let conversation_id = ConversationId::default();
let client = ModelClient::new(
config.clone(),
None,
config.model_provider.clone(),
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
include_view_image_tool: config.include_view_image_tool,
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
});
let turn_context = Arc::new(TurnContext {
client,
cwd: config.cwd.clone(),
base_instructions: config.base_instructions.clone(),
user_instructions: config.user_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
tools_config,
is_review_mode: false,
final_output_json_schema: None,
});
let services = SessionServices {
mcp_connection_manager: McpConnectionManager::default(),
session_manager: ExecSessionManager::default(),
unified_exec_manager: UnifiedExecSessionManager::default(),
notifier: UserNotifier::default(),
rollout: Mutex::new(None),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
};
let session = Arc::new(Session {
conversation_id,
tx_event,
state: Mutex::new(SessionState::new()),
active_turn: Mutex::new(None),
services,
next_internal_sub_id: AtomicU64::new(0),
});
(session, turn_context, rx_event)
}
#[derive(Clone, Copy)]
struct NeverEndingTask(TaskKind);
#[async_trait::async_trait]
impl SessionTask for NeverEndingTask {
fn kind(&self) -> TaskKind {
self.0
}
async fn run(
self: Arc<Self>,
_session: Arc<SessionTaskContext>,
_ctx: Arc<TurnContext>,
_sub_id: String,
_input: Vec<InputItem>,
) -> Option<String> {
loop {
sleep(Duration::from_secs(60)).await;
}
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
if let TaskKind::Review = self.0 {
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
}
}
}
#[tokio::test]
async fn abort_regular_task_emits_turn_aborted_only() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-regular".to_string();
let input = vec![InputItem::Text {
text: "hello".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Regular),
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let evt = rx.recv().await.expect("event");
match evt.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected event: {other:?}"),
}
assert!(rx.try_recv().is_err());
}
#[tokio::test]
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
let (sess, tc, rx) = make_session_and_context_with_rx();
let sub_id = "sub-review".to_string();
let input = vec![InputItem::Text {
text: "start review".to_string(),
}];
sess.spawn_task(
Arc::clone(&tc),
sub_id.clone(),
input,
NeverEndingTask(TaskKind::Review),
)
.await;
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
let first = rx.recv().await.expect("first event");
match first.msg {
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
other => panic!("unexpected first event: {other:?}"),
}
let second = rx.recv().await.expect("second event");
match second.msg {
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
other => panic!("unexpected second event: {other:?}"),
}
let history = sess.history_snapshot().await;
let found = history.iter().any(|item| match item {
ResponseItem::Message { role, content, .. } if role == "user" => {
content.iter().any(|ci| match ci {
ContentItem::InputText { text } => {
text.contains("<user_action>")
&& text.contains("review")
&& text.contains("interrupted")
}
_ => false,
})
}
_ => false,
});
assert!(
found,
"synthetic review interruption not recorded in history"
);
}
fn sample_rollout(
session: &Session,
turn_context: &TurnContext,