feat: add the time after aborting (#5996)
Tell the model how much time passed after the user aborted the call.
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
@@ -53,13 +54,16 @@ impl ToolCallRuntime {
|
|||||||
let turn = Arc::clone(&self.turn_context);
|
let turn = Arc::clone(&self.turn_context);
|
||||||
let tracker = Arc::clone(&self.tracker);
|
let tracker = Arc::clone(&self.tracker);
|
||||||
let lock = Arc::clone(&self.parallel_execution);
|
let lock = Arc::clone(&self.parallel_execution);
|
||||||
let aborted_response = Self::aborted_response(&call);
|
let started = Instant::now();
|
||||||
let readiness = self.turn_context.tool_call_gate.clone();
|
let readiness = self.turn_context.tool_call_gate.clone();
|
||||||
|
|
||||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
_ = cancellation_token.cancelled() => {
|
||||||
|
let secs = started.elapsed().as_secs_f32().max(0.1);
|
||||||
|
Ok(Self::aborted_response(&call, secs))
|
||||||
|
},
|
||||||
res = async {
|
res = async {
|
||||||
tracing::info!("waiting for tool gate");
|
tracing::info!("waiting for tool gate");
|
||||||
readiness.wait_ready().await;
|
readiness.wait_ready().await;
|
||||||
@@ -71,7 +75,7 @@ impl ToolCallRuntime {
|
|||||||
};
|
};
|
||||||
|
|
||||||
router
|
router
|
||||||
.dispatch_tool_call(session, turn, tracker, call)
|
.dispatch_tool_call(session, turn, tracker, call.clone())
|
||||||
.await
|
.await
|
||||||
} => res,
|
} => res,
|
||||||
}
|
}
|
||||||
@@ -91,23 +95,32 @@ impl ToolCallRuntime {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ToolCallRuntime {
|
impl ToolCallRuntime {
|
||||||
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
|
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
|
||||||
match &call.payload {
|
match &call.payload {
|
||||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||||
call_id: call.call_id.clone(),
|
call_id: call.call_id.clone(),
|
||||||
output: "aborted".to_string(),
|
output: Self::abort_message(call, secs),
|
||||||
},
|
},
|
||||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||||
call_id: call.call_id.clone(),
|
call_id: call.call_id.clone(),
|
||||||
result: Err("aborted".to_string()),
|
result: Err(Self::abort_message(call, secs)),
|
||||||
},
|
},
|
||||||
_ => ResponseInputItem::FunctionCallOutput {
|
_ => ResponseInputItem::FunctionCallOutput {
|
||||||
call_id: call.call_id.clone(),
|
call_id: call.call_id.clone(),
|
||||||
output: FunctionCallOutputPayload {
|
output: FunctionCallOutputPayload {
|
||||||
content: "aborted".to_string(),
|
content: Self::abort_message(call, secs),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn abort_message(call: &ToolCall, secs: f32) -> String {
|
||||||
|
match call.tool_name.as_str() {
|
||||||
|
"shell" | "container.exec" | "local_shell" | "unified_exec" => {
|
||||||
|
format!("Wall time: {secs:.1} seconds\naborted by user")
|
||||||
|
}
|
||||||
|
_ => format!("aborted by user after {secs:.1}s"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use assert_matches::assert_matches;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@ use core_test_support::responses::sse;
|
|||||||
use core_test_support::responses::start_mock_server;
|
use core_test_support::responses::start_mock_server;
|
||||||
use core_test_support::test_codex::test_codex;
|
use core_test_support::test_codex::test_codex;
|
||||||
use core_test_support::wait_for_event_with_timeout;
|
use core_test_support::wait_for_event_with_timeout;
|
||||||
|
use regex_lite::Regex;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
||||||
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
/// Integration test: spawn a long‑running shell tool via a mocked Responses SSE
|
||||||
@@ -123,6 +125,7 @@ async fn interrupt_tool_records_history_entries() {
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_secs_f32(0.1)).await;
|
||||||
codex.submit(Op::Interrupt).await.unwrap();
|
codex.submit(Op::Interrupt).await.unwrap();
|
||||||
|
|
||||||
wait_for_event_with_timeout(
|
wait_for_event_with_timeout(
|
||||||
@@ -159,9 +162,26 @@ async fn interrupt_tool_records_history_entries() {
|
|||||||
response_mock.saw_function_call(call_id),
|
response_mock.saw_function_call(call_id),
|
||||||
"function call not recorded in responses payload"
|
"function call not recorded in responses payload"
|
||||||
);
|
);
|
||||||
assert_eq!(
|
let output = response_mock
|
||||||
response_mock.function_call_output_text(call_id).as_deref(),
|
.function_call_output_text(call_id)
|
||||||
Some("aborted"),
|
.expect("missing function_call_output text");
|
||||||
"aborted function call output not recorded in responses payload"
|
let re = Regex::new(r"^Wall time: ([0-9]+(?:\.[0-9])?) seconds\naborted by user$")
|
||||||
|
.expect("compile regex");
|
||||||
|
let captures = re.captures(&output);
|
||||||
|
assert_matches!(
|
||||||
|
captures.as_ref(),
|
||||||
|
Some(caps) if caps.get(1).is_some(),
|
||||||
|
"aborted message with elapsed seconds"
|
||||||
|
);
|
||||||
|
let secs: f32 = captures
|
||||||
|
.expect("aborted message with elapsed seconds")
|
||||||
|
.get(1)
|
||||||
|
.unwrap()
|
||||||
|
.as_str()
|
||||||
|
.parse()
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
secs >= 0.1,
|
||||||
|
"expected at least one tenth of a second of elapsed time, got {secs}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user