diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 7340f5cc..baa5321d 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Instant; use tokio::sync::RwLock; use tokio_util::either::Either; @@ -53,13 +54,16 @@ impl ToolCallRuntime { let turn = Arc::clone(&self.turn_context); let tracker = Arc::clone(&self.tracker); let lock = Arc::clone(&self.parallel_execution); - let aborted_response = Self::aborted_response(&call); + let started = Instant::now(); let readiness = self.turn_context.tool_call_gate.clone(); let handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { tokio::select! { - _ = cancellation_token.cancelled() => Ok(aborted_response), + _ = cancellation_token.cancelled() => { + let secs = started.elapsed().as_secs_f32().max(0.1); + Ok(Self::aborted_response(&call, secs)) + }, res = async { tracing::info!("waiting for tool gate"); readiness.wait_ready().await; @@ -71,7 +75,7 @@ impl ToolCallRuntime { }; router - .dispatch_tool_call(session, turn, tracker, call) + .dispatch_tool_call(session, turn, tracker, call.clone()) .await } => res, } @@ -91,23 +95,32 @@ impl ToolCallRuntime { } impl ToolCallRuntime { - fn aborted_response(call: &ToolCall) -> ResponseInputItem { + fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem { match &call.payload { ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput { call_id: call.call_id.clone(), - output: "aborted".to_string(), + output: Self::abort_message(call, secs), }, ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput { call_id: call.call_id.clone(), - result: Err("aborted".to_string()), + result: Err(Self::abort_message(call, secs)), }, _ => ResponseInputItem::FunctionCallOutput { call_id: call.call_id.clone(), output: FunctionCallOutputPayload { - content: "aborted".to_string(), + content: Self::abort_message(call, secs), ..Default::default() }, }, } } + + fn abort_message(call: &ToolCall, secs: f32) -> String { + match call.tool_name.as_str() { + "shell" | "container.exec" | "local_shell" | "unified_exec" => { + format!("Wall time: {secs:.1} seconds\naborted by user") + } + _ => format!("aborted by user after {secs:.1}s"), + } + } } diff --git a/codex-rs/core/tests/suite/abort_tasks.rs b/codex-rs/core/tests/suite/abort_tasks.rs index c9d59508..6c037395 100644 --- a/codex-rs/core/tests/suite/abort_tasks.rs +++ b/codex-rs/core/tests/suite/abort_tasks.rs @@ -1,3 +1,4 @@ +use assert_matches::assert_matches; use std::sync::Arc; use std::time::Duration; @@ -13,6 +14,7 @@ use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event_with_timeout; +use regex_lite::Regex; use serde_json::json; /// Integration test: spawn a long‑running shell tool via a mocked Responses SSE @@ -123,6 +125,7 @@ async fn interrupt_tool_records_history_entries() { ) .await; + tokio::time::sleep(Duration::from_secs_f32(0.1)).await; codex.submit(Op::Interrupt).await.unwrap(); wait_for_event_with_timeout( @@ -159,9 +162,26 @@ async fn interrupt_tool_records_history_entries() { response_mock.saw_function_call(call_id), "function call not recorded in responses payload" ); - assert_eq!( - response_mock.function_call_output_text(call_id).as_deref(), - Some("aborted"), - "aborted function call output not recorded in responses payload" + let output = response_mock + .function_call_output_text(call_id) + .expect("missing function_call_output text"); + let re = Regex::new(r"^Wall time: ([0-9]+(?:\.[0-9])?) seconds\naborted by user$") + .expect("compile regex"); + let captures = re.captures(&output); + assert_matches!( + captures.as_ref(), + Some(caps) if caps.get(1).is_some(), + "aborted message with elapsed seconds" + ); + let secs: f32 = captures + .expect("aborted message with elapsed seconds") + .get(1) + .unwrap() + .as_str() + .parse() + .unwrap(); + assert!( + secs >= 0.1, + "expected at least one tenth of a second of elapsed time, got {secs}" ); }