ref: state - 2 (#4229)
Extracting tasks in a module and start abstraction behind a Trait (more to come on this but each task will be tackled in a dedicated PR) The goal was to drop the ActiveTask and to have a (potentially) set of tasks during each turn
This commit is contained in:
@@ -24,7 +24,6 @@ use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::prelude::*;
|
||||
use mcp_types::CallToolResult;
|
||||
@@ -34,7 +33,6 @@ use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::AbortHandle;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -107,7 +105,6 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::protocol::StreamErrorEvent;
|
||||
use crate::protocol::Submission;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TokenCountEvent;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
@@ -120,6 +117,9 @@ use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -262,7 +262,7 @@ pub(crate) struct Session {
|
||||
conversation_id: ConversationId,
|
||||
tx_event: Sender<Event>,
|
||||
state: Mutex<SessionState>,
|
||||
active_turn: Mutex<Option<ActiveTurn>>,
|
||||
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
|
||||
services: SessionServices,
|
||||
next_internal_sub_id: AtomicU64,
|
||||
}
|
||||
@@ -495,38 +495,6 @@ impl Session {
|
||||
Ok((sess, turn_context))
|
||||
}
|
||||
|
||||
pub async fn set_task(&self, task: AgentTask) {
|
||||
let mut state = self.state.lock().await;
|
||||
if let Some(current_task) = state.current_task.take() {
|
||||
current_task.abort(TurnAbortReason::Replaced);
|
||||
}
|
||||
state.current_task = Some(task);
|
||||
if let Some(current_task) = &state.current_task {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
*active = Some(ActiveTurn {
|
||||
sub_id: current_task.sub_id.clone(),
|
||||
turn_state: std::sync::Arc::new(tokio::sync::Mutex::new(
|
||||
crate::state::TurnState::default(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_task(&self, sub_id: &str) {
|
||||
let mut state = self.state.lock().await;
|
||||
if let Some(task) = &state.current_task
|
||||
&& task.sub_id == sub_id
|
||||
{
|
||||
state.current_task.take();
|
||||
}
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = &*active
|
||||
&& at.sub_id == sub_id
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn next_internal_sub_id(&self) -> String {
|
||||
let id = self
|
||||
.next_internal_sub_id
|
||||
@@ -1015,26 +983,25 @@ impl Session {
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
||||
let state = self.state.lock().await;
|
||||
if state.current_task.is_some() {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.as_mut() {
|
||||
Some(at) => {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.push_pending_input(input.into());
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Err(input)
|
||||
None => Err(input),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.take_pending_input()
|
||||
} else {
|
||||
Vec::with_capacity(0)
|
||||
match active.as_mut() {
|
||||
Some(at) => {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.take_pending_input()
|
||||
}
|
||||
None => Vec::with_capacity(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1050,29 +1017,20 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn interrupt_task(&self) {
|
||||
pub async fn interrupt_task(self: &Arc<Self>) {
|
||||
info!("interrupt received: abort current task, if any");
|
||||
let mut state = self.state.lock().await;
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
if let Some(task) = state.current_task.take() {
|
||||
task.abort(TurnAbortReason::Interrupted);
|
||||
}
|
||||
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
}
|
||||
|
||||
fn interrupt_task_sync(&self) {
|
||||
if let Ok(mut state) = self.state.try_lock() {
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
&& let Ok(mut ts) = at.turn_state.try_lock()
|
||||
{
|
||||
ts.clear_pending();
|
||||
}
|
||||
if let Some(task) = state.current_task.take() {
|
||||
task.abort(TurnAbortReason::Interrupted);
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
{
|
||||
at.try_clear_pending_sync();
|
||||
let tasks = at.drain_tasks();
|
||||
*active = None;
|
||||
for (_sub_id, task) in tasks {
|
||||
task.handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1111,106 +1069,6 @@ pub(crate) struct ApplyPatchCommandContext {
|
||||
pub(crate) changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
enum AgentTaskKind {
|
||||
Regular,
|
||||
Review,
|
||||
Compact,
|
||||
}
|
||||
|
||||
/// A series of Turns in response to user input.
|
||||
pub(crate) struct AgentTask {
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
handle: AbortHandle,
|
||||
kind: AgentTaskKind,
|
||||
}
|
||||
|
||||
impl AgentTask {
|
||||
fn spawn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Regular,
|
||||
}
|
||||
}
|
||||
|
||||
fn review(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Review,
|
||||
}
|
||||
}
|
||||
|
||||
fn compact(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { compact::run_compact_task(sess, tc, sub_id, input).await })
|
||||
.abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Compact,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self, reason: TurnAbortReason) {
|
||||
// TOCTOU?
|
||||
if !self.handle.is_finished() {
|
||||
self.handle.abort();
|
||||
let sub_id = self.sub_id.clone();
|
||||
let is_review = self.kind == AgentTaskKind::Review;
|
||||
let sess = self.sess;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
if is_review {
|
||||
exit_review_mode(sess.clone(), sub_id.clone(), None).await;
|
||||
}
|
||||
// Ensure active turn state is cleared when a task is aborted.
|
||||
sess.remove_task(&sub_id).await;
|
||||
sess.send_event(event).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn submission_loop(
|
||||
sess: Arc<Session>,
|
||||
turn_context: TurnContext,
|
||||
@@ -1318,9 +1176,8 @@ async fn submission_loop(
|
||||
// attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(items).await {
|
||||
// no current task, spawn a new one
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task).await;
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::UserTurn {
|
||||
@@ -1396,10 +1253,9 @@ async fn submission_loop(
|
||||
// Install the new persistent context for subsequent tasks/turns.
|
||||
turn_context = Arc::new(fresh_turn_context);
|
||||
|
||||
// no current task, spawn a new one with the per‑turn context
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task).await;
|
||||
// no current task, spawn a new one with the per-turn context
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::ExecApproval { id, decision } => match decision {
|
||||
@@ -1497,16 +1353,12 @@ async fn submission_loop(
|
||||
}])
|
||||
.await
|
||||
{
|
||||
compact::spawn_compact_task(
|
||||
sess.clone(),
|
||||
Arc::clone(&turn_context),
|
||||
sub.id,
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::Shutdown => {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
info!("Shutting down Codex instance");
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
@@ -1648,8 +1500,7 @@ async fn spawn_review_thread(
|
||||
|
||||
// Clone sub_id for the upcoming announcement before moving it into the task.
|
||||
let sub_id_for_event = sub_id.clone();
|
||||
let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input);
|
||||
sess.set_task(task).await;
|
||||
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
|
||||
|
||||
// Announce entering review mode so UIs can switch modes.
|
||||
sess.send_event(Event {
|
||||
@@ -1676,14 +1527,14 @@ async fn spawn_review_thread(
|
||||
/// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an
|
||||
/// isolated in-memory thread without the parent session's prior history or
|
||||
/// user_instructions. Emits ExitedReviewMode upon final review message.
|
||||
async fn run_task(
|
||||
pub(crate) async fn run_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
if input.is_empty() {
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
@@ -1955,12 +1806,7 @@ async fn run_task(
|
||||
.await;
|
||||
}
|
||||
|
||||
sess.remove_task(&sub_id).await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
last_agent_message
|
||||
}
|
||||
|
||||
/// Parse the review output; when not valid JSON, build a structured
|
||||
@@ -3219,7 +3065,7 @@ fn convert_call_tool_result_to_function_call_output_payload(
|
||||
|
||||
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
|
||||
/// and records a developer message with the review output.
|
||||
async fn exit_review_mode(
|
||||
pub(crate) async fn exit_review_mode(
|
||||
session: Arc<Session>,
|
||||
task_sub_id: String,
|
||||
review_output: Option<ReviewOutputEvent>,
|
||||
@@ -3283,6 +3129,9 @@ mod tests {
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::InitialHistory;
|
||||
use crate::protocol::ResumedHistory;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
@@ -3292,6 +3141,8 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_history_matches_live_compactions() {
|
||||
@@ -3577,6 +3428,174 @@ mod tests {
|
||||
(session, turn_context)
|
||||
}
|
||||
|
||||
// Like make_session_and_context, but returns Arc<Session> and the event receiver
|
||||
// so tests can assert on emitted events.
|
||||
fn make_session_and_context_with_rx() -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
) {
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.expect("load default test config");
|
||||
let config = Arc::new(config);
|
||||
let conversation_id = ConversationId::default();
|
||||
let client = ModelClient::new(
|
||||
config.clone(),
|
||||
None,
|
||||
config.model_provider.clone(),
|
||||
config.model_reasoning_effort,
|
||||
config.model_reasoning_summary,
|
||||
conversation_id,
|
||||
);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
});
|
||||
let turn_context = Arc::new(TurnContext {
|
||||
client,
|
||||
cwd: config.cwd.clone(),
|
||||
base_instructions: config.base_instructions.clone(),
|
||||
user_instructions: config.user_instructions.clone(),
|
||||
approval_policy: config.approval_policy,
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
shell_environment_policy: config.shell_environment_policy.clone(),
|
||||
tools_config,
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
});
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: McpConnectionManager::default(),
|
||||
session_manager: ExecSessionManager::default(),
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notifier: UserNotifier::default(),
|
||||
rollout: Mutex::new(None),
|
||||
codex_linux_sandbox_exe: None,
|
||||
user_shell: shell::Shell::Unknown,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
};
|
||||
let session = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event,
|
||||
state: Mutex::new(SessionState::new()),
|
||||
active_turn: Mutex::new(None),
|
||||
services,
|
||||
next_internal_sub_id: AtomicU64::new(0),
|
||||
});
|
||||
(session, turn_context, rx_event)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NeverEndingTask(TaskKind);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SessionTask for NeverEndingTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
self.0
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
_session: Arc<SessionTaskContext>,
|
||||
_ctx: Arc<TurnContext>,
|
||||
_sub_id: String,
|
||||
_input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
if let TaskKind::Review = self.0 {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Regular),
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let evt = rx.recv().await.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-review".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "start review".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Review),
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let first = rx.recv().await.expect("first event");
|
||||
match first.msg {
|
||||
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
let second = rx.recv().await.expect("second event");
|
||||
match second.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
|
||||
let history = sess.history_snapshot().await;
|
||||
let found = history.iter().any(|item| match item {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content.iter().any(|ci| match ci {
|
||||
ContentItem::InputText { text } => {
|
||||
text.contains("<user_action>")
|
||||
&& text.contains("review")
|
||||
&& text.contains("interrupted")
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
});
|
||||
assert!(
|
||||
found,
|
||||
"synthetic review interruption not recorded in history"
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_rollout(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
|
||||
Reference in New Issue
Block a user