Filter out reasoning items from previous turns (#5857)
Reduces request size and prevents 400 errors when switching between API orgs. Based on Responses API behavior described in https://cookbook.openai.com/examples/responses_api/reasoning_items#caching
This commit is contained in:
@@ -1001,12 +1001,6 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo (aibrahim): get rid of this method. we shouldn't deal with vec[resposne_item] and rather use ConversationHistory.
|
|
||||||
pub(crate) async fn history_snapshot(&self) -> Vec<ResponseItem> {
|
|
||||||
let mut state = self.state.lock().await;
|
|
||||||
state.history_snapshot()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) async fn clone_history(&self) -> ConversationHistory {
|
pub(crate) async fn clone_history(&self) -> ConversationHistory {
|
||||||
let state = self.state.lock().await;
|
let state = self.state.lock().await;
|
||||||
state.clone_history()
|
state.clone_history()
|
||||||
@@ -1746,11 +1740,11 @@ pub(crate) async fn run_task(
|
|||||||
if !pending_input.is_empty() {
|
if !pending_input.is_empty() {
|
||||||
review_thread_history.record_items(&pending_input);
|
review_thread_history.record_items(&pending_input);
|
||||||
}
|
}
|
||||||
review_thread_history.get_history()
|
review_thread_history.get_history_for_prompt()
|
||||||
} else {
|
} else {
|
||||||
sess.record_conversation_items(&turn_context, &pending_input)
|
sess.record_conversation_items(&turn_context, &pending_input)
|
||||||
.await;
|
.await;
|
||||||
sess.history_snapshot().await
|
sess.clone_history().await.get_history_for_prompt()
|
||||||
};
|
};
|
||||||
|
|
||||||
let turn_input_messages: Vec<String> = turn_input
|
let turn_input_messages: Vec<String> = turn_input
|
||||||
@@ -1907,13 +1901,6 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn filter_model_visible_history(input: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
|
||||||
input
|
|
||||||
.into_iter()
|
|
||||||
.filter(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn run_turn(
|
async fn run_turn(
|
||||||
sess: Arc<Session>,
|
sess: Arc<Session>,
|
||||||
turn_context: Arc<TurnContext>,
|
turn_context: Arc<TurnContext>,
|
||||||
@@ -1934,7 +1921,7 @@ async fn run_turn(
|
|||||||
.supports_parallel_tool_calls;
|
.supports_parallel_tool_calls;
|
||||||
let parallel_tool_calls = model_supports_parallel;
|
let parallel_tool_calls = model_supports_parallel;
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input: filter_model_visible_history(input),
|
input,
|
||||||
tools: router.specs(),
|
tools: router.specs(),
|
||||||
parallel_tool_calls,
|
parallel_tool_calls,
|
||||||
base_instructions_override: turn_context.base_instructions.clone(),
|
base_instructions_override: turn_context.base_instructions.clone(),
|
||||||
@@ -2462,7 +2449,9 @@ mod tests {
|
|||||||
},
|
},
|
||||||
)));
|
)));
|
||||||
|
|
||||||
let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() });
|
let actual = tokio_test::block_on(async {
|
||||||
|
session.state.lock().await.clone_history().get_history()
|
||||||
|
});
|
||||||
assert_eq!(expected, actual);
|
assert_eq!(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2473,7 +2462,9 @@ mod tests {
|
|||||||
|
|
||||||
tokio_test::block_on(session.record_initial_history(InitialHistory::Forked(rollout_items)));
|
tokio_test::block_on(session.record_initial_history(InitialHistory::Forked(rollout_items)));
|
||||||
|
|
||||||
let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() });
|
let actual = tokio_test::block_on(async {
|
||||||
|
session.state.lock().await.clone_history().get_history()
|
||||||
|
});
|
||||||
assert_eq!(expected, actual);
|
assert_eq!(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2870,7 +2861,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let history = sess.history_snapshot().await;
|
let history = sess.clone_history().await.get_history();
|
||||||
let found = history.iter().any(|item| match item {
|
let found = history.iter().any(|item| match item {
|
||||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||||
content.iter().any(|ci| match ci {
|
content.iter().any(|ci| match ci {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use super::Session;
|
use super::Session;
|
||||||
use super::TurnContext;
|
use super::TurnContext;
|
||||||
use super::filter_model_visible_history;
|
|
||||||
use super::get_last_assistant_message_from_turn;
|
use super::get_last_assistant_message_from_turn;
|
||||||
use crate::Prompt;
|
use crate::Prompt;
|
||||||
use crate::client_common::ResponseEvent;
|
use crate::client_common::ResponseEvent;
|
||||||
@@ -86,10 +85,9 @@ async fn run_compact_task_inner(
|
|||||||
sess.persist_rollout_items(&[rollout_item]).await;
|
sess.persist_rollout_items(&[rollout_item]).await;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let turn_input = history.get_history();
|
let turn_input = history.get_history_for_prompt();
|
||||||
let prompt_input = filter_model_visible_history(turn_input.clone());
|
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input: prompt_input.clone(),
|
input: turn_input.clone(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||||
@@ -111,7 +109,7 @@ async fn run_compact_task_inner(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||||
if prompt_input.len() > 1 {
|
if turn_input.len() > 1 {
|
||||||
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
||||||
error!(
|
error!(
|
||||||
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
||||||
@@ -150,7 +148,7 @@ async fn run_compact_task_inner(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let history_snapshot = sess.history_snapshot().await;
|
let history_snapshot = sess.clone_history().await.get_history();
|
||||||
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||||
let user_messages = collect_user_messages(&history_snapshot);
|
let user_messages = collect_user_messages(&history_snapshot);
|
||||||
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
||||||
|
|||||||
@@ -67,6 +67,15 @@ impl ConversationHistory {
|
|||||||
self.contents()
|
self.contents()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the history prepared for sending to the model.
|
||||||
|
// With extra response items filtered out and GhostCommits removed.
|
||||||
|
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
|
||||||
|
let mut history = self.get_history();
|
||||||
|
Self::remove_ghost_snapshots(&mut history);
|
||||||
|
Self::remove_reasoning_before_last_turn(&mut history);
|
||||||
|
history
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn remove_first_item(&mut self) {
|
pub(crate) fn remove_first_item(&mut self) {
|
||||||
if !self.items.is_empty() {
|
if !self.items.is_empty() {
|
||||||
// Remove the oldest item (front of the list). Items are ordered from
|
// Remove the oldest item (front of the list). Items are ordered from
|
||||||
@@ -111,6 +120,29 @@ impl ConversationHistory {
|
|||||||
self.items.clone()
|
self.items.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn remove_ghost_snapshots(items: &mut Vec<ResponseItem>) {
|
||||||
|
items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_reasoning_before_last_turn(items: &mut Vec<ResponseItem>) {
|
||||||
|
// Responses API drops reasoning items before the last user message.
|
||||||
|
// Sending them is harmless but can lead to validation errors when switching between API organizations.
|
||||||
|
// https://cookbook.openai.com/examples/responses_api/reasoning_items#caching
|
||||||
|
let Some(last_user_index) = items
|
||||||
|
.iter()
|
||||||
|
// Use last user message as the turn boundary.
|
||||||
|
.rposition(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user"))
|
||||||
|
else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let mut index = 0usize;
|
||||||
|
items.retain(|item| {
|
||||||
|
let keep = index >= last_user_index || !matches!(item, ResponseItem::Reasoning { .. });
|
||||||
|
index += 1;
|
||||||
|
keep
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn ensure_call_outputs_present(&mut self) {
|
fn ensure_call_outputs_present(&mut self) {
|
||||||
// Collect synthetic outputs to insert immediately after their calls.
|
// Collect synthetic outputs to insert immediately after their calls.
|
||||||
// Store the insertion position (index of call) alongside the item so
|
// Store the insertion position (index of call) alongside the item so
|
||||||
@@ -498,6 +530,7 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use codex_git_tooling::GhostCommit;
|
||||||
use codex_protocol::models::ContentItem;
|
use codex_protocol::models::ContentItem;
|
||||||
use codex_protocol::models::FunctionCallOutputPayload;
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
use codex_protocol::models::LocalShellAction;
|
use codex_protocol::models::LocalShellAction;
|
||||||
@@ -515,6 +548,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reasoning(id: &str) -> ResponseItem {
|
||||||
|
ResponseItem::Reasoning {
|
||||||
|
id: id.to_string(),
|
||||||
|
summary: Vec::new(),
|
||||||
|
content: None,
|
||||||
|
encrypted_content: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn create_history_with_items(items: Vec<ResponseItem>) -> ConversationHistory {
|
fn create_history_with_items(items: Vec<ResponseItem>) -> ConversationHistory {
|
||||||
let mut h = ConversationHistory::new();
|
let mut h = ConversationHistory::new();
|
||||||
h.record_items(items.iter());
|
h.record_items(items.iter());
|
||||||
@@ -571,6 +613,50 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_history_drops_reasoning_before_last_user_message() {
|
||||||
|
let mut history = ConversationHistory::new();
|
||||||
|
let items = vec![
|
||||||
|
user_msg("initial"),
|
||||||
|
reasoning("first"),
|
||||||
|
assistant_msg("ack"),
|
||||||
|
user_msg("latest"),
|
||||||
|
reasoning("second"),
|
||||||
|
assistant_msg("ack"),
|
||||||
|
reasoning("third"),
|
||||||
|
];
|
||||||
|
history.record_items(items.iter());
|
||||||
|
|
||||||
|
let filtered = history.get_history_for_prompt();
|
||||||
|
assert_eq!(
|
||||||
|
filtered,
|
||||||
|
vec![
|
||||||
|
user_msg("initial"),
|
||||||
|
assistant_msg("ack"),
|
||||||
|
user_msg("latest"),
|
||||||
|
reasoning("second"),
|
||||||
|
assistant_msg("ack"),
|
||||||
|
reasoning("third"),
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let reasoning_count = history
|
||||||
|
.contents()
|
||||||
|
.iter()
|
||||||
|
.filter(|item| matches!(item, ResponseItem::Reasoning { .. }))
|
||||||
|
.count();
|
||||||
|
assert_eq!(reasoning_count, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn get_history_for_prompt_drops_ghost_commits() {
|
||||||
|
let items = vec![ResponseItem::GhostSnapshot {
|
||||||
|
ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()),
|
||||||
|
}];
|
||||||
|
let mut history = create_history_with_items(items);
|
||||||
|
let filtered = history.get_history_for_prompt();
|
||||||
|
assert_eq!(filtered, vec![]);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn remove_first_item_removes_matching_output_for_function_call() {
|
fn remove_first_item_removes_matching_output_for_function_call() {
|
||||||
let items = vec![
|
let items = vec![
|
||||||
|
|||||||
@@ -34,10 +34,6 @@ impl SessionState {
|
|||||||
self.history.record_items(items)
|
self.history.record_items(items)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn history_snapshot(&mut self) -> Vec<ResponseItem> {
|
|
||||||
self.history.get_history()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn clone_history(&self) -> ConversationHistory {
|
pub(crate) fn clone_history(&self) -> ConversationHistory {
|
||||||
self.history.clone()
|
self.history.clone()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,6 +68,14 @@ impl ResponsesRequest {
|
|||||||
.clone()
|
.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn inputs_of_type(&self, ty: &str) -> Vec<Value> {
|
||||||
|
self.input()
|
||||||
|
.iter()
|
||||||
|
.filter(|item| item.get("type").and_then(Value::as_str) == Some(ty))
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn function_call_output(&self, call_id: &str) -> Value {
|
pub fn function_call_output(&self, call_id: &str) -> Value {
|
||||||
self.call_output(call_id, "function_call_output")
|
self.call_output(call_id, "function_call_output")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,10 @@ use codex_core::shell::default_user_shell;
|
|||||||
use codex_protocol::user_input::UserInput;
|
use codex_protocol::user_input::UserInput;
|
||||||
use core_test_support::load_default_config_for_test;
|
use core_test_support::load_default_config_for_test;
|
||||||
use core_test_support::load_sse_fixture_with_id;
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
|
use core_test_support::responses;
|
||||||
|
use core_test_support::responses::mount_sse_once;
|
||||||
use core_test_support::skip_if_no_network;
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
use core_test_support::wait_for_event;
|
use core_test_support::wait_for_event;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
@@ -883,3 +886,68 @@ async fn send_user_turn_with_changes_sends_environment_context() {
|
|||||||
]);
|
]);
|
||||||
assert_eq!(body2["input"], expected_input_2);
|
assert_eq!(body2["input"], expected_input_2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn cached_prompt_filters_reasoning_items_from_previous_turns() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = responses::start_mock_server().await;
|
||||||
|
let call_id = "shell-call";
|
||||||
|
let shell_args = serde_json::json!({
|
||||||
|
"command": ["/bin/echo", "tool output"],
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
});
|
||||||
|
|
||||||
|
let initial_response = responses::sse(vec![
|
||||||
|
responses::ev_response_created("resp-first"),
|
||||||
|
responses::ev_reasoning_item("reason-1", &["Planning shell command"], &[]),
|
||||||
|
responses::ev_function_call(
|
||||||
|
call_id,
|
||||||
|
"shell",
|
||||||
|
&serde_json::to_string(&shell_args).expect("serialize shell args"),
|
||||||
|
),
|
||||||
|
responses::ev_completed("resp-first"),
|
||||||
|
]);
|
||||||
|
let follow_up_response = responses::sse(vec![
|
||||||
|
responses::ev_response_created("resp-follow-up"),
|
||||||
|
responses::ev_reasoning_item(
|
||||||
|
"reason-2",
|
||||||
|
&["Shell execution completed"],
|
||||||
|
&["stdout: tool output"],
|
||||||
|
),
|
||||||
|
responses::ev_assistant_message("assistant-1", "First turn reply"),
|
||||||
|
responses::ev_completed("resp-follow-up"),
|
||||||
|
]);
|
||||||
|
let second_turn_response = responses::sse(vec![
|
||||||
|
responses::ev_response_created("resp-second"),
|
||||||
|
responses::ev_assistant_message("assistant-2", "Second turn reply"),
|
||||||
|
responses::ev_completed("resp-second"),
|
||||||
|
]);
|
||||||
|
mount_sse_once(&server, initial_response).await;
|
||||||
|
let second_request = mount_sse_once(&server, follow_up_response).await;
|
||||||
|
let third_request = mount_sse_once(&server, second_turn_response).await;
|
||||||
|
|
||||||
|
let mut builder = test_codex();
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
test.submit_turn("hello 1").await?;
|
||||||
|
test.submit_turn("hello 2").await?;
|
||||||
|
|
||||||
|
let second_request_input = second_request.single_request();
|
||||||
|
let reasoning_items = second_request_input.inputs_of_type("reasoning");
|
||||||
|
assert_eq!(
|
||||||
|
reasoning_items.len(),
|
||||||
|
1,
|
||||||
|
"expected first turn follow-up to include reasoning item"
|
||||||
|
);
|
||||||
|
|
||||||
|
let third_request_input = third_request.single_request();
|
||||||
|
let cached_reasoning = third_request_input.inputs_of_type("reasoning");
|
||||||
|
assert_eq!(
|
||||||
|
cached_reasoning.len(),
|
||||||
|
0,
|
||||||
|
"expected cached prompt to filter out prior reasoning items"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user