ref: state - 2 (#4229)
Extracting tasks in a module and start abstraction behind a Trait (more to come on this but each task will be tackled in a dedicated PR) The goal was to drop the ActiveTask and to have a (potentially) set of tasks during each turn
This commit is contained in:
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -673,6 +673,7 @@ dependencies = [
|
||||
"askama",
|
||||
"assert_cmd",
|
||||
"async-channel",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -685,6 +686,7 @@ dependencies = [
|
||||
"env-flags",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"indexmap 2.10.0",
|
||||
"landlock",
|
||||
"libc",
|
||||
"maplit",
|
||||
|
||||
@@ -85,6 +85,7 @@ icu_decimal = "2.0.0"
|
||||
icu_locale_core = "2.0.0"
|
||||
ignore = "0.4.23"
|
||||
image = { version = "^0.25.8", default-features = false }
|
||||
indexmap = "2.6.0"
|
||||
insta = "1.43.2"
|
||||
itertools = "0.14.0"
|
||||
landlock = "0.4.1"
|
||||
|
||||
@@ -15,6 +15,7 @@ workspace = true
|
||||
anyhow = { workspace = true }
|
||||
askama = { workspace = true }
|
||||
async-channel = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
@@ -26,6 +27,7 @@ dirs = { workspace = true }
|
||||
env-flags = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
indexmap = { workspace = true }
|
||||
libc = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
|
||||
@@ -24,7 +24,6 @@ use codex_protocol::protocol::ReviewRequest;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::TaskStartedEvent;
|
||||
use codex_protocol::protocol::TurnAbortReason;
|
||||
use codex_protocol::protocol::TurnAbortedEvent;
|
||||
use codex_protocol::protocol::TurnContextItem;
|
||||
use futures::prelude::*;
|
||||
use mcp_types::CallToolResult;
|
||||
@@ -34,7 +33,6 @@ use serde_json;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::AbortHandle;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
@@ -107,7 +105,6 @@ use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::protocol::StreamErrorEvent;
|
||||
use crate::protocol::Submission;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TokenCountEvent;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
@@ -120,6 +117,9 @@ use crate::safety::assess_safety_for_untrusted_command;
|
||||
use crate::shell;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::SessionServices;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_instructions::UserInstructions;
|
||||
@@ -262,7 +262,7 @@ pub(crate) struct Session {
|
||||
conversation_id: ConversationId,
|
||||
tx_event: Sender<Event>,
|
||||
state: Mutex<SessionState>,
|
||||
active_turn: Mutex<Option<ActiveTurn>>,
|
||||
pub(crate) active_turn: Mutex<Option<ActiveTurn>>,
|
||||
services: SessionServices,
|
||||
next_internal_sub_id: AtomicU64,
|
||||
}
|
||||
@@ -495,38 +495,6 @@ impl Session {
|
||||
Ok((sess, turn_context))
|
||||
}
|
||||
|
||||
pub async fn set_task(&self, task: AgentTask) {
|
||||
let mut state = self.state.lock().await;
|
||||
if let Some(current_task) = state.current_task.take() {
|
||||
current_task.abort(TurnAbortReason::Replaced);
|
||||
}
|
||||
state.current_task = Some(task);
|
||||
if let Some(current_task) = &state.current_task {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
*active = Some(ActiveTurn {
|
||||
sub_id: current_task.sub_id.clone(),
|
||||
turn_state: std::sync::Arc::new(tokio::sync::Mutex::new(
|
||||
crate::state::TurnState::default(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_task(&self, sub_id: &str) {
|
||||
let mut state = self.state.lock().await;
|
||||
if let Some(task) = &state.current_task
|
||||
&& task.sub_id == sub_id
|
||||
{
|
||||
state.current_task.take();
|
||||
}
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = &*active
|
||||
&& at.sub_id == sub_id
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn next_internal_sub_id(&self) -> String {
|
||||
let id = self
|
||||
.next_internal_sub_id
|
||||
@@ -1015,26 +983,25 @@ impl Session {
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<InputItem>) -> Result<(), Vec<InputItem>> {
|
||||
let state = self.state.lock().await;
|
||||
if state.current_task.is_some() {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.as_mut() {
|
||||
Some(at) => {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.push_pending_input(input.into());
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Err(input)
|
||||
None => Err(input),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pending_input(&self) -> Vec<ResponseInputItem> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.take_pending_input()
|
||||
} else {
|
||||
Vec::with_capacity(0)
|
||||
match active.as_mut() {
|
||||
Some(at) => {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.take_pending_input()
|
||||
}
|
||||
None => Vec::with_capacity(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1050,29 +1017,20 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn interrupt_task(&self) {
|
||||
pub async fn interrupt_task(self: &Arc<Self>) {
|
||||
info!("interrupt received: abort current task, if any");
|
||||
let mut state = self.state.lock().await;
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut() {
|
||||
let mut ts = at.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
if let Some(task) = state.current_task.take() {
|
||||
task.abort(TurnAbortReason::Interrupted);
|
||||
}
|
||||
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
}
|
||||
|
||||
fn interrupt_task_sync(&self) {
|
||||
if let Ok(mut state) = self.state.try_lock() {
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
&& let Ok(mut ts) = at.turn_state.try_lock()
|
||||
{
|
||||
ts.clear_pending();
|
||||
}
|
||||
if let Some(task) = state.current_task.take() {
|
||||
task.abort(TurnAbortReason::Interrupted);
|
||||
if let Ok(mut active) = self.active_turn.try_lock()
|
||||
&& let Some(at) = active.as_mut()
|
||||
{
|
||||
at.try_clear_pending_sync();
|
||||
let tasks = at.drain_tasks();
|
||||
*active = None;
|
||||
for (_sub_id, task) in tasks {
|
||||
task.handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1111,106 +1069,6 @@ pub(crate) struct ApplyPatchCommandContext {
|
||||
pub(crate) changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
enum AgentTaskKind {
|
||||
Regular,
|
||||
Review,
|
||||
Compact,
|
||||
}
|
||||
|
||||
/// A series of Turns in response to user input.
|
||||
pub(crate) struct AgentTask {
|
||||
sess: Arc<Session>,
|
||||
sub_id: String,
|
||||
handle: AbortHandle,
|
||||
kind: AgentTaskKind,
|
||||
}
|
||||
|
||||
impl AgentTask {
|
||||
fn spawn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Regular,
|
||||
}
|
||||
}
|
||||
|
||||
fn review(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Review,
|
||||
}
|
||||
}
|
||||
|
||||
fn compact(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Self {
|
||||
let handle = {
|
||||
let sess = sess.clone();
|
||||
let sub_id = sub_id.clone();
|
||||
let tc = Arc::clone(&turn_context);
|
||||
tokio::spawn(async move { compact::run_compact_task(sess, tc, sub_id, input).await })
|
||||
.abort_handle()
|
||||
};
|
||||
Self {
|
||||
sess,
|
||||
sub_id,
|
||||
handle,
|
||||
kind: AgentTaskKind::Compact,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self, reason: TurnAbortReason) {
|
||||
// TOCTOU?
|
||||
if !self.handle.is_finished() {
|
||||
self.handle.abort();
|
||||
let sub_id = self.sub_id.clone();
|
||||
let is_review = self.kind == AgentTaskKind::Review;
|
||||
let sess = self.sess;
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
if is_review {
|
||||
exit_review_mode(sess.clone(), sub_id.clone(), None).await;
|
||||
}
|
||||
// Ensure active turn state is cleared when a task is aborted.
|
||||
sess.remove_task(&sub_id).await;
|
||||
sess.send_event(event).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn submission_loop(
|
||||
sess: Arc<Session>,
|
||||
turn_context: TurnContext,
|
||||
@@ -1318,9 +1176,8 @@ async fn submission_loop(
|
||||
// attempt to inject input into current task
|
||||
if let Err(items) = sess.inject_input(items).await {
|
||||
// no current task, spawn a new one
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task).await;
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::UserTurn {
|
||||
@@ -1396,10 +1253,9 @@ async fn submission_loop(
|
||||
// Install the new persistent context for subsequent tasks/turns.
|
||||
turn_context = Arc::new(fresh_turn_context);
|
||||
|
||||
// no current task, spawn a new one with the per‑turn context
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task).await;
|
||||
// no current task, spawn a new one with the per-turn context
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, RegularTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::ExecApproval { id, decision } => match decision {
|
||||
@@ -1497,16 +1353,12 @@ async fn submission_loop(
|
||||
}])
|
||||
.await
|
||||
{
|
||||
compact::spawn_compact_task(
|
||||
sess.clone(),
|
||||
Arc::clone(&turn_context),
|
||||
sub.id,
|
||||
items,
|
||||
)
|
||||
.await;
|
||||
sess.spawn_task(Arc::clone(&turn_context), sub.id, items, CompactTask)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Op::Shutdown => {
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
info!("Shutting down Codex instance");
|
||||
|
||||
// Gracefully flush and shutdown rollout recorder on session end so tests
|
||||
@@ -1648,8 +1500,7 @@ async fn spawn_review_thread(
|
||||
|
||||
// Clone sub_id for the upcoming announcement before moving it into the task.
|
||||
let sub_id_for_event = sub_id.clone();
|
||||
let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input);
|
||||
sess.set_task(task).await;
|
||||
sess.spawn_task(tc.clone(), sub_id, input, ReviewTask).await;
|
||||
|
||||
// Announce entering review mode so UIs can switch modes.
|
||||
sess.send_event(Event {
|
||||
@@ -1676,14 +1527,14 @@ async fn spawn_review_thread(
|
||||
/// Review mode: when `turn_context.is_review_mode` is true, the turn runs in an
|
||||
/// isolated in-memory thread without the parent session's prior history or
|
||||
/// user_instructions. Emits ExitedReviewMode upon final review message.
|
||||
async fn run_task(
|
||||
pub(crate) async fn run_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
if input.is_empty() {
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
@@ -1955,12 +1806,7 @@ async fn run_task(
|
||||
.await;
|
||||
}
|
||||
|
||||
sess.remove_task(&sub_id).await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
last_agent_message
|
||||
}
|
||||
|
||||
/// Parse the review output; when not valid JSON, build a structured
|
||||
@@ -3219,7 +3065,7 @@ fn convert_call_tool_result_to_function_call_output_payload(
|
||||
|
||||
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
|
||||
/// and records a developer message with the review output.
|
||||
async fn exit_review_mode(
|
||||
pub(crate) async fn exit_review_mode(
|
||||
session: Arc<Session>,
|
||||
task_sub_id: String,
|
||||
review_output: Option<ReviewOutputEvent>,
|
||||
@@ -3283,6 +3129,9 @@ mod tests {
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::InitialHistory;
|
||||
use crate::protocol::ResumedHistory;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
@@ -3292,6 +3141,8 @@ mod tests {
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_history_matches_live_compactions() {
|
||||
@@ -3577,6 +3428,174 @@ mod tests {
|
||||
(session, turn_context)
|
||||
}
|
||||
|
||||
// Like make_session_and_context, but returns Arc<Session> and the event receiver
|
||||
// so tests can assert on emitted events.
|
||||
fn make_session_and_context_with_rx() -> (
|
||||
Arc<Session>,
|
||||
Arc<TurnContext>,
|
||||
async_channel::Receiver<Event>,
|
||||
) {
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
let codex_home = tempfile::tempdir().expect("create temp dir");
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.expect("load default test config");
|
||||
let config = Arc::new(config);
|
||||
let conversation_id = ConversationId::default();
|
||||
let client = ModelClient::new(
|
||||
config.clone(),
|
||||
None,
|
||||
config.model_provider.clone(),
|
||||
config.model_reasoning_effort,
|
||||
config.model_reasoning_summary,
|
||||
conversation_id,
|
||||
);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||
include_view_image_tool: config.include_view_image_tool,
|
||||
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||
});
|
||||
let turn_context = Arc::new(TurnContext {
|
||||
client,
|
||||
cwd: config.cwd.clone(),
|
||||
base_instructions: config.base_instructions.clone(),
|
||||
user_instructions: config.user_instructions.clone(),
|
||||
approval_policy: config.approval_policy,
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
shell_environment_policy: config.shell_environment_policy.clone(),
|
||||
tools_config,
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
});
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: McpConnectionManager::default(),
|
||||
session_manager: ExecSessionManager::default(),
|
||||
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||
notifier: UserNotifier::default(),
|
||||
rollout: Mutex::new(None),
|
||||
codex_linux_sandbox_exe: None,
|
||||
user_shell: shell::Shell::Unknown,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
};
|
||||
let session = Arc::new(Session {
|
||||
conversation_id,
|
||||
tx_event,
|
||||
state: Mutex::new(SessionState::new()),
|
||||
active_turn: Mutex::new(None),
|
||||
services,
|
||||
next_internal_sub_id: AtomicU64::new(0),
|
||||
});
|
||||
(session, turn_context, rx_event)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NeverEndingTask(TaskKind);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl SessionTask for NeverEndingTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
self.0
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
_session: Arc<SessionTaskContext>,
|
||||
_ctx: Arc<TurnContext>,
|
||||
_sub_id: String,
|
||||
_input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
if let TaskKind::Review = self.0 {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-regular".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "hello".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Regular),
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let evt = rx.recv().await.expect("event");
|
||||
match evt.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected event: {other:?}"),
|
||||
}
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||
let sub_id = "sub-review".to_string();
|
||||
let input = vec![InputItem::Text {
|
||||
text: "start review".to_string(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
sub_id.clone(),
|
||||
input,
|
||||
NeverEndingTask(TaskKind::Review),
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let first = rx.recv().await.expect("first event");
|
||||
match first.msg {
|
||||
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
||||
other => panic!("unexpected first event: {other:?}"),
|
||||
}
|
||||
let second = rx.recv().await.expect("second event");
|
||||
match second.msg {
|
||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||
other => panic!("unexpected second event: {other:?}"),
|
||||
}
|
||||
|
||||
let history = sess.history_snapshot().await;
|
||||
let found = history.iter().any(|item| match item {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content.iter().any(|ci| match ci {
|
||||
ContentItem::InputText { text } => {
|
||||
text.contains("<user_action>")
|
||||
&& text.contains("review")
|
||||
&& text.contains("interrupted")
|
||||
}
|
||||
_ => false,
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
});
|
||||
assert!(
|
||||
found,
|
||||
"synthetic review interruption not recorded in history"
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_rollout(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::AgentTask;
|
||||
use super::Session;
|
||||
use super::TurnContext;
|
||||
use super::get_last_assistant_message_from_turn;
|
||||
@@ -15,7 +14,6 @@ use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::InputMessageKind;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
use crate::truncate::truncate_middle;
|
||||
@@ -37,17 +35,7 @@ struct HistoryBridgeTemplate<'a> {
|
||||
summary_text: &'a str,
|
||||
}
|
||||
|
||||
pub(super) async fn spawn_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
let task = AgentTask::compact(sess.clone(), turn_context, sub_id, input);
|
||||
sess.set_task(task).await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_inline_auto_compact_task(
|
||||
pub(crate) async fn run_inline_auto_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) {
|
||||
@@ -55,15 +43,15 @@ pub(super) async fn run_inline_auto_compact_task(
|
||||
let input = vec![InputItem::Text {
|
||||
text: SUMMARIZATION_PROMPT.to_string(),
|
||||
}];
|
||||
run_compact_task_inner(sess, turn_context, sub_id, input, false).await;
|
||||
run_compact_task_inner(sess, turn_context, sub_id, input).await;
|
||||
}
|
||||
|
||||
pub(super) async fn run_compact_task(
|
||||
pub(crate) async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) {
|
||||
) -> Option<String> {
|
||||
let start_event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TaskStarted(TaskStartedEvent {
|
||||
@@ -71,14 +59,8 @@ pub(super) async fn run_compact_task(
|
||||
}),
|
||||
};
|
||||
sess.send_event(start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input, true).await;
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent {
|
||||
last_agent_message: None,
|
||||
}),
|
||||
};
|
||||
sess.send_event(event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, sub_id.clone(), input).await;
|
||||
None
|
||||
}
|
||||
|
||||
async fn run_compact_task_inner(
|
||||
@@ -86,7 +68,6 @@ async fn run_compact_task_inner(
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
remove_task_on_completion: bool,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input = sess
|
||||
@@ -148,9 +129,6 @@ async fn run_compact_task_inner(
|
||||
}
|
||||
}
|
||||
|
||||
if remove_task_on_completion {
|
||||
sess.remove_task(&sub_id).await;
|
||||
}
|
||||
let history_snapshot = sess.history_snapshot().await;
|
||||
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||
let user_messages = collect_user_messages(&history_snapshot);
|
||||
|
||||
@@ -77,6 +77,7 @@ pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
mod function_tool;
|
||||
mod state;
|
||||
mod tasks;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
|
||||
|
||||
@@ -5,4 +5,5 @@ mod turn;
|
||||
pub(crate) use service::SessionServices;
|
||||
pub(crate) use session::SessionState;
|
||||
pub(crate) use turn::ActiveTurn;
|
||||
pub(crate) use turn::TurnState;
|
||||
pub(crate) use turn::RunningTask;
|
||||
pub(crate) use turn::TaskKind;
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::collections::HashSet;
|
||||
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::codex::AgentTask;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
@@ -14,7 +13,6 @@ use crate::protocol::TokenUsageInfo;
|
||||
#[derive(Default)]
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) approved_commands: HashSet<Vec<String>>,
|
||||
pub(crate) current_task: Option<AgentTask>,
|
||||
pub(crate) history: ConversationHistory,
|
||||
pub(crate) token_info: Option<TokenUsageInfo>,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
|
||||
@@ -1,21 +1,61 @@
|
||||
//! Turn-scoped state and active turn metadata scaffolding.
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::AbortHandle;
|
||||
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct ActiveTurn {
|
||||
pub(crate) sub_id: String,
|
||||
pub(crate) tasks: IndexMap<String, RunningTask>,
|
||||
pub(crate) turn_state: Arc<Mutex<TurnState>>,
|
||||
}
|
||||
|
||||
impl Default for ActiveTurn {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tasks: IndexMap::new(),
|
||||
turn_state: Arc::new(Mutex::new(TurnState::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum TaskKind {
|
||||
Regular,
|
||||
Review,
|
||||
Compact,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RunningTask {
|
||||
pub(crate) handle: AbortHandle,
|
||||
pub(crate) kind: TaskKind,
|
||||
pub(crate) task: Arc<dyn SessionTask>,
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
pub(crate) fn add_task(&mut self, sub_id: String, task: RunningTask) {
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
|
||||
self.tasks.swap_remove(sub_id);
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn drain_tasks(&mut self) -> IndexMap<String, RunningTask> {
|
||||
std::mem::take(&mut self.tasks)
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state for a single turn.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TurnState {
|
||||
@@ -58,3 +98,18 @@ impl TurnState {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
/// Clear any pending approvals and input buffered for the current turn.
|
||||
pub(crate) async fn clear_pending(&self) {
|
||||
let mut ts = self.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
|
||||
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
|
||||
pub(crate) fn try_clear_pending_sync(&self) {
|
||||
if let Ok(mut ts) = self.turn_state.try_lock() {
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
31
codex-rs/core/src/tasks/compact.rs
Normal file
31
codex-rs/core/src/tasks/compact.rs
Normal file
@@ -0,0 +1,31 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::compact;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct CompactTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for CompactTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Compact
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
||||
}
|
||||
}
|
||||
166
codex-rs/core/src/tasks/mod.rs
Normal file
166
codex-rs/core/src/tasks/mod.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
mod compact;
|
||||
mod regular;
|
||||
mod review;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnAbortReason;
|
||||
use crate::protocol::TurnAbortedEvent;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::RunningTask;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
|
||||
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionTaskContext {
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
impl SessionTaskContext {
|
||||
pub(crate) fn new(session: Arc<Session>) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
|
||||
pub(crate) fn clone_session(&self) -> Arc<Session> {
|
||||
Arc::clone(&self.session)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
fn kind(&self) -> TaskKind;
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String>;
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
let _ = (session, sub_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn spawn_task<T: SessionTask>(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
task: T,
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
|
||||
let handle = {
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let sub_clone = sub_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let last_agent_message = task_for_run
|
||||
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
|
||||
.await;
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||
})
|
||||
.abort_handle()
|
||||
};
|
||||
|
||||
let running_task = RunningTask {
|
||||
handle,
|
||||
kind: task_kind,
|
||||
task,
|
||||
};
|
||||
self.register_new_active_task(sub_id, running_task).await;
|
||||
}
|
||||
|
||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
for (sub_id, task) in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(sub_id, task, reason.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
let event = Event {
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn register_new_active_task(&self, sub_id: String, task: RunningTask) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut turn = ActiveTurn::default();
|
||||
turn.add_task(sub_id, task);
|
||||
*active = Some(turn);
|
||||
}
|
||||
|
||||
async fn take_all_running_tasks(&self) -> Vec<(String, RunningTask)> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
at.clear_pending().await;
|
||||
let tasks = at.drain_tasks();
|
||||
tasks.into_iter().collect()
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(
|
||||
self: &Arc<Self>,
|
||||
sub_id: String,
|
||||
task: RunningTask,
|
||||
reason: TurnAbortReason,
|
||||
) {
|
||||
if task.handle.is_finished() {
|
||||
return;
|
||||
}
|
||||
|
||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||
let session_task = task.task;
|
||||
let handle = task.handle;
|
||||
handle.abort();
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task.abort(session_ctx, &sub_id).await;
|
||||
|
||||
let event = Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
self.send_event(event).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
32
codex-rs/core/src/tasks/regular.rs
Normal file
32
codex-rs/core/src/tasks/regular.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_task;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct RegularTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for RegularTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
}
|
||||
}
|
||||
37
codex-rs/core/src/tasks/review.rs
Normal file
37
codex-rs/core/src/tasks/review.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::exit_review_mode;
|
||||
use crate::codex::run_task;
|
||||
use crate::protocol::InputItem;
|
||||
use crate::state::TaskKind;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct ReviewTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for ReviewTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Review
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
sub_id: String,
|
||||
input: Vec<InputItem>,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, sub_id, input).await
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||
}
|
||||
}
|
||||
66
codex-rs/core/tests/suite/abort_tasks.rs
Normal file
66
codex-rs/core/tests/suite/abort_tasks.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use serde_json::json;
|
||||
use wiremock::matchers::body_string_contains;
|
||||
|
||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||
/// function call, then interrupt the session and expect TurnAborted.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
let command = vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 60".to_string(),
|
||||
];
|
||||
|
||||
let args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let body = sse(vec![ev_function_call("call_sleep", "shell", &args)]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
mount_sse_once(&server, body_string_contains("start sleep"), body).await;
|
||||
|
||||
let codex = test_codex().build(&server).await.unwrap().codex;
|
||||
|
||||
let wait_timeout = Duration::from_secs(5);
|
||||
|
||||
// Kick off a turn that triggers the function call.
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "start sleep".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Wait until the exec begins to avoid a race, then interrupt.
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
|
||||
codex.submit(Op::Interrupt).await.unwrap();
|
||||
|
||||
// Expect TurnAborted soon after.
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::TurnAborted(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
// Aggregates all former standalone integration tests as modules.
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
mod abort_tasks;
|
||||
mod cli_stream;
|
||||
mod client;
|
||||
mod compact;
|
||||
|
||||
Reference in New Issue
Block a user