diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 47d80c5c..7c8151d2 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1001,12 +1001,6 @@ impl Session { } } - // todo (aibrahim): get rid of this method. we shouldn't deal with vec[resposne_item] and rather use ConversationHistory. - pub(crate) async fn history_snapshot(&self) -> Vec { - let mut state = self.state.lock().await; - state.history_snapshot() - } - pub(crate) async fn clone_history(&self) -> ConversationHistory { let state = self.state.lock().await; state.clone_history() @@ -1746,11 +1740,11 @@ pub(crate) async fn run_task( if !pending_input.is_empty() { review_thread_history.record_items(&pending_input); } - review_thread_history.get_history() + review_thread_history.get_history_for_prompt() } else { sess.record_conversation_items(&turn_context, &pending_input) .await; - sess.history_snapshot().await + sess.clone_history().await.get_history_for_prompt() }; let turn_input_messages: Vec = turn_input @@ -1907,13 +1901,6 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent { } } -fn filter_model_visible_history(input: Vec) -> Vec { - input - .into_iter() - .filter(|item| !matches!(item, ResponseItem::GhostSnapshot { .. })) - .collect() -} - async fn run_turn( sess: Arc, turn_context: Arc, @@ -1934,7 +1921,7 @@ async fn run_turn( .supports_parallel_tool_calls; let parallel_tool_calls = model_supports_parallel; let prompt = Prompt { - input: filter_model_visible_history(input), + input, tools: router.specs(), parallel_tool_calls, base_instructions_override: turn_context.base_instructions.clone(), @@ -2462,7 +2449,9 @@ mod tests { }, ))); - let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); + let actual = tokio_test::block_on(async { + session.state.lock().await.clone_history().get_history() + }); assert_eq!(expected, actual); } @@ -2473,7 +2462,9 @@ mod tests { tokio_test::block_on(session.record_initial_history(InitialHistory::Forked(rollout_items))); - let actual = tokio_test::block_on(async { session.state.lock().await.history_snapshot() }); + let actual = tokio_test::block_on(async { + session.state.lock().await.clone_history().get_history() + }); assert_eq!(expected, actual); } @@ -2870,7 +2861,7 @@ mod tests { } } - let history = sess.history_snapshot().await; + let history = sess.clone_history().await.get_history(); let found = history.iter().any(|item| match item { ResponseItem::Message { role, content, .. } if role == "user" => { content.iter().any(|ci| match ci { diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index dc7a6aa5..a8340cd0 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use super::Session; use super::TurnContext; -use super::filter_model_visible_history; use super::get_last_assistant_message_from_turn; use crate::Prompt; use crate::client_common::ResponseEvent; @@ -86,10 +85,9 @@ async fn run_compact_task_inner( sess.persist_rollout_items(&[rollout_item]).await; loop { - let turn_input = history.get_history(); - let prompt_input = filter_model_visible_history(turn_input.clone()); + let turn_input = history.get_history_for_prompt(); let prompt = Prompt { - input: prompt_input.clone(), + input: turn_input.clone(), ..Default::default() }; let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await; @@ -111,7 +109,7 @@ async fn run_compact_task_inner( return; } Err(e @ CodexErr::ContextWindowExceeded) => { - if prompt_input.len() > 1 { + if turn_input.len() > 1 { // Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact. error!( "Context window exceeded while compacting; removing oldest history item. Error: {e}" @@ -150,7 +148,7 @@ async fn run_compact_task_inner( } } - let history_snapshot = sess.history_snapshot().await; + let history_snapshot = sess.clone_history().await.get_history(); let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default(); let user_messages = collect_user_messages(&history_snapshot); let initial_context = sess.build_initial_context(turn_context.as_ref()); diff --git a/codex-rs/core/src/conversation_history.rs b/codex-rs/core/src/conversation_history.rs index 27291ebc..9c724df6 100644 --- a/codex-rs/core/src/conversation_history.rs +++ b/codex-rs/core/src/conversation_history.rs @@ -67,6 +67,15 @@ impl ConversationHistory { self.contents() } + // Returns the history prepared for sending to the model. + // With extra response items filtered out and GhostCommits removed. + pub(crate) fn get_history_for_prompt(&mut self) -> Vec { + let mut history = self.get_history(); + Self::remove_ghost_snapshots(&mut history); + Self::remove_reasoning_before_last_turn(&mut history); + history + } + pub(crate) fn remove_first_item(&mut self) { if !self.items.is_empty() { // Remove the oldest item (front of the list). Items are ordered from @@ -111,6 +120,29 @@ impl ConversationHistory { self.items.clone() } + fn remove_ghost_snapshots(items: &mut Vec) { + items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. })); + } + + fn remove_reasoning_before_last_turn(items: &mut Vec) { + // Responses API drops reasoning items before the last user message. + // Sending them is harmless but can lead to validation errors when switching between API organizations. + // https://cookbook.openai.com/examples/responses_api/reasoning_items#caching + let Some(last_user_index) = items + .iter() + // Use last user message as the turn boundary. + .rposition(|item| matches!(item, ResponseItem::Message { role, .. } if role == "user")) + else { + return; + }; + let mut index = 0usize; + items.retain(|item| { + let keep = index >= last_user_index || !matches!(item, ResponseItem::Reasoning { .. }); + index += 1; + keep + }); + } + fn ensure_call_outputs_present(&mut self) { // Collect synthetic outputs to insert immediately after their calls. // Store the insertion position (index of call) alongside the item so @@ -498,6 +530,7 @@ fn is_api_message(message: &ResponseItem) -> bool { #[cfg(test)] mod tests { use super::*; + use codex_git_tooling::GhostCommit; use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputPayload; use codex_protocol::models::LocalShellAction; @@ -515,6 +548,15 @@ mod tests { } } + fn reasoning(id: &str) -> ResponseItem { + ResponseItem::Reasoning { + id: id.to_string(), + summary: Vec::new(), + content: None, + encrypted_content: None, + } + } + fn create_history_with_items(items: Vec) -> ConversationHistory { let mut h = ConversationHistory::new(); h.record_items(items.iter()); @@ -571,6 +613,50 @@ mod tests { ); } + #[test] + fn get_history_drops_reasoning_before_last_user_message() { + let mut history = ConversationHistory::new(); + let items = vec![ + user_msg("initial"), + reasoning("first"), + assistant_msg("ack"), + user_msg("latest"), + reasoning("second"), + assistant_msg("ack"), + reasoning("third"), + ]; + history.record_items(items.iter()); + + let filtered = history.get_history_for_prompt(); + assert_eq!( + filtered, + vec![ + user_msg("initial"), + assistant_msg("ack"), + user_msg("latest"), + reasoning("second"), + assistant_msg("ack"), + reasoning("third"), + ] + ); + let reasoning_count = history + .contents() + .iter() + .filter(|item| matches!(item, ResponseItem::Reasoning { .. })) + .count(); + assert_eq!(reasoning_count, 3); + } + + #[test] + fn get_history_for_prompt_drops_ghost_commits() { + let items = vec![ResponseItem::GhostSnapshot { + ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()), + }]; + let mut history = create_history_with_items(items); + let filtered = history.get_history_for_prompt(); + assert_eq!(filtered, vec![]); + } + #[test] fn remove_first_item_removes_matching_output_for_function_call() { let items = vec![ diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index 7c4603d9..a41d2b63 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -34,10 +34,6 @@ impl SessionState { self.history.record_items(items) } - pub(crate) fn history_snapshot(&mut self) -> Vec { - self.history.get_history() - } - pub(crate) fn clone_history(&self) -> ConversationHistory { self.history.clone() } diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 511c0b5b..7fc6a69c 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -68,6 +68,14 @@ impl ResponsesRequest { .clone() } + pub fn inputs_of_type(&self, ty: &str) -> Vec { + self.input() + .iter() + .filter(|item| item.get("type").and_then(Value::as_str) == Some(ty)) + .cloned() + .collect() + } + pub fn function_call_output(&self, call_id: &str) -> Value { self.call_output(call_id, "function_call_output") } diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index 1bf12eb9..04304126 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -18,7 +18,10 @@ use codex_core::shell::default_user_shell; use codex_protocol::user_input::UserInput; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; +use core_test_support::responses; +use core_test_support::responses::mount_sse_once; use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; use std::collections::HashMap; use tempfile::TempDir; @@ -883,3 +886,68 @@ async fn send_user_turn_with_changes_sends_environment_context() { ]); assert_eq!(body2["input"], expected_input_2); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn cached_prompt_filters_reasoning_items_from_previous_turns() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + let call_id = "shell-call"; + let shell_args = serde_json::json!({ + "command": ["/bin/echo", "tool output"], + "timeout_ms": 1_000, + }); + + let initial_response = responses::sse(vec![ + responses::ev_response_created("resp-first"), + responses::ev_reasoning_item("reason-1", &["Planning shell command"], &[]), + responses::ev_function_call( + call_id, + "shell", + &serde_json::to_string(&shell_args).expect("serialize shell args"), + ), + responses::ev_completed("resp-first"), + ]); + let follow_up_response = responses::sse(vec![ + responses::ev_response_created("resp-follow-up"), + responses::ev_reasoning_item( + "reason-2", + &["Shell execution completed"], + &["stdout: tool output"], + ), + responses::ev_assistant_message("assistant-1", "First turn reply"), + responses::ev_completed("resp-follow-up"), + ]); + let second_turn_response = responses::sse(vec![ + responses::ev_response_created("resp-second"), + responses::ev_assistant_message("assistant-2", "Second turn reply"), + responses::ev_completed("resp-second"), + ]); + mount_sse_once(&server, initial_response).await; + let second_request = mount_sse_once(&server, follow_up_response).await; + let third_request = mount_sse_once(&server, second_turn_response).await; + + let mut builder = test_codex(); + let test = builder.build(&server).await?; + + test.submit_turn("hello 1").await?; + test.submit_turn("hello 2").await?; + + let second_request_input = second_request.single_request(); + let reasoning_items = second_request_input.inputs_of_type("reasoning"); + assert_eq!( + reasoning_items.len(), + 1, + "expected first turn follow-up to include reasoning item" + ); + + let third_request_input = third_request.single_request(); + let cached_reasoning = third_request_input.inputs_of_type("reasoning"); + assert_eq!( + cached_reasoning.len(), + 0, + "expected cached prompt to filter out prior reasoning items" + ); + + Ok(()) +}