[Rust] Allow resuming a session that was killed with ctrl + c (#1387)

Previously, if you ctrl+c'd a conversation, all subsequent turns would
400 because the Responses API never got a response for one of its call
ids. This ensures that if we aren't sending a call id by hand, we
generate a synthetic aborted call.

Fixes #1244 


https://github.com/user-attachments/assets/5126354f-b970-45f5-8c65-f811bca8294a
This commit is contained in:
Gabriel Peal
2025-06-26 14:40:42 -04:00
committed by GitHub
parent fcfe43c7df
commit a339a7bcce
4 changed files with 108 additions and 19 deletions

View File

@@ -425,7 +425,12 @@ where
response_id, response_id,
token_usage, token_usage,
}))); })));
} // No other `Ok` variants exist at the moment, continue polling. }
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
} }
} }
} }
@@ -439,7 +444,7 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
/// ///
/// ```ignore /// ```ignore
/// OutputItemDone(<full message>) /// OutputItemDone(<full message>)
/// Completed { .. } /// Completed
/// ``` /// ```
/// ///
/// No other `OutputItemDone` events will be seen by the caller. /// No other `OutputItemDone` events will be seen by the caller.

View File

@@ -168,7 +168,7 @@ impl ModelClient {
// negligible. // negligible.
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) { if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
// Surface the error body to callers. Use `unwrap_or_default` per Clippy. // Surface the error body to callers. Use `unwrap_or_default` per Clippy.
let body = (res.text().await).unwrap_or_default(); let body = res.text().await.unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body)); return Err(CodexErr::UnexpectedStatus(status, body));
} }
@@ -208,6 +208,9 @@ struct SseEvent {
item: Option<Value>, item: Option<Value>,
} }
#[derive(Debug, Deserialize)]
struct ResponseCreated {}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ResponseCompleted { struct ResponseCompleted {
id: String, id: String,
@@ -335,6 +338,11 @@ where
return; return;
} }
} }
"response.created" => {
if event.response.is_some() {
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
}
}
// Final response completed includes array of output items & id // Final response completed includes array of output items & id
"response.completed" => { "response.completed" => {
if let Some(resp_val) = event.response { if let Some(resp_val) = event.response {
@@ -350,7 +358,6 @@ where
}; };
} }
"response.content_part.done" "response.content_part.done"
| "response.created"
| "response.function_call_arguments.delta" | "response.function_call_arguments.delta"
| "response.in_progress" | "response.in_progress"
| "response.output_item.added" | "response.output_item.added"

View File

@@ -51,6 +51,7 @@ impl Prompt {
#[derive(Debug)] #[derive(Debug)]
pub enum ResponseEvent { pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem), OutputItemDone(ResponseItem),
Completed { Completed {
response_id: String, response_id: String,

View File

@@ -1,6 +1,7 @@
// Poisoned mutex should fail the program // Poisoned mutex should fail the program
#![allow(clippy::unwrap_used)] #![allow(clippy::unwrap_used)]
use std::borrow::Cow;
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::HashSet; use std::collections::HashSet;
use std::path::Path; use std::path::Path;
@@ -188,7 +189,7 @@ pub(crate) struct Session {
/// Optional rollout recorder for persisting the conversation transcript so /// Optional rollout recorder for persisting the conversation transcript so
/// sessions can be replayed or inspected later. /// sessions can be replayed or inspected later.
rollout: Mutex<Option<crate::rollout::RolloutRecorder>>, rollout: Mutex<Option<RolloutRecorder>>,
state: Mutex<State>, state: Mutex<State>,
codex_linux_sandbox_exe: Option<PathBuf>, codex_linux_sandbox_exe: Option<PathBuf>,
} }
@@ -206,6 +207,9 @@ impl Session {
struct State { struct State {
approved_commands: HashSet<Vec<String>>, approved_commands: HashSet<Vec<String>>,
current_task: Option<AgentTask>, current_task: Option<AgentTask>,
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
pending_call_ids: HashSet<String>,
previous_response_id: Option<String>, previous_response_id: Option<String>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>, pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>, pending_input: Vec<ResponseInputItem>,
@@ -312,7 +316,7 @@ impl Session {
/// Append the given items to the session's rollout transcript (if enabled) /// Append the given items to the session's rollout transcript (if enabled)
/// and persist them to disk. /// and persist them to disk.
async fn record_rollout_items(&self, items: &[ResponseItem]) { async fn record_rollout_items(&self, items: &[ResponseItem]) {
// Clone the recorder outside of the mutex so we dont hold the lock // Clone the recorder outside of the mutex so we don't hold the lock
// across an await point (MutexGuard is not Send). // across an await point (MutexGuard is not Send).
let recorder = { let recorder = {
let guard = self.rollout.lock().unwrap(); let guard = self.rollout.lock().unwrap();
@@ -411,6 +415,8 @@ impl Session {
pub fn abort(&self) { pub fn abort(&self) {
info!("Aborting existing session"); info!("Aborting existing session");
let mut state = self.state.lock().unwrap(); let mut state = self.state.lock().unwrap();
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
// We will generate a synthetic aborted response for each pending call id.
state.pending_approvals.clear(); state.pending_approvals.clear();
state.pending_input.clear(); state.pending_input.clear();
if let Some(task) = state.current_task.take() { if let Some(task) = state.current_task.take() {
@@ -431,7 +437,7 @@ impl Session {
} }
let Ok(json) = serde_json::to_string(&notification) else { let Ok(json) = serde_json::to_string(&notification) else {
tracing::error!("failed to serialise notification payload"); error!("failed to serialise notification payload");
return; return;
}; };
@@ -443,7 +449,7 @@ impl Session {
// Fire-and-forget we do not wait for completion. // Fire-and-forget we do not wait for completion.
if let Err(e) = command.spawn() { if let Err(e) = command.spawn() {
tracing::warn!("failed to spawn notifier '{}': {e}", notify_command[0]); warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
} }
} }
} }
@@ -647,7 +653,7 @@ async fn submission_loop(
match RolloutRecorder::new(&config, session_id, instructions.clone()).await { match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
Ok(r) => Some(r), Ok(r) => Some(r),
Err(e) => { Err(e) => {
tracing::warn!("failed to initialise rollout recorder: {e}"); warn!("failed to initialise rollout recorder: {e}");
None None
} }
}; };
@@ -742,7 +748,7 @@ async fn submission_loop(
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await
{ {
tracing::warn!("failed to append to message history: {e}"); warn!("failed to append to message history: {e}");
} }
}); });
} }
@@ -772,7 +778,7 @@ async fn submission_loop(
}; };
if let Err(e) = tx_event.send(event).await { if let Err(e) = tx_event.send(event).await {
tracing::warn!("failed to send GetHistoryEntryResponse event: {e}"); warn!("failed to send GetHistoryEntryResponse event: {e}");
} }
}); });
} }
@@ -1052,6 +1058,7 @@ async fn run_turn(
/// events map to a `ResponseItem`. A `ResponseItem` may need to be /// events map to a `ResponseItem`. A `ResponseItem` may need to be
/// "handled" such that it produces a `ResponseInputItem` that needs to be /// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn. /// sent back to the model on the next turn.
#[derive(Debug)]
struct ProcessedResponseItem { struct ProcessedResponseItem {
item: ResponseItem, item: ResponseItem,
response: Option<ResponseInputItem>, response: Option<ResponseInputItem>,
@@ -1062,7 +1069,57 @@ async fn try_run_turn(
sub_id: &str, sub_id: &str,
prompt: &Prompt, prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> { ) -> CodexResult<Vec<ProcessedResponseItem>> {
let mut stream = sess.client.clone().stream(prompt).await?; // call_ids that are part of this response.
let completed_call_ids = prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
_ => None,
})
.collect::<Vec<_>>();
// call_ids that were pending but are not part of this response.
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
sess.state
.lock()
.unwrap()
.pending_call_ids
.iter()
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
} else {
Some(call_id.clone())
}
})
.map(|call_id| ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
success: Some(false),
},
})
.collect::<Vec<_>>()
};
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
Cow::Borrowed(prompt)
} else {
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
let input = [missing_calls, prompt.input.clone()].concat();
Cow::Owned(Prompt {
input,
..prompt.clone()
})
};
let mut stream = sess.client.clone().stream(&prompt).await?;
// Buffer all the incoming messages from the stream first, then execute them. // Buffer all the incoming messages from the stream first, then execute them.
// If we execute a function call in the middle of handling the stream, it can time out. // If we execute a function call in the middle of handling the stream, it can time out.
@@ -1074,8 +1131,27 @@ async fn try_run_turn(
let mut output = Vec::new(); let mut output = Vec::new();
for event in input { for event in input {
match event { match event {
ResponseEvent::Created => {
let mut state = sess.state.lock().unwrap();
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
state.pending_call_ids.clear();
}
ResponseEvent::OutputItemDone(item) => { ResponseEvent::OutputItemDone(item) => {
let call_id = match &item {
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
_ => None,
};
if let Some(call_id) = call_id {
// We just got a new call id so we need to make sure to respond to it in the next turn.
let mut state = sess.state.lock().unwrap();
state.pending_call_ids.insert(call_id.clone());
}
let response = handle_response_item(sess, sub_id, item.clone()).await?; let response = handle_response_item(sess, sub_id, item.clone()).await?;
output.push(ProcessedResponseItem { item, response }); output.push(ProcessedResponseItem { item, response });
} }
ResponseEvent::Completed { ResponseEvent::Completed {
@@ -1138,7 +1214,7 @@ async fn handle_response_item(
arguments, arguments,
call_id, call_id,
} => { } => {
tracing::info!("FunctionCall: {arguments}"); info!("FunctionCall: {arguments}");
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await) Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
} }
ResponseItem::LocalShellCall { ResponseItem::LocalShellCall {
@@ -1220,7 +1296,7 @@ async fn handle_function_call(
// Unknown function: reply with structured failure so the model can adapt. // Unknown function: reply with structured failure so the model can adapt.
ResponseInputItem::FunctionCallOutput { ResponseInputItem::FunctionCallOutput {
call_id, call_id,
output: crate::models::FunctionCallOutputPayload { output: FunctionCallOutputPayload {
content: format!("unsupported call: {}", name), content: format!("unsupported call: {}", name),
success: None, success: None,
}, },
@@ -1252,7 +1328,7 @@ fn parse_container_exec_arguments(
// allow model to re-sample // allow model to re-sample
let output = ResponseInputItem::FunctionCallOutput { let output = ResponseInputItem::FunctionCallOutput {
call_id: call_id.to_string(), call_id: call_id.to_string(),
output: crate::models::FunctionCallOutputPayload { output: FunctionCallOutputPayload {
content: format!("failed to parse function arguments: {e}"), content: format!("failed to parse function arguments: {e}"),
success: None, success: None,
}, },
@@ -1320,7 +1396,7 @@ async fn handle_container_exec_with_params(
ReviewDecision::Denied | ReviewDecision::Abort => { ReviewDecision::Denied | ReviewDecision::Abort => {
return ResponseInputItem::FunctionCallOutput { return ResponseInputItem::FunctionCallOutput {
call_id, call_id,
output: crate::models::FunctionCallOutputPayload { output: FunctionCallOutputPayload {
content: "exec command rejected by user".to_string(), content: "exec command rejected by user".to_string(),
success: None, success: None,
}, },
@@ -1336,7 +1412,7 @@ async fn handle_container_exec_with_params(
SafetyCheck::Reject { reason } => { SafetyCheck::Reject { reason } => {
return ResponseInputItem::FunctionCallOutput { return ResponseInputItem::FunctionCallOutput {
call_id, call_id,
output: crate::models::FunctionCallOutputPayload { output: FunctionCallOutputPayload {
content: format!("exec command rejected: {reason}"), content: format!("exec command rejected: {reason}"),
success: None, success: None,
}, },
@@ -1870,7 +1946,7 @@ fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result<A
}) })
} }
fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> { fn get_writable_roots(cwd: &Path) -> Vec<PathBuf> {
let mut writable_roots = Vec::new(); let mut writable_roots = Vec::new();
if cfg!(target_os = "macos") { if cfg!(target_os = "macos") {
// On macOS, $TMPDIR is private to the user. // On macOS, $TMPDIR is private to the user.
@@ -1898,7 +1974,7 @@ fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
} }
/// Exec output is a pre-serialized JSON payload /// Exec output is a pre-serialized JSON payload
fn format_exec_output(output: &str, exit_code: i32, duration: std::time::Duration) -> String { fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String {
#[derive(Serialize)] #[derive(Serialize)]
struct ExecMetadata { struct ExecMetadata {
exit_code: i32, exit_code: i32,