[Rust] Allow resuming a session that was killed with ctrl + c (#1387)
Previously, if you ctrl+c'd a conversation, all subsequent turns would 400 because the Responses API never got a response for one of its call ids. This ensures that if we aren't sending a call id by hand, we generate a synthetic aborted call. Fixes #1244 https://github.com/user-attachments/assets/5126354f-b970-45f5-8c65-f811bca8294a
This commit is contained in:
@@ -425,7 +425,12 @@ where
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
} // No other `Ok` variants exist at the moment, continue polling.
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -439,7 +444,7 @@ pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Size
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed { .. }
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
|
||||
@@ -168,7 +168,7 @@ impl ModelClient {
|
||||
// negligible.
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
}
|
||||
|
||||
@@ -208,6 +208,9 @@ struct SseEvent {
|
||||
item: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCreated {}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
@@ -335,6 +338,11 @@ where
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if event.response.is_some() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created {})).await;
|
||||
}
|
||||
}
|
||||
// Final response completed – includes array of output items & id
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
@@ -350,7 +358,6 @@ where
|
||||
};
|
||||
}
|
||||
"response.content_part.done"
|
||||
| "response.created"
|
||||
| "response.function_call_arguments.delta"
|
||||
| "response.in_progress"
|
||||
| "response.output_item.added"
|
||||
|
||||
@@ -51,6 +51,7 @@ impl Prompt {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Poisoned mutex should fail the program
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
@@ -188,7 +189,7 @@ pub(crate) struct Session {
|
||||
|
||||
/// Optional rollout recorder for persisting the conversation transcript so
|
||||
/// sessions can be replayed or inspected later.
|
||||
rollout: Mutex<Option<crate::rollout::RolloutRecorder>>,
|
||||
rollout: Mutex<Option<RolloutRecorder>>,
|
||||
state: Mutex<State>,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
}
|
||||
@@ -206,6 +207,9 @@ impl Session {
|
||||
struct State {
|
||||
approved_commands: HashSet<Vec<String>>,
|
||||
current_task: Option<AgentTask>,
|
||||
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
|
||||
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
|
||||
pending_call_ids: HashSet<String>,
|
||||
previous_response_id: Option<String>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
@@ -312,7 +316,7 @@ impl Session {
|
||||
/// Append the given items to the session's rollout transcript (if enabled)
|
||||
/// and persist them to disk.
|
||||
async fn record_rollout_items(&self, items: &[ResponseItem]) {
|
||||
// Clone the recorder outside of the mutex so we don’t hold the lock
|
||||
// Clone the recorder outside of the mutex so we don't hold the lock
|
||||
// across an await point (MutexGuard is not Send).
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock().unwrap();
|
||||
@@ -411,6 +415,8 @@ impl Session {
|
||||
pub fn abort(&self) {
|
||||
info!("Aborting existing session");
|
||||
let mut state = self.state.lock().unwrap();
|
||||
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
|
||||
// We will generate a synthetic aborted response for each pending call id.
|
||||
state.pending_approvals.clear();
|
||||
state.pending_input.clear();
|
||||
if let Some(task) = state.current_task.take() {
|
||||
@@ -431,7 +437,7 @@ impl Session {
|
||||
}
|
||||
|
||||
let Ok(json) = serde_json::to_string(¬ification) else {
|
||||
tracing::error!("failed to serialise notification payload");
|
||||
error!("failed to serialise notification payload");
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -443,7 +449,7 @@ impl Session {
|
||||
|
||||
// Fire-and-forget – we do not wait for completion.
|
||||
if let Err(e) = command.spawn() {
|
||||
tracing::warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
|
||||
warn!("failed to spawn notifier '{}': {e}", notify_command[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -647,7 +653,7 @@ async fn submission_loop(
|
||||
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
|
||||
Ok(r) => Some(r),
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to initialise rollout recorder: {e}");
|
||||
warn!("failed to initialise rollout recorder: {e}");
|
||||
None
|
||||
}
|
||||
};
|
||||
@@ -742,7 +748,7 @@ async fn submission_loop(
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = crate::message_history::append_entry(&text, &id, &config).await
|
||||
{
|
||||
tracing::warn!("failed to append to message history: {e}");
|
||||
warn!("failed to append to message history: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -772,7 +778,7 @@ async fn submission_loop(
|
||||
};
|
||||
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
tracing::warn!("failed to send GetHistoryEntryResponse event: {e}");
|
||||
warn!("failed to send GetHistoryEntryResponse event: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1052,6 +1058,7 @@ async fn run_turn(
|
||||
/// events map to a `ResponseItem`. A `ResponseItem` may need to be
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
#[derive(Debug)]
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
@@ -1062,7 +1069,57 @@ async fn try_run_turn(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
// call_ids that are part of this response.
|
||||
let completed_call_ids = prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// call_ids that were pending but are not part of this response.
|
||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||
// and then the user sent a follow-up message.
|
||||
let missing_calls = {
|
||||
sess.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.pending_call_ids
|
||||
.iter()
|
||||
.filter_map(|call_id| {
|
||||
if completed_call_ids.contains(&call_id) {
|
||||
None
|
||||
} else {
|
||||
Some(call_id.clone())
|
||||
}
|
||||
})
|
||||
.map(|call_id| ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
|
||||
Cow::Borrowed(prompt)
|
||||
} else {
|
||||
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
|
||||
let input = [missing_calls, prompt.input.clone()].concat();
|
||||
Cow::Owned(Prompt {
|
||||
input,
|
||||
..prompt.clone()
|
||||
})
|
||||
};
|
||||
|
||||
let mut stream = sess.client.clone().stream(&prompt).await?;
|
||||
|
||||
// Buffer all the incoming messages from the stream first, then execute them.
|
||||
// If we execute a function call in the middle of handling the stream, it can time out.
|
||||
@@ -1074,8 +1131,27 @@ async fn try_run_turn(
|
||||
let mut output = Vec::new();
|
||||
for event in input {
|
||||
match event {
|
||||
ResponseEvent::Created => {
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
|
||||
state.pending_call_ids.clear();
|
||||
}
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
let call_id = match &item {
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(call_id) = call_id {
|
||||
// We just got a new call id so we need to make sure to respond to it in the next turn.
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
state.pending_call_ids.insert(call_id.clone());
|
||||
}
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
@@ -1138,7 +1214,7 @@ async fn handle_response_item(
|
||||
arguments,
|
||||
call_id,
|
||||
} => {
|
||||
tracing::info!("FunctionCall: {arguments}");
|
||||
info!("FunctionCall: {arguments}");
|
||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
@@ -1220,7 +1296,7 @@ async fn handle_function_call(
|
||||
// Unknown function: reply with structured failure so the model can adapt.
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: crate::models::FunctionCallOutputPayload {
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("unsupported call: {}", name),
|
||||
success: None,
|
||||
},
|
||||
@@ -1252,7 +1328,7 @@ fn parse_container_exec_arguments(
|
||||
// allow model to re-sample
|
||||
let output = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: crate::models::FunctionCallOutputPayload {
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("failed to parse function arguments: {e}"),
|
||||
success: None,
|
||||
},
|
||||
@@ -1320,7 +1396,7 @@ async fn handle_container_exec_with_params(
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: crate::models::FunctionCallOutputPayload {
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "exec command rejected by user".to_string(),
|
||||
success: None,
|
||||
},
|
||||
@@ -1336,7 +1412,7 @@ async fn handle_container_exec_with_params(
|
||||
SafetyCheck::Reject { reason } => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: crate::models::FunctionCallOutputPayload {
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("exec command rejected: {reason}"),
|
||||
success: None,
|
||||
},
|
||||
@@ -1870,7 +1946,7 @@ fn apply_changes_from_apply_patch(action: &ApplyPatchAction) -> anyhow::Result<A
|
||||
})
|
||||
}
|
||||
|
||||
fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
|
||||
fn get_writable_roots(cwd: &Path) -> Vec<PathBuf> {
|
||||
let mut writable_roots = Vec::new();
|
||||
if cfg!(target_os = "macos") {
|
||||
// On macOS, $TMPDIR is private to the user.
|
||||
@@ -1898,7 +1974,7 @@ fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
|
||||
}
|
||||
|
||||
/// Exec output is a pre-serialized JSON payload
|
||||
fn format_exec_output(output: &str, exit_code: i32, duration: std::time::Duration) -> String {
|
||||
fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> String {
|
||||
#[derive(Serialize)]
|
||||
struct ExecMetadata {
|
||||
exit_code: i32,
|
||||
|
||||
Reference in New Issue
Block a user