Handle cancelling/aborting while processing a turn (#5543)
Currently we collect all all turn items in a vector, then we add it to the history on success. This result in losing those items on errors including aborting `ctrl+c`. This PR: - Adds the ability for the tool call to handle cancellation - bubble the turn items up to where we are recording this info Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of error edge cases. The right thing to do is recording to the history on the spot as `items`/`tool calls output` come. However, this isn't possible because of having different `task_kind` that has different `conversation_histories`. The `try_run_turn` has no idea what thread are we using. We cannot also pass an `arc` to the `conversation_histories` because it's a private element of `state`. That's said, `abort` is the most common case and we should cover it until we remove `task kind`
This commit is contained in:
@@ -35,6 +35,22 @@ impl ResponseMock {
|
||||
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Returns true if any captured request contains a `function_call` with the
|
||||
/// provided `call_id`.
|
||||
pub fn saw_function_call(&self, call_id: &str) -> bool {
|
||||
self.requests()
|
||||
.iter()
|
||||
.any(|req| req.has_function_call(call_id))
|
||||
}
|
||||
|
||||
/// Returns the `output` string for a matching `function_call_output` with
|
||||
/// the provided `call_id`, searching across all captured requests.
|
||||
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
||||
self.requests()
|
||||
.iter()
|
||||
.find_map(|req| req.function_call_output_text(call_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -70,6 +86,28 @@ impl ResponsesRequest {
|
||||
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
||||
}
|
||||
|
||||
/// Returns true if this request's `input` contains a `function_call` with
|
||||
/// the specified `call_id`.
|
||||
pub fn has_function_call(&self, call_id: &str) -> bool {
|
||||
self.input().iter().any(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
})
|
||||
}
|
||||
|
||||
/// If present, returns the `output` string of the `function_call_output`
|
||||
/// entry matching `call_id` in this request's `input`.
|
||||
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
||||
let binding = self.input();
|
||||
let item = binding.iter().find(|item| {
|
||||
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||
})?;
|
||||
item.get("output")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string)
|
||||
}
|
||||
|
||||
pub fn header(&self, name: &str) -> Option<String> {
|
||||
self.0
|
||||
.headers
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::protocol::EventMsg;
|
||||
@@ -5,7 +6,9 @@ use codex_core::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::responses::ev_completed;
|
||||
use core_test_support::responses::ev_function_call;
|
||||
use core_test_support::responses::ev_response_created;
|
||||
use core_test_support::responses::mount_sse_once;
|
||||
use core_test_support::responses::mount_sse_sequence;
|
||||
use core_test_support::responses::sse;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
@@ -67,3 +70,98 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// After an interrupt we expect the next request to the model to include both
|
||||
/// the original tool call and an `"aborted"` `function_call_output`. This test
|
||||
/// exercises the follow-up flow: it sends another user turn, inspects the mock
|
||||
/// responses server, and ensures the model receives the synthesized abort.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn interrupt_tool_records_history_entries() {
|
||||
let command = vec![
|
||||
"bash".to_string(),
|
||||
"-lc".to_string(),
|
||||
"sleep 60".to_string(),
|
||||
];
|
||||
let call_id = "call-history";
|
||||
|
||||
let args = json!({
|
||||
"command": command,
|
||||
"timeout_ms": 60_000
|
||||
})
|
||||
.to_string();
|
||||
let first_body = sse(vec![
|
||||
ev_response_created("resp-history"),
|
||||
ev_function_call(call_id, "shell", &args),
|
||||
ev_completed("resp-history"),
|
||||
]);
|
||||
let follow_up_body = sse(vec![
|
||||
ev_response_created("resp-followup"),
|
||||
ev_completed("resp-followup"),
|
||||
]);
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let response_mock = mount_sse_sequence(&server, vec![first_body, follow_up_body]).await;
|
||||
|
||||
let fixture = test_codex().build(&server).await.unwrap();
|
||||
let codex = Arc::clone(&fixture.codex);
|
||||
|
||||
let wait_timeout = Duration::from_millis(100);
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "start history recording".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
|
||||
codex.submit(Op::Interrupt).await.unwrap();
|
||||
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::TurnAborted(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "follow up".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
wait_for_event_with_timeout(
|
||||
&codex,
|
||||
|ev| matches!(ev, EventMsg::TaskComplete(_)),
|
||||
wait_timeout,
|
||||
)
|
||||
.await;
|
||||
|
||||
let requests = response_mock.requests();
|
||||
assert!(
|
||||
requests.len() == 2,
|
||||
"expected two calls to the responses API, got {}",
|
||||
requests.len()
|
||||
);
|
||||
|
||||
assert!(
|
||||
response_mock.saw_function_call(call_id),
|
||||
"function call not recorded in responses payload"
|
||||
);
|
||||
assert_eq!(
|
||||
response_mock.function_call_output_text(call_id).as_deref(),
|
||||
Some("aborted"),
|
||||
"aborted function call output not recorded in responses payload"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user