[Rust] Allow resuming a session that was killed with ctrl + c (#1387)

Previously, if you ctrl+c'd a conversation, all subsequent turns would
400 because the Responses API never got a response for one of its call
ids. This ensures that if we aren't sending a call id by hand, we
generate a synthetic aborted call.

Fixes #1244 


https://github.com/user-attachments/assets/5126354f-b970-45f5-8c65-f811bca8294a
This commit is contained in:
Gabriel Peal
2025-06-26 14:40:42 -04:00
committed by GitHub
parent fcfe43c7df
commit a339a7bcce
4 changed files with 108 additions and 19 deletions

View File

@@ -425,7 +425,12 @@ where
response_id,
token_usage,
})));
} // No other `Ok` variants exist at the moment, continue polling.
}
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
// These events are exclusive to the Responses API and
// will never appear in a Chat Completions stream.
continue;
}
}
}
}
@@ -439,7 +444,7 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
///
/// ```ignore
/// OutputItemDone(<full message>)
/// Completed { .. }
/// Completed
/// ```
///
/// No other `OutputItemDone` events will be seen by the caller.

View File

@@ -168,7 +168,7 @@ impl ModelClient {
// negligible.
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
let body = (res.text().await).unwrap_or_default();
let body = res.text().await.unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body));
}
@@ -208,6 +208,9 @@ struct SseEvent {
item: Option<Value>,
}
#[derive(Debug, Deserialize)]
struct ResponseCreated {}
#[derive(Debug, Deserialize)]
struct ResponseCompleted {
id: String,
@@ -335,6 +338,11 @@ where
return;
}
}
"response.created" => {
if event.response.is_some() {
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
}
}
// Final response completed includes array of output items & id
"response.completed" => {
if let Some(resp_val) = event.response {
@@ -350,7 +358,6 @@ where
};
}
"response.content_part.done"
| "response.created"
| "response.function_call_arguments.delta"
| "response.in_progress"
| "response.output_item.added"

View File

@@ -51,6 +51,7 @@ impl Prompt {
#[derive(Debug)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
Completed {
response_id: String,

View File

@@ -1,6 +1,7 @@
// Poisoned mutex should fail the program
#![allow(clippy::unwrap_used)]
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::Path;
@@ -188,7 +189,7 @@ pub(crate) struct Session {
/// Optional rollout recorder for persisting the conversation transcript so
/// sessions can be replayed or inspected later.
rollout: Mutex<Option<crate::rollout::RolloutRecorder>>,
rollout: Mutex<Option<RolloutRecorder>>,
state: Mutex<State>,
codex_linux_sandbox_exe: Option<PathBuf>,
}
@@ -206,6 +207,9 @@ impl Session {
struct State {
approved_commands: HashSet<Vec<String>>,
current_task: Option<AgentTask>,
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
pending_call_ids: HashSet<String>,
previous_response_id: Option<String>,
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
@@ -312,7 +316,7 @@ impl Session {
/// Append the given items to the session's rollout transcript (if enabled)
/// and persist them to disk.
async fn record_rollout_items(&self, items: &[ResponseItem]) {
// Clone the recorder outside of the mutex so we dont hold the lock
// Clone the recorder outside of the mutex so we don't hold the lock
// across an await point (MutexGuard is not Send).
let recorder = {
let guard = self.rollout.lock().unwrap();
@@ -411,6 +415,8 @@ impl Session {
pub fn abort(&self) {
info!("Aborting existing session");
let mut state = self.state.lock().unwrap();
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
// We will generate a synthetic aborted response for each pending call id.
state.pending_approvals.clear();
state.pending_input.clear();
if let Some(task) = state.current_task.take() {
@@ -431,7 +437,7 @@ impl Session {
}
let Ok(json) = serde_json::to_string(&notification) else {
tracing::error!("failed to serialise notification payload");
error!("failed to serialise notification payload");
return;
};
@@ -443,7 +449,7 @@ impl Session {
// Fire-and-forget we do not wait for completion.
if let Err(e) = command.spawn() {
tracing::warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
}
}
}
@@ -647,7 +653,7 @@ async fn submission_loop(
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
Ok(r) => Some(r),
Err(e) => {
tracing::warn!("failed to initialise rollout recorder: {e}");
warn!("failed to initialise rollout recorder: {e}");
None
}
};
@@ -742,7 +748,7 @@ async fn submission_loop(
tokio::spawn(async move {
if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await
{
tracing::warn!("failed to append to message history: {e}");
warn!("failed to append to message history: {e}");
}
});
}
@@ -772,7 +778,7 @@ async fn submission_loop(
};
if let Err(e) = tx_event.send(event).await {
tracing::warn!("failed to send GetHistoryEntryResponse event: {e}");
warn!("failed to send GetHistoryEntryResponse event: {e}");
}
});
}
@@ -1052,6 +1058,7 @@ async fn run_turn(
/// events map to a `ResponseItem`. A `ResponseItem` may need to be
/// "handled" such that it produces a `ResponseInputItem` that needs to be
/// sent back to the model on the next turn.
#[derive(Debug)]
struct ProcessedResponseItem {
item: ResponseItem,
response: Option<ResponseInputItem>,
@@ -1062,7 +1069,57 @@ async fn try_run_turn(
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
let mut stream = sess.client.clone().stream(prompt).await?;
// call_ids that are part of this response.
let completed_call_ids = prompt
.input
.iter()
.filter_map(|ri| match ri {
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
_ => None,
})
.collect::<Vec<_>>();
// call_ids that were pending but are not part of this response.
// This usually happens because the user interrupted the model before we responded to one of its tool calls
// and then the user sent a follow-up message.
let missing_calls = {
sess.state
.lock()
.unwrap()
.pending_call_ids
.iter()
.filter_map(|call_id| {
if completed_call_ids.contains(&call_id) {
None
} else {
Some(call_id.clone())
}
})
.map(|call_id| ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
success: Some(false),
},
})
.collect::<Vec<_>>()
};
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
Cow::Borrowed(prompt)
} else {
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
let input = [missing_calls, prompt.input.clone()].concat();
Cow::Owned(Prompt {
input,
..prompt.clone()
})
};
let mut stream = sess.client.clone().stream(&prompt).await?;
// Buffer all the incoming messages from the stream first, then execute them.
// If we execute a function call in the middle of handling the stream, it can time out.
@@ -1074,8 +1131,27 @@ async fn try_run_turn(
let mut output = Vec::new();
for event in input {
match event {
ResponseEvent::Created => {
let mut state = sess.state.lock().unwrap();
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
state.pending_call_ids.clear();
}
ResponseEvent::OutputItemDone(item) => {
let call_id = match &item {
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id),
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
_ => None,
};
if let Some(call_id) = call_id {
// We just got a new call id so we need to make sure to respond to it in the next turn.
let mut state = sess.state.lock().unwrap();
state.pending_call_ids.insert(call_id.clone());
}
let response = handle_response_item(sess, sub_id, item.clone()).await?;
output.push(ProcessedResponseItem { item, response });
}
ResponseEvent::Completed {
@@ -1138,7 +1214,7 @@ async fn handle_response_item(
arguments,
call_id,
} => {
tracing::info!("FunctionCall: {arguments}");
info!("FunctionCall: {arguments}");
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
}
ResponseItem::LocalShellCall {
@@ -1220,7 +1296,7 @@ async fn handle_function_call(
// Unknown function: reply with structured failure so the model can adapt.
ResponseInputItem::FunctionCallOutput {
call_id,
output: crate::models::FunctionCallOutputPayload {
output: FunctionCallOutputPayload {
content: format!("unsupported call: {}", name),
success: None,
},
@@ -1252,7 +1328,7 @@ fn parse_container_exec_arguments(
// allow model to re-sample
let output = ResponseInputItem::FunctionCallOutput {
call_id: call_id.to_string(),
output: crate::models::FunctionCallOutputPayload {
output: FunctionCallOutputPayload {
content: format!("failed to parse function arguments: {e}"),
success: None,
},
@@ -1320,7 +1396,7 @@ async fn handle_container_exec_with_params(
ReviewDecision::Denied | ReviewDecision::Abort => {
return ResponseInputItem::FunctionCallOutput {
call_id,
output: crate::models::FunctionCallOutputPayload {
output: FunctionCallOutputPayload {
content: "exec command rejected by user".to_string(),
success: None,
},
@@ -1336,7 +1412,7 @@ async fn handle_container_exec_with_params(
SafetyCheck::Reject { reason } => {
return ResponseInputItem::FunctionCallOutput {
call_id,
output: crate::models::FunctionCallOutputPayload {
output: FunctionCallOutputPayload {
content: format!("exec command rejected: {reason}"),
success: None,
},
@@ -1870,7 +1946,7 @@ fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result<A
})
}
fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
fn get_writable_roots(cwd: &Path) -> Vec<PathBuf> {
let mut writable_roots = Vec::new();
if cfg!(target_os = "macos") {
// On macOS, $TMPDIR is private to the user.
@@ -1898,7 +1974,7 @@ fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
}
/// Exec output is a pre-serialized JSON payload
fn format_exec_output(output: &str, exit_code: i32, duration: std::time::Duration) -> String {
fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String {
#[derive(Serialize)]
struct ExecMetadata {
exit_code: i32,