Support graceful agent interruption (#5287)
This commit is contained in:
34
codex-rs/Cargo.lock
generated
34
codex-rs/Cargo.lock
generated
@@ -899,6 +899,16 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "codex-async-utils"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"pretty_assertions",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "codex-backend-client"
|
name = "codex-backend-client"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
@@ -1037,6 +1047,7 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"codex-app-server-protocol",
|
"codex-app-server-protocol",
|
||||||
"codex-apply-patch",
|
"codex-apply-patch",
|
||||||
|
"codex-async-utils",
|
||||||
"codex-file-search",
|
"codex-file-search",
|
||||||
"codex-mcp-client",
|
"codex-mcp-client",
|
||||||
"codex-otel",
|
"codex-otel",
|
||||||
@@ -1073,6 +1084,7 @@ dependencies = [
|
|||||||
"similar",
|
"similar",
|
||||||
"strum_macros 0.27.2",
|
"strum_macros 0.27.2",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
|
"test-log",
|
||||||
"thiserror 2.0.16",
|
"thiserror 2.0.16",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
@@ -6022,6 +6034,28 @@ version = "0.5.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "test-log"
|
||||||
|
version = "0.2.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b"
|
||||||
|
dependencies = [
|
||||||
|
"env_logger",
|
||||||
|
"test-log-macros",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "test-log-macros"
|
||||||
|
version = "0.2.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.104",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "textwrap"
|
name = "textwrap"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
members = [
|
members = [
|
||||||
"backend-client",
|
"backend-client",
|
||||||
"ansi-escape",
|
"ansi-escape",
|
||||||
|
"async-utils",
|
||||||
"app-server",
|
"app-server",
|
||||||
"app-server-protocol",
|
"app-server-protocol",
|
||||||
"apply-patch",
|
"apply-patch",
|
||||||
@@ -56,6 +57,7 @@ codex-arg0 = { path = "arg0" }
|
|||||||
codex-chatgpt = { path = "chatgpt" }
|
codex-chatgpt = { path = "chatgpt" }
|
||||||
codex-common = { path = "common" }
|
codex-common = { path = "common" }
|
||||||
codex-core = { path = "core" }
|
codex-core = { path = "core" }
|
||||||
|
codex-async-utils = { path = "async-utils" }
|
||||||
codex-exec = { path = "exec" }
|
codex-exec = { path = "exec" }
|
||||||
codex-feedback = { path = "feedback" }
|
codex-feedback = { path = "feedback" }
|
||||||
codex-file-search = { path = "file-search" }
|
codex-file-search = { path = "file-search" }
|
||||||
@@ -164,6 +166,7 @@ strum_macros = "0.27.2"
|
|||||||
supports-color = "3.0.2"
|
supports-color = "3.0.2"
|
||||||
sys-locale = "0.3.2"
|
sys-locale = "0.3.2"
|
||||||
tempfile = "3.23.0"
|
tempfile = "3.23.0"
|
||||||
|
test-log = "0.2.18"
|
||||||
textwrap = "0.16.2"
|
textwrap = "0.16.2"
|
||||||
thiserror = "2.0.16"
|
thiserror = "2.0.16"
|
||||||
time = "0.3"
|
time = "0.3"
|
||||||
|
|||||||
15
codex-rs/async-utils/Cargo.toml
Normal file
15
codex-rs/async-utils/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[package]
|
||||||
|
edition.workspace = true
|
||||||
|
name = "codex-async-utils"
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-trait.workspace = true
|
||||||
|
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "time"] }
|
||||||
|
tokio-util.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
pretty_assertions.workspace = true
|
||||||
86
codex-rs/async-utils/src/lib.rs
Normal file
86
codex-rs/async-utils/src/lib.rs
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use std::future::Future;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum CancelErr {
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait OrCancelExt: Sized {
|
||||||
|
type Output;
|
||||||
|
|
||||||
|
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<F> OrCancelExt for F
|
||||||
|
where
|
||||||
|
F: Future + Send,
|
||||||
|
F::Output: Send,
|
||||||
|
{
|
||||||
|
type Output = F::Output;
|
||||||
|
|
||||||
|
async fn or_cancel(self, token: &CancellationToken) -> Result<Self::Output, CancelErr> {
|
||||||
|
tokio::select! {
|
||||||
|
_ = token.cancelled() => Err(CancelErr::Cancelled),
|
||||||
|
res = self => Ok(res),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::task;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_ok_when_future_completes_first() {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
let value = async { 42 };
|
||||||
|
|
||||||
|
let result = value.or_cancel(&token).await;
|
||||||
|
|
||||||
|
assert_eq!(Ok(42), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_err_when_token_cancelled_first() {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
let token_clone = token.clone();
|
||||||
|
|
||||||
|
let cancel_handle = task::spawn(async move {
|
||||||
|
sleep(Duration::from_millis(10)).await;
|
||||||
|
token_clone.cancel();
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = async {
|
||||||
|
sleep(Duration::from_millis(100)).await;
|
||||||
|
7
|
||||||
|
}
|
||||||
|
.or_cancel(&token)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
cancel_handle.await.expect("cancel task panicked");
|
||||||
|
assert_eq!(Err(CancelErr::Cancelled), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_err_when_token_already_cancelled() {
|
||||||
|
let token = CancellationToken::new();
|
||||||
|
token.cancel();
|
||||||
|
|
||||||
|
let result = async {
|
||||||
|
sleep(Duration::from_millis(50)).await;
|
||||||
|
5
|
||||||
|
}
|
||||||
|
.or_cancel(&token)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(Err(CancelErr::Cancelled), result);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,7 @@ codex-mcp-client = { workspace = true }
|
|||||||
codex-otel = { workspace = true, features = ["otel"] }
|
codex-otel = { workspace = true, features = ["otel"] }
|
||||||
codex-protocol = { workspace = true }
|
codex-protocol = { workspace = true }
|
||||||
codex-rmcp-client = { workspace = true }
|
codex-rmcp-client = { workspace = true }
|
||||||
|
codex-async-utils = { workspace = true }
|
||||||
codex-utils-string = { workspace = true }
|
codex-utils-string = { workspace = true }
|
||||||
dirs = { workspace = true }
|
dirs = { workspace = true }
|
||||||
dunce = { workspace = true }
|
dunce = { workspace = true }
|
||||||
@@ -47,6 +48,7 @@ shlex = { workspace = true }
|
|||||||
similar = { workspace = true }
|
similar = { workspace = true }
|
||||||
strum_macros = { workspace = true }
|
strum_macros = { workspace = true }
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
test-log = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
time = { workspace = true, features = [
|
time = { workspace = true, features = [
|
||||||
"formatting",
|
"formatting",
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ use serde_json;
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -119,6 +120,7 @@ use crate::unified_exec::UnifiedExecSessionManager;
|
|||||||
use crate::user_instructions::UserInstructions;
|
use crate::user_instructions::UserInstructions;
|
||||||
use crate::user_notification::UserNotification;
|
use crate::user_notification::UserNotification;
|
||||||
use crate::util::backoff;
|
use crate::util::backoff;
|
||||||
|
use codex_async_utils::OrCancelExt;
|
||||||
use codex_otel::otel_event_manager::OtelEventManager;
|
use codex_otel::otel_event_manager::OtelEventManager;
|
||||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||||
@@ -1170,19 +1172,6 @@ impl Session {
|
|||||||
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn interrupt_task_sync(&self) {
|
|
||||||
if let Ok(mut active) = self.active_turn.try_lock()
|
|
||||||
&& let Some(at) = active.as_mut()
|
|
||||||
{
|
|
||||||
at.try_clear_pending_sync();
|
|
||||||
let tasks = at.drain_tasks();
|
|
||||||
*active = None;
|
|
||||||
for (_sub_id, task) in tasks {
|
|
||||||
task.handle.abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn notifier(&self) -> &UserNotifier {
|
pub(crate) fn notifier(&self) -> &UserNotifier {
|
||||||
&self.services.notifier
|
&self.services.notifier
|
||||||
}
|
}
|
||||||
@@ -1196,12 +1185,6 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for Session {
|
|
||||||
fn drop(&mut self) {
|
|
||||||
self.interrupt_task_sync();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn submission_loop(
|
async fn submission_loop(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: TurnContext,
|
turn_context: TurnContext,
|
||||||
@@ -1711,6 +1694,7 @@ pub(crate) async fn run_task(
|
|||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
if input.is_empty() {
|
if input.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
@@ -1795,6 +1779,7 @@ pub(crate) async fn run_task(
|
|||||||
sub_id.clone(),
|
sub_id.clone(),
|
||||||
turn_input,
|
turn_input,
|
||||||
task_kind,
|
task_kind,
|
||||||
|
cancellation_token.child_token(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@@ -1956,6 +1941,10 @@ pub(crate) async fn run_task(
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Err(CodexErr::TurnAborted) => {
|
||||||
|
// Aborted turn is reported via a different event.
|
||||||
|
break;
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Turn error: {e:#}");
|
info!("Turn error: {e:#}");
|
||||||
let event = Event {
|
let event = Event {
|
||||||
@@ -2022,6 +2011,7 @@ async fn run_turn(
|
|||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<ResponseItem>,
|
input: Vec<ResponseItem>,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> CodexResult<TurnRunResult> {
|
) -> CodexResult<TurnRunResult> {
|
||||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||||
let router = Arc::new(ToolRouter::from_config(
|
let router = Arc::new(ToolRouter::from_config(
|
||||||
@@ -2052,10 +2042,12 @@ async fn run_turn(
|
|||||||
&sub_id,
|
&sub_id,
|
||||||
&prompt,
|
&prompt,
|
||||||
task_kind,
|
task_kind,
|
||||||
|
cancellation_token.child_token(),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(output) => return Ok(output),
|
Ok(output) => return Ok(output),
|
||||||
|
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
|
||||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||||
@@ -2118,6 +2110,7 @@ struct TurnRunResult {
|
|||||||
total_token_usage: Option<TokenUsage>,
|
total_token_usage: Option<TokenUsage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn try_run_turn(
|
async fn try_run_turn(
|
||||||
router: Arc<ToolRouter>,
|
router: Arc<ToolRouter>,
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
@@ -2126,6 +2119,7 @@ async fn try_run_turn(
|
|||||||
sub_id: &str,
|
sub_id: &str,
|
||||||
prompt: &Prompt,
|
prompt: &Prompt,
|
||||||
task_kind: TaskKind,
|
task_kind: TaskKind,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> CodexResult<TurnRunResult> {
|
) -> CodexResult<TurnRunResult> {
|
||||||
// call_ids that are part of this response.
|
// call_ids that are part of this response.
|
||||||
let completed_call_ids = prompt
|
let completed_call_ids = prompt
|
||||||
@@ -2195,7 +2189,8 @@ async fn try_run_turn(
|
|||||||
.client
|
.client
|
||||||
.clone()
|
.clone()
|
||||||
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
||||||
.await?;
|
.or_cancel(&cancellation_token)
|
||||||
|
.await??;
|
||||||
|
|
||||||
let tool_runtime = ToolCallRuntime::new(
|
let tool_runtime = ToolCallRuntime::new(
|
||||||
Arc::clone(&router),
|
Arc::clone(&router),
|
||||||
@@ -2211,7 +2206,8 @@ async fn try_run_turn(
|
|||||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||||
let event = stream.next().await;
|
let event = stream.next().or_cancel(&cancellation_token).await?;
|
||||||
|
|
||||||
let event = match event {
|
let event = match event {
|
||||||
Some(res) => res?,
|
Some(res) => res?,
|
||||||
None => {
|
None => {
|
||||||
@@ -2316,7 +2312,10 @@ async fn try_run_turn(
|
|||||||
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
sess.update_token_usage_info(sub_id, turn_context.as_ref(), token_usage.as_ref())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let processed_items: Vec<ProcessedResponseItem> = output.try_collect().await?;
|
let processed_items = output
|
||||||
|
.try_collect()
|
||||||
|
.or_cancel(&cancellation_token)
|
||||||
|
.await??;
|
||||||
|
|
||||||
let unified_diff = {
|
let unified_diff = {
|
||||||
let mut tracker = turn_diff_tracker.lock().await;
|
let mut tracker = turn_diff_tracker.lock().await;
|
||||||
@@ -2554,6 +2553,8 @@ mod tests {
|
|||||||
use codex_app_server_protocol::AuthMode;
|
use codex_app_server_protocol::AuthMode;
|
||||||
use codex_protocol::models::ContentItem;
|
use codex_protocol::models::ContentItem;
|
||||||
use codex_protocol::models::ResponseItem;
|
use codex_protocol::models::ResponseItem;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
use mcp_types::ContentBlock;
|
use mcp_types::ContentBlock;
|
||||||
use mcp_types::TextContent;
|
use mcp_types::TextContent;
|
||||||
@@ -2563,8 +2564,6 @@ mod tests {
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration as StdDuration;
|
use std::time::Duration as StdDuration;
|
||||||
use tokio::time::Duration;
|
|
||||||
use tokio::time::sleep;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reconstruct_history_matches_live_compactions() {
|
fn reconstruct_history_matches_live_compactions() {
|
||||||
@@ -2944,12 +2943,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
struct NeverEndingTask(TaskKind);
|
struct NeverEndingTask {
|
||||||
|
kind: TaskKind,
|
||||||
|
listen_to_cancellation_token: bool,
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl SessionTask for NeverEndingTask {
|
impl SessionTask for NeverEndingTask {
|
||||||
fn kind(&self) -> TaskKind {
|
fn kind(&self) -> TaskKind {
|
||||||
self.0
|
self.kind
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run(
|
async fn run(
|
||||||
@@ -2958,20 +2960,26 @@ mod tests {
|
|||||||
_ctx: Arc<TurnContext>,
|
_ctx: Arc<TurnContext>,
|
||||||
_sub_id: String,
|
_sub_id: String,
|
||||||
_input: Vec<InputItem>,
|
_input: Vec<InputItem>,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
|
if self.listen_to_cancellation_token {
|
||||||
|
cancellation_token.cancelled().await;
|
||||||
|
return None;
|
||||||
|
}
|
||||||
loop {
|
loop {
|
||||||
sleep(Duration::from_secs(60)).await;
|
sleep(Duration::from_secs(60)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||||
if let TaskKind::Review = self.0 {
|
if let TaskKind::Review = self.kind {
|
||||||
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
exit_review_mode(session.clone_session(), sub_id.to_string(), None).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
#[test_log::test]
|
||||||
async fn abort_regular_task_emits_turn_aborted_only() {
|
async fn abort_regular_task_emits_turn_aborted_only() {
|
||||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
let sub_id = "sub-regular".to_string();
|
let sub_id = "sub-regular".to_string();
|
||||||
@@ -2982,7 +2990,41 @@ mod tests {
|
|||||||
Arc::clone(&tc),
|
Arc::clone(&tc),
|
||||||
sub_id.clone(),
|
sub_id.clone(),
|
||||||
input,
|
input,
|
||||||
NeverEndingTask(TaskKind::Regular),
|
NeverEndingTask {
|
||||||
|
kind: TaskKind::Regular,
|
||||||
|
listen_to_cancellation_token: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||||
|
|
||||||
|
let evt = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout waiting for event")
|
||||||
|
.expect("event");
|
||||||
|
match evt.msg {
|
||||||
|
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||||
|
other => panic!("unexpected event: {other:?}"),
|
||||||
|
}
|
||||||
|
assert!(rx.try_recv().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn abort_gracefuly_emits_turn_aborted_only() {
|
||||||
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
|
let sub_id = "sub-regular".to_string();
|
||||||
|
let input = vec![InputItem::Text {
|
||||||
|
text: "hello".to_string(),
|
||||||
|
}];
|
||||||
|
sess.spawn_task(
|
||||||
|
Arc::clone(&tc),
|
||||||
|
sub_id.clone(),
|
||||||
|
input,
|
||||||
|
NeverEndingTask {
|
||||||
|
kind: TaskKind::Regular,
|
||||||
|
listen_to_cancellation_token: true,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -2996,7 +3038,7 @@ mod tests {
|
|||||||
assert!(rx.try_recv().is_err());
|
assert!(rx.try_recv().is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
async fn abort_review_task_emits_exited_then_aborted_and_records_history() {
|
||||||
let (sess, tc, rx) = make_session_and_context_with_rx();
|
let (sess, tc, rx) = make_session_and_context_with_rx();
|
||||||
let sub_id = "sub-review".to_string();
|
let sub_id = "sub-review".to_string();
|
||||||
@@ -3007,18 +3049,27 @@ mod tests {
|
|||||||
Arc::clone(&tc),
|
Arc::clone(&tc),
|
||||||
sub_id.clone(),
|
sub_id.clone(),
|
||||||
input,
|
input,
|
||||||
NeverEndingTask(TaskKind::Review),
|
NeverEndingTask {
|
||||||
|
kind: TaskKind::Review,
|
||||||
|
listen_to_cancellation_token: false,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||||
|
|
||||||
let first = rx.recv().await.expect("first event");
|
let first = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout waiting for first event")
|
||||||
|
.expect("first event");
|
||||||
match first.msg {
|
match first.msg {
|
||||||
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
EventMsg::ExitedReviewMode(ev) => assert!(ev.review_output.is_none()),
|
||||||
other => panic!("unexpected first event: {other:?}"),
|
other => panic!("unexpected first event: {other:?}"),
|
||||||
}
|
}
|
||||||
let second = rx.recv().await.expect("second event");
|
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("timeout waiting for second event")
|
||||||
|
.expect("second event");
|
||||||
match second.msg {
|
match second.msg {
|
||||||
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
EventMsg::TurnAborted(e) => assert_eq!(TurnAbortReason::Interrupted, e.reason),
|
||||||
other => panic!("unexpected second event: {other:?}"),
|
other => panic!("unexpected second event: {other:?}"),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use crate::exec::ExecToolCallOutput;
|
|||||||
use crate::token_data::KnownPlan;
|
use crate::token_data::KnownPlan;
|
||||||
use crate::token_data::PlanType;
|
use crate::token_data::PlanType;
|
||||||
use crate::truncate::truncate_middle;
|
use crate::truncate::truncate_middle;
|
||||||
|
use codex_async_utils::CancelErr;
|
||||||
use codex_protocol::ConversationId;
|
use codex_protocol::ConversationId;
|
||||||
use codex_protocol::protocol::RateLimitSnapshot;
|
use codex_protocol::protocol::RateLimitSnapshot;
|
||||||
use reqwest::StatusCode;
|
use reqwest::StatusCode;
|
||||||
@@ -50,6 +51,9 @@ pub enum SandboxErr {
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum CodexErr {
|
pub enum CodexErr {
|
||||||
|
#[error("turn aborted")]
|
||||||
|
TurnAborted,
|
||||||
|
|
||||||
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
||||||
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
||||||
///
|
///
|
||||||
@@ -150,6 +154,12 @@ pub enum CodexErr {
|
|||||||
EnvVar(EnvVarError),
|
EnvVar(EnvVarError),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<CancelErr> for CodexErr {
|
||||||
|
fn from(_: CancelErr) -> Self {
|
||||||
|
CodexErr::TurnAborted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ConnectionFailedError {
|
pub struct ConnectionFailedError {
|
||||||
pub source: reqwest::Error,
|
pub source: reqwest::Error,
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ mod client;
|
|||||||
mod client_common;
|
mod client_common;
|
||||||
pub mod codex;
|
pub mod codex;
|
||||||
mod codex_conversation;
|
mod codex_conversation;
|
||||||
pub mod token_data;
|
|
||||||
pub use codex_conversation::CodexConversation;
|
pub use codex_conversation::CodexConversation;
|
||||||
mod command_safety;
|
mod command_safety;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
@@ -39,6 +38,7 @@ mod mcp_tool_call;
|
|||||||
mod message_history;
|
mod message_history;
|
||||||
mod model_provider_info;
|
mod model_provider_info;
|
||||||
pub mod parse_command;
|
pub mod parse_command;
|
||||||
|
pub mod token_data;
|
||||||
mod truncate;
|
mod truncate;
|
||||||
mod unified_exec;
|
mod unified_exec;
|
||||||
mod user_instructions;
|
mod user_instructions;
|
||||||
@@ -107,5 +107,4 @@ pub use codex_protocol::models::LocalShellExecAction;
|
|||||||
pub use codex_protocol::models::LocalShellStatus;
|
pub use codex_protocol::models::LocalShellStatus;
|
||||||
pub use codex_protocol::models::ReasoningItemContent;
|
pub use codex_protocol::models::ReasoningItemContent;
|
||||||
pub use codex_protocol::models::ResponseItem;
|
pub use codex_protocol::models::ResponseItem;
|
||||||
|
|
||||||
pub mod otel_init;
|
pub mod otel_init;
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ use indexmap::IndexMap;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::AbortHandle;
|
use tokio::sync::Notify;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tokio_util::task::AbortOnDropHandle;
|
||||||
|
|
||||||
use codex_protocol::models::ResponseInputItem;
|
use codex_protocol::models::ResponseInputItem;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@@ -46,9 +48,11 @@ impl TaskKind {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct RunningTask {
|
pub(crate) struct RunningTask {
|
||||||
pub(crate) handle: AbortHandle,
|
pub(crate) done: Arc<Notify>,
|
||||||
pub(crate) kind: TaskKind,
|
pub(crate) kind: TaskKind,
|
||||||
pub(crate) task: Arc<dyn SessionTask>,
|
pub(crate) task: Arc<dyn SessionTask>,
|
||||||
|
pub(crate) cancellation_token: CancellationToken,
|
||||||
|
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ActiveTurn {
|
impl ActiveTurn {
|
||||||
@@ -115,13 +119,6 @@ impl ActiveTurn {
|
|||||||
let mut ts = self.turn_state.lock().await;
|
let mut ts = self.turn_state.lock().await;
|
||||||
ts.clear_pending();
|
ts.clear_pending();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Best-effort, non-blocking variant for synchronous contexts (Drop/interrupt).
|
|
||||||
pub(crate) fn try_clear_pending_sync(&self) {
|
|
||||||
if let Ok(mut ts) = self.turn_state.try_lock() {
|
|
||||||
ts.clear_pending();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::codex::compact;
|
use crate::codex::compact;
|
||||||
@@ -25,6 +26,7 @@ impl SessionTask for CompactTask {
|
|||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
|
_cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
compact::run_compact_task(session.clone_session(), ctx, sub_id, input).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,15 @@ mod regular;
|
|||||||
mod review;
|
mod review;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::Notify;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
use tokio_util::task::AbortOnDropHandle;
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
@@ -23,6 +29,8 @@ pub(crate) use compact::CompactTask;
|
|||||||
pub(crate) use regular::RegularTask;
|
pub(crate) use regular::RegularTask;
|
||||||
pub(crate) use review::ReviewTask;
|
pub(crate) use review::ReviewTask;
|
||||||
|
|
||||||
|
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
|
||||||
|
|
||||||
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct SessionTaskContext {
|
pub(crate) struct SessionTaskContext {
|
||||||
@@ -49,6 +57,7 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
|||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String>;
|
) -> Option<String>;
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||||
@@ -69,26 +78,42 @@ impl Session {
|
|||||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||||
let task_kind = task.kind();
|
let task_kind = task.kind();
|
||||||
|
|
||||||
|
let cancellation_token = CancellationToken::new();
|
||||||
|
let done = Arc::new(Notify::new());
|
||||||
|
|
||||||
|
let done_clone = Arc::clone(&done);
|
||||||
let handle = {
|
let handle = {
|
||||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||||
let ctx = Arc::clone(&turn_context);
|
let ctx = Arc::clone(&turn_context);
|
||||||
let task_for_run = Arc::clone(&task);
|
let task_for_run = Arc::clone(&task);
|
||||||
let sub_clone = sub_id.clone();
|
let sub_clone = sub_id.clone();
|
||||||
|
let task_cancellation_token = cancellation_token.child_token();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let last_agent_message = task_for_run
|
let last_agent_message = task_for_run
|
||||||
.run(Arc::clone(&session_ctx), ctx, sub_clone.clone(), input)
|
.run(
|
||||||
|
Arc::clone(&session_ctx),
|
||||||
|
ctx,
|
||||||
|
sub_clone.clone(),
|
||||||
|
input,
|
||||||
|
task_cancellation_token.child_token(),
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
|
||||||
let sess = session_ctx.clone_session();
|
if !task_cancellation_token.is_cancelled() {
|
||||||
sess.on_task_finished(sub_clone, last_agent_message).await;
|
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||||
|
let sess = session_ctx.clone_session();
|
||||||
|
sess.on_task_finished(sub_clone, last_agent_message).await;
|
||||||
|
}
|
||||||
|
done_clone.notify_waiters();
|
||||||
})
|
})
|
||||||
.abort_handle()
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let running_task = RunningTask {
|
let running_task = RunningTask {
|
||||||
handle,
|
done,
|
||||||
|
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||||
kind: task_kind,
|
kind: task_kind,
|
||||||
task,
|
task,
|
||||||
|
cancellation_token,
|
||||||
};
|
};
|
||||||
self.register_new_active_task(sub_id, running_task).await;
|
self.register_new_active_task(sub_id, running_task).await;
|
||||||
}
|
}
|
||||||
@@ -143,14 +168,24 @@ impl Session {
|
|||||||
task: RunningTask,
|
task: RunningTask,
|
||||||
reason: TurnAbortReason,
|
reason: TurnAbortReason,
|
||||||
) {
|
) {
|
||||||
if task.handle.is_finished() {
|
if task.cancellation_token.is_cancelled() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||||
|
task.cancellation_token.cancel();
|
||||||
let session_task = task.task;
|
let session_task = task.task;
|
||||||
let handle = task.handle;
|
|
||||||
handle.abort();
|
select! {
|
||||||
|
_ = task.done.notified() => {
|
||||||
|
},
|
||||||
|
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
|
||||||
|
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task.handle.abort();
|
||||||
|
|
||||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||||
session_task.abort(session_ctx, &sub_id).await;
|
session_task.abort(session_ctx, &sub_id).await;
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::codex::run_task;
|
use crate::codex::run_task;
|
||||||
@@ -25,8 +26,17 @@ impl SessionTask for RegularTask {
|
|||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
let sess = session.clone_session();
|
let sess = session.clone_session();
|
||||||
run_task(sess, ctx, sub_id, input, TaskKind::Regular).await
|
run_task(
|
||||||
|
sess,
|
||||||
|
ctx,
|
||||||
|
sub_id,
|
||||||
|
input,
|
||||||
|
TaskKind::Regular,
|
||||||
|
cancellation_token,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
use crate::codex::TurnContext;
|
use crate::codex::TurnContext;
|
||||||
use crate::codex::exit_review_mode;
|
use crate::codex::exit_review_mode;
|
||||||
@@ -26,9 +27,18 @@ impl SessionTask for ReviewTask {
|
|||||||
ctx: Arc<TurnContext>,
|
ctx: Arc<TurnContext>,
|
||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<InputItem>,
|
input: Vec<InputItem>,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> Option<String> {
|
) -> Option<String> {
|
||||||
let sess = session.clone_session();
|
let sess = session.clone_session();
|
||||||
run_task(sess, ctx, sub_id, input, TaskKind::Review).await
|
run_task(
|
||||||
|
sess,
|
||||||
|
ctx,
|
||||||
|
sub_id,
|
||||||
|
input,
|
||||||
|
TaskKind::Review,
|
||||||
|
cancellation_token,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {
|
||||||
|
|||||||
Reference in New Issue
Block a user