From a339a7bcce153974d7f590b38f6e53dd01c2cc66 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 26 Jun 2025 14:40:42 -0400 Subject: [PATCH] [Rust] Allow resuming a session that was killed with ctrl + c (#1387) Previously, if you ctrl+c'd a conversation, all subsequent turns would 400 because the Responses API never got a response for one of its call ids. This ensures that if we aren't sending a call id by hand, we generate a synthetic aborted call. Fixes #1244 https://github.com/user-attachments/assets/5126354f-b970-45f5-8c65-f811bca8294a --- codex-rs/core/src/chat_completions.rs | 9 ++- codex-rs/core/src/client.rs | 11 ++- codex-rs/core/src/client_common.rs | 1 + codex-rs/core/src/codex.rs | 106 ++++++++++++++++++++++---- 4 files changed, 108 insertions(+), 19 deletions(-) diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 12c5b7af..dfe06d1f 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -425,7 +425,12 @@ where response_id, token_usage, }))); - } // No other `Ok` variants exist at the moment, continue polling. + } + Poll::Ready(Some(Ok(ResponseEvent::Created))) => { + // These events are exclusive to the Responses API and + // will never appear in a Chat Completions stream. + continue; + } } } } @@ -439,7 +444,7 @@ pub(crate) trait AggregateStreamExt: Stream> + Size /// /// ```ignore /// OutputItemDone() - /// Completed { .. } + /// Completed /// ``` /// /// No other `OutputItemDone` events will be seen by the caller. diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 4770796d..6daa3a89 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -168,7 +168,7 @@ impl ModelClient { // negligible. if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { // Surface the error body to callers. Use `unwrap_or_default` per Clippy. - let body = (res.text().await).unwrap_or_default(); + let body = res.text().await.unwrap_or_default(); return Err(CodexErr::UnexpectedStatus(status, body)); } @@ -208,6 +208,9 @@ struct SseEvent { item: Option, } +#[derive(Debug, Deserialize)] +struct ResponseCreated {} + #[derive(Debug, Deserialize)] struct ResponseCompleted { id: String, @@ -335,6 +338,11 @@ where return; } } + "response.created" => { + if event.response.is_some() { + let _ = tx_event.send(Ok(ResponseEvent::Created {})).await; + } + } // Final response completed – includes array of output items & id "response.completed" => { if let Some(resp_val) = event.response { @@ -350,7 +358,6 @@ where }; } "response.content_part.done" - | "response.created" | "response.function_call_arguments.delta" | "response.in_progress" | "response.output_item.added" diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index e17cf22c..b08880a0 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -51,6 +51,7 @@ impl Prompt { #[derive(Debug)] pub enum ResponseEvent { + Created, OutputItemDone(ResponseItem), Completed { response_id: String, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index a43f75a7..ec6e0bd1 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1,6 +1,7 @@ // Poisoned mutex should fail the program #![allow(clippy::unwrap_used)] +use std::borrow::Cow; use std::collections::HashMap; use std::collections::HashSet; use std::path::Path; @@ -188,7 +189,7 @@ pub(crate) struct Session { /// Optional rollout recorder for persisting the conversation transcript so /// sessions can be replayed or inspected later. - rollout: Mutex>, + rollout: Mutex>, state: Mutex, codex_linux_sandbox_exe: Option, } @@ -206,6 +207,9 @@ impl Session { struct State { approved_commands: HashSet>, current_task: Option, + /// Call IDs that have been sent from the Responses API but have not been sent back yet. + /// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400. + pending_call_ids: HashSet, previous_response_id: Option, pending_approvals: HashMap>, pending_input: Vec, @@ -312,7 +316,7 @@ impl Session { /// Append the given items to the session's rollout transcript (if enabled) /// and persist them to disk. async fn record_rollout_items(&self, items: &[ResponseItem]) { - // Clone the recorder outside of the mutex so we don’t hold the lock + // Clone the recorder outside of the mutex so we don't hold the lock // across an await point (MutexGuard is not Send). let recorder = { let guard = self.rollout.lock().unwrap(); @@ -411,6 +415,8 @@ impl Session { pub fn abort(&self) { info!("Aborting existing session"); let mut state = self.state.lock().unwrap(); + // Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn. + // We will generate a synthetic aborted response for each pending call id. state.pending_approvals.clear(); state.pending_input.clear(); if let Some(task) = state.current_task.take() { @@ -431,7 +437,7 @@ impl Session { } let Ok(json) = serde_json::to_string(¬ification) else { - tracing::error!("failed to serialise notification payload"); + error!("failed to serialise notification payload"); return; }; @@ -443,7 +449,7 @@ impl Session { // Fire-and-forget – we do not wait for completion. if let Err(e) = command.spawn() { - tracing::warn!("failed to spawn notifier '{}': {e}", notify_command[0]); + warn!("failed to spawn notifier '{}': {e}", notify_command[0]); } } } @@ -647,7 +653,7 @@ async fn submission_loop( match RolloutRecorder::new(&config, session_id, instructions.clone()).await { Ok(r) => Some(r), Err(e) => { - tracing::warn!("failed to initialise rollout recorder: {e}"); + warn!("failed to initialise rollout recorder: {e}"); None } }; @@ -742,7 +748,7 @@ async fn submission_loop( tokio::spawn(async move { if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await { - tracing::warn!("failed to append to message history: {e}"); + warn!("failed to append to message history: {e}"); } }); } @@ -772,7 +778,7 @@ async fn submission_loop( }; if let Err(e) = tx_event.send(event).await { - tracing::warn!("failed to send GetHistoryEntryResponse event: {e}"); + warn!("failed to send GetHistoryEntryResponse event: {e}"); } }); } @@ -1052,6 +1058,7 @@ async fn run_turn( /// events map to a `ResponseItem`. A `ResponseItem` may need to be /// "handled" such that it produces a `ResponseInputItem` that needs to be /// sent back to the model on the next turn. +#[derive(Debug)] struct ProcessedResponseItem { item: ResponseItem, response: Option, @@ -1062,7 +1069,57 @@ async fn try_run_turn( sub_id: &str, prompt: &Prompt, ) -> CodexResult> { - let mut stream = sess.client.clone().stream(prompt).await?; + // call_ids that are part of this response. + let completed_call_ids = prompt + .input + .iter() + .filter_map(|ri| match ri { + ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id), + ResponseItem::LocalShellCall { + call_id: Some(call_id), + .. + } => Some(call_id), + _ => None, + }) + .collect::>(); + + // call_ids that were pending but are not part of this response. + // This usually happens because the user interrupted the model before we responded to one of its tool calls + // and then the user sent a follow-up message. + let missing_calls = { + sess.state + .lock() + .unwrap() + .pending_call_ids + .iter() + .filter_map(|call_id| { + if completed_call_ids.contains(&call_id) { + None + } else { + Some(call_id.clone()) + } + }) + .map(|call_id| ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: FunctionCallOutputPayload { + content: "aborted".to_string(), + success: Some(false), + }, + }) + .collect::>() + }; + let prompt: Cow = if missing_calls.is_empty() { + Cow::Borrowed(prompt) + } else { + // Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses. + let input = [missing_calls, prompt.input.clone()].concat(); + Cow::Owned(Prompt { + input, + ..prompt.clone() + }) + }; + + let mut stream = sess.client.clone().stream(&prompt).await?; // Buffer all the incoming messages from the stream first, then execute them. // If we execute a function call in the middle of handling the stream, it can time out. @@ -1074,8 +1131,27 @@ async fn try_run_turn( let mut output = Vec::new(); for event in input { match event { + ResponseEvent::Created => { + let mut state = sess.state.lock().unwrap(); + // We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids. + state.pending_call_ids.clear(); + } ResponseEvent::OutputItemDone(item) => { + let call_id = match &item { + ResponseItem::LocalShellCall { + call_id: Some(call_id), + .. + } => Some(call_id), + ResponseItem::FunctionCall { call_id, .. } => Some(call_id), + _ => None, + }; + if let Some(call_id) = call_id { + // We just got a new call id so we need to make sure to respond to it in the next turn. + let mut state = sess.state.lock().unwrap(); + state.pending_call_ids.insert(call_id.clone()); + } let response = handle_response_item(sess, sub_id, item.clone()).await?; + output.push(ProcessedResponseItem { item, response }); } ResponseEvent::Completed { @@ -1138,7 +1214,7 @@ async fn handle_response_item( arguments, call_id, } => { - tracing::info!("FunctionCall: {arguments}"); + info!("FunctionCall: {arguments}"); Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await) } ResponseItem::LocalShellCall { @@ -1220,7 +1296,7 @@ async fn handle_function_call( // Unknown function: reply with structured failure so the model can adapt. ResponseInputItem::FunctionCallOutput { call_id, - output: crate::models::FunctionCallOutputPayload { + output: FunctionCallOutputPayload { content: format!("unsupported call: {}", name), success: None, }, @@ -1252,7 +1328,7 @@ fn parse_container_exec_arguments( // allow model to re-sample let output = ResponseInputItem::FunctionCallOutput { call_id: call_id.to_string(), - output: crate::models::FunctionCallOutputPayload { + output: FunctionCallOutputPayload { content: format!("failed to parse function arguments: {e}"), success: None, }, @@ -1320,7 +1396,7 @@ async fn handle_container_exec_with_params( ReviewDecision::Denied | ReviewDecision::Abort => { return ResponseInputItem::FunctionCallOutput { call_id, - output: crate::models::FunctionCallOutputPayload { + output: FunctionCallOutputPayload { content: "exec command rejected by user".to_string(), success: None, }, @@ -1336,7 +1412,7 @@ async fn handle_container_exec_with_params( SafetyCheck::Reject { reason } => { return ResponseInputItem::FunctionCallOutput { call_id, - output: crate::models::FunctionCallOutputPayload { + output: FunctionCallOutputPayload { content: format!("exec command rejected: {reason}"), success: None, }, @@ -1870,7 +1946,7 @@ fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result Vec { +fn get_writable_roots(cwd: &Path) -> Vec { let mut writable_roots = Vec::new(); if cfg!(target_os = "macos") { // On macOS, $TMPDIR is private to the user. @@ -1898,7 +1974,7 @@ fn get_writable_roots(cwd: &Path) -> Vec { } /// Exec output is a pre-serialized JSON payload -fn format_exec_output(output: &str, exit_code: i32, duration: std::time::Duration) -> String { +fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String { #[derive(Serialize)] struct ExecMetadata { exit_code: i32,