Handle cancelling/aborting while processing a turn (#5543)
Currently we collect all all turn items in a vector, then we add it to the history on success. This result in losing those items on errors including aborting `ctrl+c`. This PR: - Adds the ability for the tool call to handle cancellation - bubble the turn items up to where we are recording this info Admittedly, this logic is an ad-hoc logic that doesn't handle a lot of error edge cases. The right thing to do is recording to the history on the spot as `items`/`tool calls output` come. However, this isn't possible because of having different `task_kind` that has different `conversation_histories`. The `try_run_turn` has no idea what thread are we using. We cannot also pass an `arc` to the `conversation_histories` because it's a private element of `state`. That's said, `abort` is the most common case and we should cover it until we remove `task kind`
This commit is contained in:
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
@@ -9,8 +10,10 @@ use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
@@ -40,6 +43,7 @@ impl ToolCallRuntime {
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
@@ -48,18 +52,24 @@ impl ToolCallRuntime {
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let aborted_response = Self::aborted_response(&call);
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
||||
res = async {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, call)
|
||||
.await
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, call)
|
||||
.await
|
||||
} => res,
|
||||
}
|
||||
}));
|
||||
|
||||
async move {
|
||||
@@ -74,3 +84,25 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn aborted_response(call: &ToolCall) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
result: Err("aborted".to_string()),
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user