Handle cancelling/aborting while processing a turn (#5543)
Currently we collect all all turn items in a vector, then we add it to the history on success. This result in losing those items on errors including aborting `ctrl+c`. This PR: - Adds the ability for the tool call to handle cancellation - bubble the turn items up to where we are recording this info Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of error edge cases. The right thing to do is recording to the history on the spot as `items`/`tool calls output` come. However, this isn't possible because of having different `task_kind` that has different `conversation_histories`. The `try_run_turn` has no idea what thread are we using. We cannot also pass an `arc` to the `conversation_histories` because it's a private element of `state`. That's said, `abort` is the most common case and we should cover it until we remove `task kind`
This commit is contained in:
@@ -10,6 +10,7 @@ use crate::function_tool::FunctionCallError;
|
|||||||
use crate::mcp::auth::McpAuthStatusEntry;
|
use crate::mcp::auth::McpAuthStatusEntry;
|
||||||
use crate::parse_command::parse_command;
|
use crate::parse_command::parse_command;
|
||||||
use crate::parse_turn_item;
|
use crate::parse_turn_item;
|
||||||
|
use crate::response_processing::process_items;
|
||||||
use crate::review_format::format_review_findings_block;
|
use crate::review_format::format_review_findings_block;
|
||||||
use crate::terminal;
|
use crate::terminal;
|
||||||
use crate::user_notification::UserNotifier;
|
use crate::user_notification::UserNotifier;
|
||||||
@@ -855,7 +856,7 @@ impl Session {
|
|||||||
|
|
||||||
/// Records input items: always append to conversation history and
|
/// Records input items: always append to conversation history and
|
||||||
/// persist these response items to rollout.
|
/// persist these response items to rollout.
|
||||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
pub(crate) async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||||
self.record_into_history(items).await;
|
self.record_into_history(items).await;
|
||||||
self.persist_rollout_response_items(items).await;
|
self.persist_rollout_response_items(items).await;
|
||||||
}
|
}
|
||||||
@@ -1608,109 +1609,13 @@ pub(crate) async fn run_task(
|
|||||||
let token_limit_reached = total_usage_tokens
|
let token_limit_reached = total_usage_tokens
|
||||||
.map(|tokens| tokens >= limit)
|
.map(|tokens| tokens >= limit)
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
let (responses, items_to_record_in_conversation_history) = process_items(
|
||||||
let mut responses = Vec::<ResponseInputItem>::new();
|
processed_items,
|
||||||
for processed_response_item in processed_items {
|
is_review_mode,
|
||||||
let ProcessedResponseItem { item, response } = processed_response_item;
|
&mut review_thread_history,
|
||||||
match (&item, &response) {
|
&sess,
|
||||||
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
)
|
||||||
// If the model returned a message, we need to record it.
|
.await;
|
||||||
items_to_record_in_conversation_history.push(item);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
ResponseItem::LocalShellCall { .. },
|
|
||||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
|
||||||
) => {
|
|
||||||
items_to_record_in_conversation_history.push(item);
|
|
||||||
items_to_record_in_conversation_history.push(
|
|
||||||
ResponseItem::FunctionCallOutput {
|
|
||||||
call_id: call_id.clone(),
|
|
||||||
output: output.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
ResponseItem::FunctionCall { .. },
|
|
||||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
|
||||||
) => {
|
|
||||||
items_to_record_in_conversation_history.push(item);
|
|
||||||
items_to_record_in_conversation_history.push(
|
|
||||||
ResponseItem::FunctionCallOutput {
|
|
||||||
call_id: call_id.clone(),
|
|
||||||
output: output.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
ResponseItem::CustomToolCall { .. },
|
|
||||||
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
|
|
||||||
) => {
|
|
||||||
items_to_record_in_conversation_history.push(item);
|
|
||||||
items_to_record_in_conversation_history.push(
|
|
||||||
ResponseItem::CustomToolCallOutput {
|
|
||||||
call_id: call_id.clone(),
|
|
||||||
output: output.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
ResponseItem::FunctionCall { .. },
|
|
||||||
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
|
||||||
) => {
|
|
||||||
items_to_record_in_conversation_history.push(item);
|
|
||||||
let output = match result {
|
|
||||||
Ok(call_tool_result) => {
|
|
||||||
convert_call_tool_result_to_function_call_output_payload(
|
|
||||||
call_tool_result,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Err(err) => FunctionCallOutputPayload {
|
|
||||||
content: err.clone(),
|
|
||||||
success: Some(false),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
items_to_record_in_conversation_history.push(
|
|
||||||
ResponseItem::FunctionCallOutput {
|
|
||||||
call_id: call_id.clone(),
|
|
||||||
output,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
ResponseItem::Reasoning {
|
|
||||||
id,
|
|
||||||
summary,
|
|
||||||
content,
|
|
||||||
encrypted_content,
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
) => {
|
|
||||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
|
||||||
id: id.clone(),
|
|
||||||
summary: summary.clone(),
|
|
||||||
content: content.clone(),
|
|
||||||
encrypted_content: encrypted_content.clone(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if let Some(response) = response {
|
|
||||||
responses.push(response);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only attempt to take the lock if there is something to record.
|
|
||||||
if !items_to_record_in_conversation_history.is_empty() {
|
|
||||||
if is_review_mode {
|
|
||||||
review_thread_history
|
|
||||||
.record_items(items_to_record_in_conversation_history.iter());
|
|
||||||
} else {
|
|
||||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if token_limit_reached {
|
if token_limit_reached {
|
||||||
if auto_compact_recently_attempted {
|
if auto_compact_recently_attempted {
|
||||||
@@ -1749,7 +1654,16 @@ pub(crate) async fn run_task(
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Err(CodexErr::TurnAborted) => {
|
Err(CodexErr::TurnAborted {
|
||||||
|
dangling_artifacts: processed_items,
|
||||||
|
}) => {
|
||||||
|
let _ = process_items(
|
||||||
|
processed_items,
|
||||||
|
is_review_mode,
|
||||||
|
&mut review_thread_history,
|
||||||
|
&sess,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
// Aborted turn is reported via a different event.
|
// Aborted turn is reported via a different event.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -1850,7 +1764,13 @@ async fn run_turn(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(output) => return Ok(output),
|
Ok(output) => return Ok(output),
|
||||||
Err(CodexErr::TurnAborted) => return Err(CodexErr::TurnAborted),
|
Err(CodexErr::TurnAborted {
|
||||||
|
dangling_artifacts: processed_items,
|
||||||
|
}) => {
|
||||||
|
return Err(CodexErr::TurnAborted {
|
||||||
|
dangling_artifacts: processed_items,
|
||||||
|
});
|
||||||
|
}
|
||||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||||
@@ -1903,9 +1823,9 @@ async fn run_turn(
|
|||||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||||
/// sent back to the model on the next turn.
|
/// sent back to the model on the next turn.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct ProcessedResponseItem {
|
pub struct ProcessedResponseItem {
|
||||||
pub(crate) item: ResponseItem,
|
pub item: ResponseItem,
|
||||||
pub(crate) response: Option<ResponseInputItem>,
|
pub response: Option<ResponseInputItem>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -1954,7 +1874,15 @@ async fn try_run_turn(
|
|||||||
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
// Poll the next item from the model stream. We must inspect *both* Ok and Err
|
||||||
// cases so that transient stream failures (e.g., dropped SSE connection before
|
// cases so that transient stream failures (e.g., dropped SSE connection before
|
||||||
// `response.completed`) bubble up and trigger the caller's retry logic.
|
// `response.completed`) bubble up and trigger the caller's retry logic.
|
||||||
let event = stream.next().or_cancel(&cancellation_token).await?;
|
let event = match stream.next().or_cancel(&cancellation_token).await {
|
||||||
|
Ok(event) => event,
|
||||||
|
Err(codex_async_utils::CancelErr::Cancelled) => {
|
||||||
|
let processed_items = output.try_collect().await?;
|
||||||
|
return Err(CodexErr::TurnAborted {
|
||||||
|
dangling_artifacts: processed_items,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let event = match event {
|
let event = match event {
|
||||||
Some(res) => res?,
|
Some(res) => res?,
|
||||||
@@ -1978,7 +1906,8 @@ async fn try_run_turn(
|
|||||||
let payload_preview = call.payload.log_payload().into_owned();
|
let payload_preview = call.payload.log_payload().into_owned();
|
||||||
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview);
|
||||||
|
|
||||||
let response = tool_runtime.handle_tool_call(call);
|
let response =
|
||||||
|
tool_runtime.handle_tool_call(call, cancellation_token.child_token());
|
||||||
|
|
||||||
output.push_back(
|
output.push_back(
|
||||||
async move {
|
async move {
|
||||||
@@ -2060,12 +1989,7 @@ async fn try_run_turn(
|
|||||||
} => {
|
} => {
|
||||||
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
|
sess.update_token_usage_info(turn_context.as_ref(), token_usage.as_ref())
|
||||||
.await;
|
.await;
|
||||||
|
let processed_items = output.try_collect().await?;
|
||||||
let processed_items = output
|
|
||||||
.try_collect()
|
|
||||||
.or_cancel(&cancellation_token)
|
|
||||||
.await??;
|
|
||||||
|
|
||||||
let unified_diff = {
|
let unified_diff = {
|
||||||
let mut tracker = turn_diff_tracker.lock().await;
|
let mut tracker = turn_diff_tracker.lock().await;
|
||||||
tracker.get_unified_diff()
|
tracker.get_unified_diff()
|
||||||
@@ -2169,7 +2093,7 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
fn convert_call_tool_result_to_function_call_output_payload(
|
pub(crate) fn convert_call_tool_result_to_function_call_output_payload(
|
||||||
call_tool_result: &CallToolResult,
|
call_tool_result: &CallToolResult,
|
||||||
) -> FunctionCallOutputPayload {
|
) -> FunctionCallOutputPayload {
|
||||||
let CallToolResult {
|
let CallToolResult {
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::codex::ProcessedResponseItem;
|
||||||
use crate::exec::ExecToolCallOutput;
|
use crate::exec::ExecToolCallOutput;
|
||||||
use crate::token_data::KnownPlan;
|
use crate::token_data::KnownPlan;
|
||||||
use crate::token_data::PlanType;
|
use crate::token_data::PlanType;
|
||||||
@@ -53,8 +54,11 @@ pub enum SandboxErr {
|
|||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum CodexErr {
|
pub enum CodexErr {
|
||||||
|
// todo(aibrahim): git rid of this error carrying the dangling artifacts
|
||||||
#[error("turn aborted")]
|
#[error("turn aborted")]
|
||||||
TurnAborted,
|
TurnAborted {
|
||||||
|
dangling_artifacts: Vec<ProcessedResponseItem>,
|
||||||
|
},
|
||||||
|
|
||||||
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
||||||
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
||||||
@@ -158,7 +162,9 @@ pub enum CodexErr {
|
|||||||
|
|
||||||
impl From<CancelErr> for CodexErr {
|
impl From<CancelErr> for CodexErr {
|
||||||
fn from(_: CancelErr) -> Self {
|
fn from(_: CancelErr) -> Self {
|
||||||
CodexErr::TurnAborted
|
CodexErr::TurnAborted {
|
||||||
|
dangling_artifacts: Vec::new(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ mod mcp_tool_call;
|
|||||||
mod message_history;
|
mod message_history;
|
||||||
mod model_provider_info;
|
mod model_provider_info;
|
||||||
pub mod parse_command;
|
pub mod parse_command;
|
||||||
|
mod response_processing;
|
||||||
pub mod sandboxing;
|
pub mod sandboxing;
|
||||||
pub mod token_data;
|
pub mod token_data;
|
||||||
mod truncate;
|
mod truncate;
|
||||||
|
|||||||
112
codex-rs/core/src/response_processing.rs
Normal file
112
codex-rs/core/src/response_processing.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
use crate::codex::Session;
|
||||||
|
use crate::conversation_history::ConversationHistory;
|
||||||
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
use codex_protocol::models::ResponseItem;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
/// Process streamed `ResponseItem`s from the model into the pair of:
|
||||||
|
/// - items we should record in conversation history; and
|
||||||
|
/// - `ResponseInputItem`s to send back to the model on the next turn.
|
||||||
|
pub(crate) async fn process_items(
|
||||||
|
processed_items: Vec<crate::codex::ProcessedResponseItem>,
|
||||||
|
is_review_mode: bool,
|
||||||
|
review_thread_history: &mut ConversationHistory,
|
||||||
|
sess: &Session,
|
||||||
|
) -> (Vec<ResponseInputItem>, Vec<ResponseItem>) {
|
||||||
|
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||||
|
let mut responses = Vec::<ResponseInputItem>::new();
|
||||||
|
for processed_response_item in processed_items {
|
||||||
|
let crate::codex::ProcessedResponseItem { item, response } = processed_response_item;
|
||||||
|
match (&item, &response) {
|
||||||
|
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
||||||
|
// If the model returned a message, we need to record it.
|
||||||
|
items_to_record_in_conversation_history.push(item);
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ResponseItem::LocalShellCall { .. },
|
||||||
|
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(item);
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
output: output.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ResponseItem::FunctionCall { .. },
|
||||||
|
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(item);
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
output: output.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ResponseItem::CustomToolCall { .. },
|
||||||
|
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(item);
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput {
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
output: output.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ResponseItem::FunctionCall { .. },
|
||||||
|
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(item);
|
||||||
|
let output = match result {
|
||||||
|
Ok(call_tool_result) => {
|
||||||
|
crate::codex::convert_call_tool_result_to_function_call_output_payload(
|
||||||
|
call_tool_result,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Err(err) => FunctionCallOutputPayload {
|
||||||
|
content: err.clone(),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
output,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
(
|
||||||
|
ResponseItem::Reasoning {
|
||||||
|
id,
|
||||||
|
summary,
|
||||||
|
content,
|
||||||
|
encrypted_content,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||||
|
id: id.clone(),
|
||||||
|
summary: summary.clone(),
|
||||||
|
content: content.clone(),
|
||||||
|
encrypted_content: encrypted_content.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Some(response) = response {
|
||||||
|
responses.push(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only attempt to take the lock if there is something to record.
|
||||||
|
if !items_to_record_in_conversation_history.is_empty() {
|
||||||
|
if is_review_mode {
|
||||||
|
review_thread_history.record_items(items_to_record_in_conversation_history.iter());
|
||||||
|
} else {
|
||||||
|
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(responses, items_to_record_in_conversation_history)
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tokio_util::either::Either;
|
use tokio_util::either::Either;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
use tokio_util::task::AbortOnDropHandle;
|
use tokio_util::task::AbortOnDropHandle;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
@@ -9,8 +10,10 @@ use crate::codex::TurnContext;
|
|||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::tools::context::SharedTurnDiffTracker;
|
use crate::tools::context::SharedTurnDiffTracker;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
use crate::tools::router::ToolCall;
|
use crate::tools::router::ToolCall;
|
||||||
use crate::tools::router::ToolRouter;
|
use crate::tools::router::ToolRouter;
|
||||||
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
use codex_protocol::models::ResponseInputItem;
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
|
||||||
pub(crate) struct ToolCallRuntime {
|
pub(crate) struct ToolCallRuntime {
|
||||||
@@ -40,6 +43,7 @@ impl ToolCallRuntime {
|
|||||||
pub(crate) fn handle_tool_call(
|
pub(crate) fn handle_tool_call(
|
||||||
&self,
|
&self,
|
||||||
call: ToolCall,
|
call: ToolCall,
|
||||||
|
cancellation_token: CancellationToken,
|
||||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||||
|
|
||||||
@@ -48,18 +52,24 @@ impl ToolCallRuntime {
|
|||||||
let turn = Arc::clone(&self.turn_context);
|
let turn = Arc::clone(&self.turn_context);
|
||||||
let tracker = Arc::clone(&self.tracker);
|
let tracker = Arc::clone(&self.tracker);
|
||||||
let lock = Arc::clone(&self.parallel_execution);
|
let lock = Arc::clone(&self.parallel_execution);
|
||||||
|
let aborted_response = Self::aborted_response(&call);
|
||||||
|
|
||||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||||
let _guard = if supports_parallel {
|
tokio::select! {
|
||||||
Either::Left(lock.read().await)
|
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
||||||
} else {
|
res = async {
|
||||||
Either::Right(lock.write().await)
|
let _guard = if supports_parallel {
|
||||||
};
|
Either::Left(lock.read().await)
|
||||||
|
} else {
|
||||||
|
Either::Right(lock.write().await)
|
||||||
|
};
|
||||||
|
|
||||||
router
|
router
|
||||||
.dispatch_tool_call(session, turn, tracker, call)
|
.dispatch_tool_call(session, turn, tracker, call)
|
||||||
.await
|
.await
|
||||||
|
} => res,
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
@@ -74,3 +84,25 @@ impl ToolCallRuntime {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ToolCallRuntime {
|
||||||
|
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
|
||||||
|
match &call.payload {
|
||||||
|
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||||
|
call_id: call.call_id.clone(),
|
||||||
|
output: "aborted".to_string(),
|
||||||
|
},
|
||||||
|
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||||
|
call_id: call.call_id.clone(),
|
||||||
|
result: Err("aborted".to_string()),
|
||||||
|
},
|
||||||
|
_ => ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id: call.call_id.clone(),
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: "aborted".to_string(),
|
||||||
|
success: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -35,6 +35,22 @@ impl ResponseMock {
|
|||||||
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
pub fn requests(&self) -> Vec<ResponsesRequest> {
|
||||||
self.requests.lock().unwrap().clone()
|
self.requests.lock().unwrap().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if any captured request contains a `function_call` with the
|
||||||
|
/// provided `call_id`.
|
||||||
|
pub fn saw_function_call(&self, call_id: &str) -> bool {
|
||||||
|
self.requests()
|
||||||
|
.iter()
|
||||||
|
.any(|req| req.has_function_call(call_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the `output` string for a matching `function_call_output` with
|
||||||
|
/// the provided `call_id`, searching across all captured requests.
|
||||||
|
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
||||||
|
self.requests()
|
||||||
|
.iter()
|
||||||
|
.find_map(|req| req.function_call_output_text(call_id))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -70,6 +86,28 @@ impl ResponsesRequest {
|
|||||||
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if this request's `input` contains a `function_call` with
|
||||||
|
/// the specified `call_id`.
|
||||||
|
pub fn has_function_call(&self, call_id: &str) -> bool {
|
||||||
|
self.input().iter().any(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("function_call")
|
||||||
|
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If present, returns the `output` string of the `function_call_output`
|
||||||
|
/// entry matching `call_id` in this request's `input`.
|
||||||
|
pub fn function_call_output_text(&self, call_id: &str) -> Option<String> {
|
||||||
|
let binding = self.input();
|
||||||
|
let item = binding.iter().find(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||||
|
&& item.get("call_id").and_then(Value::as_str) == Some(call_id)
|
||||||
|
})?;
|
||||||
|
item.get("output")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(str::to_string)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn header(&self, name: &str) -> Option<String> {
|
pub fn header(&self, name: &str) -> Option<String> {
|
||||||
self.0
|
self.0
|
||||||
.headers
|
.headers
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use codex_core::protocol::EventMsg;
|
use codex_core::protocol::EventMsg;
|
||||||
@@ -5,7 +6,9 @@ use codex_core::protocol::Op;
|
|||||||
use codex_protocol::user_input::UserInput;
|
use codex_protocol::user_input::UserInput;
|
||||||
use core_test_support::responses::ev_completed;
|
use core_test_support::responses::ev_completed;
|
||||||
use core_test_support::responses::ev_function_call;
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::ev_response_created;
|
||||||
use core_test_support::responses::mount_sse_once;
|
use core_test_support::responses::mount_sse_once;
|
||||||
|
use core_test_support::responses::mount_sse_sequence;
|
||||||
use core_test_support::responses::sse;
|
use core_test_support::responses::sse;
|
||||||
use core_test_support::responses::start_mock_server;
|
use core_test_support::responses::start_mock_server;
|
||||||
use core_test_support::test_codex::test_codex;
|
use core_test_support::test_codex::test_codex;
|
||||||
@@ -67,3 +70,98 @@ async fn interrupt_long_running_tool_emits_turn_aborted() {
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// After an interrupt we expect the next request to the model to include both
|
||||||
|
/// the original tool call and an `"aborted"` `function_call_output`. This test
|
||||||
|
/// exercises the follow-up flow: it sends another user turn, inspects the mock
|
||||||
|
/// responses server, and ensures the model receives the synthesized abort.
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn interrupt_tool_records_history_entries() {
|
||||||
|
let command = vec![
|
||||||
|
"bash".to_string(),
|
||||||
|
"-lc".to_string(),
|
||||||
|
"sleep 60".to_string(),
|
||||||
|
];
|
||||||
|
let call_id = "call-history";
|
||||||
|
|
||||||
|
let args = json!({
|
||||||
|
"command": command,
|
||||||
|
"timeout_ms": 60_000
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
let first_body = sse(vec![
|
||||||
|
ev_response_created("resp-history"),
|
||||||
|
ev_function_call(call_id, "shell", &args),
|
||||||
|
ev_completed("resp-history"),
|
||||||
|
]);
|
||||||
|
let follow_up_body = sse(vec![
|
||||||
|
ev_response_created("resp-followup"),
|
||||||
|
ev_completed("resp-followup"),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let response_mock = mount_sse_sequence(&server, vec![first_body, follow_up_body]).await;
|
||||||
|
|
||||||
|
let fixture = test_codex().build(&server).await.unwrap();
|
||||||
|
let codex = Arc::clone(&fixture.codex);
|
||||||
|
|
||||||
|
let wait_timeout = Duration::from_millis(100);
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![UserInput::Text {
|
||||||
|
text: "start history recording".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
wait_for_event_with_timeout(
|
||||||
|
&codex,
|
||||||
|
|ev| matches!(ev, EventMsg::ExecCommandBegin(_)),
|
||||||
|
wait_timeout,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
codex.submit(Op::Interrupt).await.unwrap();
|
||||||
|
|
||||||
|
wait_for_event_with_timeout(
|
||||||
|
&codex,
|
||||||
|
|ev| matches!(ev, EventMsg::TurnAborted(_)),
|
||||||
|
wait_timeout,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![UserInput::Text {
|
||||||
|
text: "follow up".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
wait_for_event_with_timeout(
|
||||||
|
&codex,
|
||||||
|
|ev| matches!(ev, EventMsg::TaskComplete(_)),
|
||||||
|
wait_timeout,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let requests = response_mock.requests();
|
||||||
|
assert!(
|
||||||
|
requests.len() == 2,
|
||||||
|
"expected two calls to the responses API, got {}",
|
||||||
|
requests.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
response_mock.saw_function_call(call_id),
|
||||||
|
"function call not recorded in responses payload"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
response_mock.function_call_output_text(call_id).as_deref(),
|
||||||
|
Some("aborted"),
|
||||||
|
"aborted function call output not recorded in responses payload"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user