Move changing turn input functionalities to ConversationHistory (#5473)
We are doing some ad-hoc logic while dealing with conversation history. Ideally, we shouldn't mutate `vec[responseitem]` manually at all and should depend on `ConversationHistory` for those changes. Those changes are: - Adding input to the history - Removing items from the history - Correcting history I am also adding some `error` logs for cases we shouldn't ideally face. For example, we shouldn't be missing `toolcalls` or `outputs`. We shouldn't hit `ContextWindowExceeded` while performing `compact` This refactor will give us granular control over our context management.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::path::PathBuf;
|
||||
@@ -873,7 +872,7 @@ impl Session {
|
||||
history.record_items(std::iter::once(response_item));
|
||||
}
|
||||
RolloutItem::Compacted(compacted) => {
|
||||
let snapshot = history.contents();
|
||||
let snapshot = history.get_history();
|
||||
let user_messages = collect_user_messages(&snapshot);
|
||||
let rebuilt = build_compacted_history(
|
||||
self.build_initial_context(turn_context),
|
||||
@@ -885,7 +884,7 @@ impl Session {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
history.contents()
|
||||
history.get_history()
|
||||
}
|
||||
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
@@ -934,11 +933,17 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
// todo (aibrahim): get rid of this method. we shouldn't deal with vec[resposne_item] and rather use ConversationHistory.
|
||||
pub(crate) async fn history_snapshot(&self) -> Vec<ResponseItem> {
|
||||
let state = self.state.lock().await;
|
||||
let mut state = self.state.lock().await;
|
||||
state.history_snapshot()
|
||||
}
|
||||
|
||||
pub(crate) async fn clone_history(&self) -> ConversationHistory {
|
||||
let state = self.state.lock().await;
|
||||
state.clone_history()
|
||||
}
|
||||
|
||||
async fn update_token_usage_info(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
@@ -1030,16 +1035,6 @@ impl Session {
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
/// history with additional items for this turn.
|
||||
pub async fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
let history = {
|
||||
let state = self.state.lock().await;
|
||||
state.history_snapshot()
|
||||
};
|
||||
[history, extra].concat()
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<UserInput>) -> Result<(), Vec<UserInput>> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
@@ -1526,11 +1521,13 @@ pub(crate) async fn run_task(
|
||||
// model sees a fresh conversation without the parent session's history.
|
||||
// For normal turns, continue recording to the session history as before.
|
||||
let is_review_mode = turn_context.is_review_mode;
|
||||
let mut review_thread_history: Vec<ResponseItem> = Vec::new();
|
||||
|
||||
let mut review_thread_history: ConversationHistory = ConversationHistory::new();
|
||||
if is_review_mode {
|
||||
// Seed review threads with environment context so the model knows the working directory.
|
||||
review_thread_history.extend(sess.build_initial_context(turn_context.as_ref()));
|
||||
review_thread_history.push(initial_input_for_turn.into());
|
||||
review_thread_history
|
||||
.record_items(sess.build_initial_context(turn_context.as_ref()).iter());
|
||||
review_thread_history.record_items(std::iter::once(&initial_input_for_turn.into()));
|
||||
} else {
|
||||
sess.record_input_and_rollout_usermsg(turn_context.as_ref(), &initial_input_for_turn)
|
||||
.await;
|
||||
@@ -1565,12 +1562,12 @@ pub(crate) async fn run_task(
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> = if is_review_mode {
|
||||
if !pending_input.is_empty() {
|
||||
review_thread_history.extend(pending_input);
|
||||
review_thread_history.record_items(&pending_input);
|
||||
}
|
||||
review_thread_history.clone()
|
||||
review_thread_history.get_history()
|
||||
} else {
|
||||
sess.record_conversation_items(&pending_input).await;
|
||||
sess.turn_input_with_history(pending_input).await
|
||||
sess.history_snapshot().await
|
||||
};
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
@@ -1708,7 +1705,7 @@ pub(crate) async fn run_task(
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
if is_review_mode {
|
||||
review_thread_history
|
||||
.extend(items_to_record_in_conversation_history.clone());
|
||||
.record_items(items_to_record_in_conversation_history.iter());
|
||||
} else {
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
@@ -1927,61 +1924,6 @@ async fn try_run_turn(
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
// call_ids that are part of this response.
|
||||
let completed_call_ids = prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// call_ids that were pending but are not part of this response.
|
||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||
// and then the user sent a follow-up message.
|
||||
let missing_calls = {
|
||||
prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.filter_map(|call_id| {
|
||||
if completed_call_ids.contains(&call_id) {
|
||||
None
|
||||
} else {
|
||||
Some(call_id.clone())
|
||||
}
|
||||
})
|
||||
.map(|call_id| ResponseItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: "aborted".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
|
||||
Cow::Borrowed(prompt)
|
||||
} else {
|
||||
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
|
||||
let input = [missing_calls, prompt.input.clone()].concat();
|
||||
Cow::Owned(Prompt {
|
||||
input,
|
||||
..prompt.clone()
|
||||
})
|
||||
};
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
@@ -1990,11 +1932,12 @@ async fn try_run_turn(
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
||||
.stream_with_task_kind(prompt, task_kind)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
@@ -2982,7 +2925,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
|
||||
|
||||
let summary1 = "summary one";
|
||||
let snapshot1 = live_history.contents();
|
||||
let snapshot1 = live_history.get_history();
|
||||
let user_messages1 = collect_user_messages(&snapshot1);
|
||||
let rebuilt1 = build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3015,7 +2958,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
|
||||
|
||||
let summary2 = "summary two";
|
||||
let snapshot2 = live_history.contents();
|
||||
let snapshot2 = live_history.get_history();
|
||||
let user_messages2 = collect_user_messages(&snapshot2);
|
||||
let rebuilt2 = build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3047,7 +2990,7 @@ mod tests {
|
||||
live_history.record_items(std::iter::once(&assistant3));
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));
|
||||
|
||||
(rollout_items, live_history.contents())
|
||||
(rollout_items, live_history.get_history())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -24,6 +24,7 @@ use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use futures::prelude::*;
|
||||
use tracing::error;
|
||||
|
||||
pub const SUMMARIZATION_PROMPT: &str = include_str!("../../templates/compact/prompt.md");
|
||||
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
|
||||
@@ -64,9 +65,10 @@ async fn run_compact_task_inner(
|
||||
input: Vec<UserInput>,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let mut turn_input = sess
|
||||
.turn_input_with_history(vec![initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
history.record_items(&[initial_input_for_turn.into()]);
|
||||
|
||||
let mut truncated_count = 0usize;
|
||||
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
@@ -83,6 +85,7 @@ async fn run_compact_task_inner(
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let turn_input = history.get_history();
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
@@ -107,7 +110,11 @@ async fn run_compact_task_inner(
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
turn_input.remove(0);
|
||||
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
||||
error!(
|
||||
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
||||
);
|
||||
history.remove_first_item();
|
||||
truncated_count += 1;
|
||||
retries = 0;
|
||||
continue;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use tracing::error;
|
||||
|
||||
/// Transcript of conversation history
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -12,11 +14,6 @@ impl ConversationHistory {
|
||||
Self { items: Vec::new() }
|
||||
}
|
||||
|
||||
/// Returns a clone of the contents in the transcript.
|
||||
pub(crate) fn contents(&self) -> Vec<ResponseItem> {
|
||||
self.items.clone()
|
||||
}
|
||||
|
||||
/// `items` is ordered from oldest to newest.
|
||||
pub(crate) fn record_items<I>(&mut self, items: I)
|
||||
where
|
||||
@@ -32,9 +29,287 @@ impl ConversationHistory {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_history(&mut self) -> Vec<ResponseItem> {
|
||||
self.normalize_history();
|
||||
self.contents()
|
||||
}
|
||||
|
||||
pub(crate) fn remove_first_item(&mut self) {
|
||||
if !self.items.is_empty() {
|
||||
// Remove the oldest item (front of the list). Items are ordered from
|
||||
// oldest → newest, so index 0 is the first entry recorded.
|
||||
let removed = self.items.remove(0);
|
||||
// If the removed item participates in a call/output pair, also remove
|
||||
// its corresponding counterpart to keep the invariants intact without
|
||||
// running a full normalization pass.
|
||||
self.remove_corresponding_for(&removed);
|
||||
}
|
||||
}
|
||||
|
||||
/// This function enforces a couple of invariants on the in-memory history:
|
||||
/// 1. every call (function/custom) has a corresponding output entry
|
||||
/// 2. every output has a corresponding call entry
|
||||
fn normalize_history(&mut self) {
|
||||
// all function/tool calls must have a corresponding output
|
||||
self.ensure_call_outputs_present();
|
||||
|
||||
// all outputs must have a corresponding function/tool call
|
||||
self.remove_orphan_outputs();
|
||||
}
|
||||
|
||||
/// Returns a clone of the contents in the transcript.
|
||||
fn contents(&self) -> Vec<ResponseItem> {
|
||||
self.items.clone()
|
||||
}
|
||||
|
||||
fn ensure_call_outputs_present(&mut self) {
|
||||
// Collect synthetic outputs to insert immediately after their calls.
|
||||
// Store the insertion position (index of call) alongside the item so
|
||||
// we can insert in reverse order and avoid index shifting.
|
||||
let mut missing_outputs_to_insert: Vec<(usize, ResponseItem)> = Vec::new();
|
||||
|
||||
for (idx, item) in self.items.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::FunctionCall { call_id, .. } => {
|
||||
let has_output = self.items.iter().any(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Function call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall { call_id, .. } => {
|
||||
let has_output = self.items.iter().any(|i| match i {
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Custom tool call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
// LocalShellCall is represented in upstream streams by a FunctionCallOutput
|
||||
ResponseItem::LocalShellCall { call_id, .. } => {
|
||||
if let Some(call_id) = call_id.as_ref() {
|
||||
let has_output = self.items.iter().any(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Local shell call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !missing_outputs_to_insert.is_empty() {
|
||||
// Insert from the end to avoid shifting subsequent indices.
|
||||
missing_outputs_to_insert.sort_by_key(|(i, _)| *i);
|
||||
for (idx, item) in missing_outputs_to_insert.into_iter().rev() {
|
||||
let insert_pos = idx + 1; // place immediately after the call
|
||||
if insert_pos <= self.items.len() {
|
||||
self.items.insert(insert_pos, item);
|
||||
} else {
|
||||
self.items.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_orphan_outputs(&mut self) {
|
||||
// Work on a snapshot to avoid borrowing `self.items` while mutating it.
|
||||
let snapshot = self.items.clone();
|
||||
let mut orphan_output_call_ids: std::collections::HashSet<String> =
|
||||
std::collections::HashSet::new();
|
||||
|
||||
for item in &snapshot {
|
||||
match item {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => {
|
||||
let has_call = snapshot.iter().any(|i| match i {
|
||||
ResponseItem::FunctionCall {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(existing),
|
||||
..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_call {
|
||||
error_or_panic(format!("Function call is missing for call id: {call_id}"));
|
||||
orphan_output_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => {
|
||||
let has_call = snapshot.iter().any(|i| match i {
|
||||
ResponseItem::CustomToolCall {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_call {
|
||||
error_or_panic(format!(
|
||||
"Custom tool call is missing for call id: {call_id}"
|
||||
));
|
||||
orphan_output_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !orphan_output_call_ids.is_empty() {
|
||||
let ids = orphan_output_call_ids;
|
||||
self.items.retain(|i| match i {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. }
|
||||
| ResponseItem::CustomToolCallOutput { call_id, .. } => !ids.contains(call_id),
|
||||
_ => true,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
|
||||
self.items = items;
|
||||
}
|
||||
|
||||
/// Removes the corresponding paired item for the provided `item`, if any.
|
||||
///
|
||||
/// Pairs:
|
||||
/// - FunctionCall <-> FunctionCallOutput
|
||||
/// - CustomToolCall <-> CustomToolCallOutput
|
||||
/// - LocalShellCall(call_id: Some) <-> FunctionCallOutput
|
||||
fn remove_corresponding_for(&mut self, item: &ResponseItem) {
|
||||
match item {
|
||||
ResponseItem::FunctionCall { call_id, .. } => {
|
||||
self.remove_first_matching(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
}
|
||||
ResponseItem::CustomToolCall { call_id, .. } => {
|
||||
self.remove_first_matching(|i| match i {
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => {
|
||||
self.remove_first_matching(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => {
|
||||
self.remove_first_matching(|i| match i {
|
||||
ResponseItem::FunctionCall {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(existing),
|
||||
..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => {
|
||||
self.remove_first_matching(|i| match i {
|
||||
ResponseItem::CustomToolCall {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove the first item matching the predicate.
|
||||
fn remove_first_matching<F>(&mut self, predicate: F)
|
||||
where
|
||||
F: FnMut(&ResponseItem) -> bool,
|
||||
{
|
||||
if let Some(pos) = self.items.iter().position(predicate) {
|
||||
self.items.remove(pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn error_or_panic(message: String) {
|
||||
if cfg!(debug_assertions) || env!("CARGO_PKG_VERSION").contains("alpha") {
|
||||
panic!("{message}");
|
||||
} else {
|
||||
error!("{message}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Anything that is not a system message or "reasoning" message is considered
|
||||
@@ -57,6 +332,11 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::LocalShellExecAction;
|
||||
use codex_protocol::models::LocalShellStatus;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
@@ -68,6 +348,12 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn create_history_with_items(items: Vec<ResponseItem>) -> ConversationHistory {
|
||||
let mut h = ConversationHistory::new();
|
||||
h.record_items(items.iter());
|
||||
h
|
||||
}
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
@@ -117,4 +403,452 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_removes_matching_output_for_function_call() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_removes_matching_call_for_output() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-2".to_string(),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_handles_local_shell_pair() {
|
||||
let items = vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("call-3".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-3".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_handles_custom_tool_pair() {
|
||||
let items = vec![
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-1".to_string(),
|
||||
name: "my_tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "tool-1".to_string(),
|
||||
output: "ok".to_string(),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
//TODO(aibrahim): run CI in release mode.
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_function_call() {
|
||||
let items = vec![ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-x".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_custom_tool_call() {
|
||||
let items = vec![ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "tool-x".to_string(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_local_shell_call_with_id() {
|
||||
let items = vec![ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "shell-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_removes_orphan_function_call_output() {
|
||||
let items = vec![ResponseItem::FunctionCallOutput {
|
||||
call_id: "orphan-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_removes_orphan_custom_tool_call_output() {
|
||||
let items = vec![ResponseItem::CustomToolCallOutput {
|
||||
call_id: "orphan-2".to_string(),
|
||||
output: "ok".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_mixed_inserts_and_removals() {
|
||||
let items = vec![
|
||||
// Will get an inserted output
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
// Orphan output that should be removed
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
// Will get an inserted custom tool output
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
// Local shell call also gets an inserted function call output
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "t1".to_string(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "s1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// In debug builds we panic on normalization errors instead of silently fixing them.
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_function_call_panics_in_debug() {
|
||||
let items = vec![ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_custom_tool_call_panics_in_debug() {
|
||||
let items = vec![ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_local_shell_call_with_id_panics_in_debug() {
|
||||
let items = vec![ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_removes_orphan_function_call_output_panics_in_debug() {
|
||||
let items = vec![ResponseItem::FunctionCallOutput {
|
||||
call_id: "orphan-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() {
|
||||
let items = vec![ResponseItem::CustomToolCallOutput {
|
||||
call_id: "orphan-2".to_string(),
|
||||
output: "ok".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_mixed_inserts_and_removals_panics_in_debug() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
success: None,
|
||||
},
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,8 +36,12 @@ impl SessionState {
|
||||
self.history.record_items(items)
|
||||
}
|
||||
|
||||
pub(crate) fn history_snapshot(&self) -> Vec<ResponseItem> {
|
||||
self.history.contents()
|
||||
pub(crate) fn history_snapshot(&mut self) -> Vec<ResponseItem> {
|
||||
self.history.get_history()
|
||||
}
|
||||
|
||||
pub(crate) fn clone_history(&self) -> ConversationHistory {
|
||||
self.history.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
|
||||
|
||||
@@ -97,6 +97,10 @@ impl Match for ResponseMock {
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push(ResponsesRequest(request.clone()));
|
||||
|
||||
// Enforce invariant checks on every request body captured by the mock.
|
||||
// Panic on orphan tool outputs or calls to catch regressions early.
|
||||
validate_request_body_invariants(request);
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -386,3 +390,90 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> Res
|
||||
|
||||
response_mock
|
||||
}
|
||||
|
||||
/// Validate invariants on the request body sent to `/v1/responses`.
|
||||
///
|
||||
/// - No `function_call_output`/`custom_tool_call_output` with missing/empty `call_id`.
|
||||
/// - Every `function_call_output` must match a prior `function_call` or
|
||||
/// `local_shell_call` with the same `call_id` in the same `input`.
|
||||
/// - Every `custom_tool_call_output` must match a prior `custom_tool_call`.
|
||||
/// - Additionally, enforce symmetry: every `function_call`/`custom_tool_call`
|
||||
/// in the `input` must have a matching output entry.
|
||||
fn validate_request_body_invariants(request: &wiremock::Request) {
|
||||
let Ok(body): Result<Value, _> = request.body_json() else {
|
||||
return;
|
||||
};
|
||||
let Some(items) = body.get("input").and_then(Value::as_array) else {
|
||||
panic!("input array not found in request");
|
||||
};
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
fn get_call_id(item: &Value) -> Option<&str> {
|
||||
item.get("call_id")
|
||||
.and_then(Value::as_str)
|
||||
.filter(|id| !id.is_empty())
|
||||
}
|
||||
|
||||
fn gather_ids(items: &[Value], kind: &str) -> HashSet<String> {
|
||||
items
|
||||
.iter()
|
||||
.filter(|item| item.get("type").and_then(Value::as_str) == Some(kind))
|
||||
.filter_map(get_call_id)
|
||||
.map(str::to_string)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn gather_output_ids(items: &[Value], kind: &str, missing_msg: &str) -> HashSet<String> {
|
||||
items
|
||||
.iter()
|
||||
.filter(|item| item.get("type").and_then(Value::as_str) == Some(kind))
|
||||
.map(|item| {
|
||||
let Some(id) = get_call_id(item) else {
|
||||
panic!("{missing_msg}");
|
||||
};
|
||||
id.to_string()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
let function_calls = gather_ids(items, "function_call");
|
||||
let custom_tool_calls = gather_ids(items, "custom_tool_call");
|
||||
let local_shell_calls = gather_ids(items, "local_shell_call");
|
||||
let function_call_outputs = gather_output_ids(
|
||||
items,
|
||||
"function_call_output",
|
||||
"orphan function_call_output with empty call_id should be dropped",
|
||||
);
|
||||
let custom_tool_call_outputs = gather_output_ids(
|
||||
items,
|
||||
"custom_tool_call_output",
|
||||
"orphan custom_tool_call_output with empty call_id should be dropped",
|
||||
);
|
||||
|
||||
for cid in &function_call_outputs {
|
||||
assert!(
|
||||
function_calls.contains(cid) || local_shell_calls.contains(cid),
|
||||
"function_call_output without matching call in input: {cid}",
|
||||
);
|
||||
}
|
||||
for cid in &custom_tool_call_outputs {
|
||||
assert!(
|
||||
custom_tool_calls.contains(cid),
|
||||
"custom_tool_call_output without matching call in input: {cid}",
|
||||
);
|
||||
}
|
||||
|
||||
for cid in &function_calls {
|
||||
assert!(
|
||||
function_call_outputs.contains(cid),
|
||||
"Function call output is missing for call id: {cid}",
|
||||
);
|
||||
}
|
||||
for cid in &custom_tool_calls {
|
||||
assert!(
|
||||
custom_tool_call_outputs.contains(cid),
|
||||
"Custom tool call output is missing for call id: {cid}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,62 +227,6 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_mock_server().await;
|
||||
let mut builder = test_codex();
|
||||
let test = builder.build(&server).await?;
|
||||
|
||||
let local_shell_event = json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "local_shell_call",
|
||||
"status": "completed",
|
||||
"action": {
|
||||
"type": "exec",
|
||||
"command": ["/bin/echo", "hi"],
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_response_created("resp-1"),
|
||||
local_shell_event,
|
||||
ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
let second_mock = mount_sse_once(
|
||||
&server,
|
||||
sse(vec![
|
||||
ev_assistant_message("msg-1", "done"),
|
||||
ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
submit_turn(
|
||||
&test,
|
||||
"check shell output",
|
||||
AskForApproval::Never,
|
||||
SandboxPolicy::DangerFullAccess,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let item = second_mock.single_request().function_call_output("");
|
||||
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
|
||||
assert_eq!(
|
||||
item.get("output").and_then(Value::as_str),
|
||||
Some("LocalShellCall without call_id or id"),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||
let server = start_mock_server().await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user