feat: add the time after aborting (#5996)

Tell the model how much time passed after the user aborted the call.
This commit is contained in:
Ahmed Ibrahim
2025-11-03 11:44:06 -08:00
committed by GitHub
parent 5f3a0473f1
commit 6ee7fbcfff
2 changed files with 44 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tokio_util::either::Either;
@@ -53,13 +54,16 @@ impl ToolCallRuntime {
let turn = Arc::clone(&self.turn_context);
let tracker = Arc::clone(&self.tracker);
let lock = Arc::clone(&self.parallel_execution);
let aborted_response = Self::aborted_response(&call);
let started = Instant::now();
let readiness = self.turn_context.tool_call_gate.clone();
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => Ok(aborted_response),
_ = cancellation_token.cancelled() => {
let secs = started.elapsed().as_secs_f32().max(0.1);
Ok(Self::aborted_response(&call, secs))
},
res = async {
tracing::info!("waiting for tool gate");
readiness.wait_ready().await;
@@ -71,7 +75,7 @@ impl ToolCallRuntime {
};
router
.dispatch_tool_call(session, turn, tracker, call)
.dispatch_tool_call(session, turn, tracker, call.clone())
.await
} => res,
}
@@ -91,23 +95,32 @@ impl ToolCallRuntime {
}
impl ToolCallRuntime {
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
match &call.payload {
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
call_id: call.call_id.clone(),
output: "aborted".to_string(),
output: Self::abort_message(call, secs),
},
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
call_id: call.call_id.clone(),
result: Err("aborted".to_string()),
result: Err(Self::abort_message(call, secs)),
},
_ => ResponseInputItem::FunctionCallOutput {
call_id: call.call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
content: Self::abort_message(call, secs),
..Default::default()
},
},
}
}
fn abort_message(call: &ToolCall, secs: f32) -> String {
match call.tool_name.as_str() {
"shell" | "container.exec" | "local_shell" | "unified_exec" => {
format!("Wall time: {secs:.1} seconds\naborted by user")
}
_ => format!("aborted by user after {secs:.1}s"),
}
}
}

View File

@@ -1,3 +1,4 @@
use assert_matches::assert_matches;
use std::sync::Arc;
use std::time::Duration;
@@ -13,6 +14,7 @@ use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::test_codex::test_codex;
use core_test_support::wait_for_event_with_timeout;
use regex_lite::Regex;
use serde_json::json;
/// Integration test: spawn a longrunning shell tool via a mocked Responses SSE
@@ -123,6 +125,7 @@ async fn interrupt_tool_records_history_entries() {
)
.await;
tokio::time::sleep(Duration::from_secs_f32(0.1)).await;
codex.submit(Op::Interrupt).await.unwrap();
wait_for_event_with_timeout(
@@ -159,9 +162,26 @@ async fn interrupt_tool_records_history_entries() {
response_mock.saw_function_call(call_id),
"function call not recorded in responses payload"
);
assert_eq!(
response_mock.function_call_output_text(call_id).as_deref(),
Some("aborted"),
"aborted function call output not recorded in responses payload"
let output = response_mock
.function_call_output_text(call_id)
.expect("missing function_call_output text");
let re = Regex::new(r"^Wall time: ([0-9]+(?:\.[0-9])?) seconds\naborted by user$")
.expect("compile regex");
let captures = re.captures(&output);
assert_matches!(
captures.as_ref(),
Some(caps) if caps.get(1).is_some(),
"aborted message with elapsed seconds"
);
let secs: f32 = captures
.expect("aborted message with elapsed seconds")
.get(1)
.unwrap()
.as_str()
.parse()
.unwrap();
assert!(
secs >= 0.1,
"expected at least one tenth of a second of elapsed time, got {secs}"
);
}