Handle resuming/forking after compact (#3533)

We need to construct the history different when compact happens. For
this, we need to just consider the history after compact and convert
compact to a response item.

This needs to change and use `build_compact_history` when this #3446 is
merged.
This commit is contained in:
Ahmed Ibrahim
2025-09-14 09:23:31 -04:00
committed by GitHub
parent 4891ee29c5
commit bbea6bbf7e
6 changed files with 1118 additions and 46 deletions

View File

@@ -130,6 +130,8 @@ use codex_protocol::models::ShellToolCallParams;
use codex_protocol::protocol::InitialHistory;
mod compact;
use self::compact::build_compacted_history;
use self::compact::collect_user_messages;
// A convenience extension trait for acquiring mutex locks where poisoning is
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
@@ -205,7 +207,7 @@ impl Codex {
config.clone(),
auth_manager.clone(),
tx_event.clone(),
conversation_history.clone(),
conversation_history,
)
.await
.map_err(|e| {
@@ -564,9 +566,10 @@ impl Session {
let persist = matches!(conversation_history, InitialHistory::Forked(_));
// Always add response items to conversation history
let response_items = conversation_history.get_response_items();
if !response_items.is_empty() {
self.record_into_history(&response_items);
let reconstructed_history =
self.reconstruct_history_from_rollout(turn_context, &rollout_items);
if !reconstructed_history.is_empty() {
self.record_into_history(&reconstructed_history);
}
// If persisting, persist all rollout items as-is (recorder filters)
@@ -678,6 +681,33 @@ impl Session {
self.persist_rollout_response_items(items).await;
}
fn reconstruct_history_from_rollout(
&self,
turn_context: &TurnContext,
rollout_items: &[RolloutItem],
) -> Vec<ResponseItem> {
let mut history = ConversationHistory::new();
for item in rollout_items {
match item {
RolloutItem::ResponseItem(response_item) => {
history.record_items(std::iter::once(response_item));
}
RolloutItem::Compacted(compacted) => {
let snapshot = history.contents();
let user_messages = collect_user_messages(&snapshot);
let rebuilt = build_compacted_history(
self.build_initial_context(turn_context),
&user_messages,
&compacted.message,
);
history.replace(rebuilt);
}
_ => {}
}
}
history.contents()
}
/// Append ResponseItems to the in-memory conversation history only.
fn record_into_history(&self, items: &[ResponseItem]) {
self.state
@@ -3220,18 +3250,59 @@ async fn exit_review_mode(
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConfigOverrides;
use crate::config::ConfigToml;
use crate::protocol::CompactedItem;
use crate::protocol::InitialHistory;
use crate::protocol::ResumedHistory;
use codex_protocol::models::ContentItem;
use mcp_types::ContentBlock;
use mcp_types::TextContent;
use pretty_assertions::assert_eq;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration as StdDuration;
fn text_block(s: &str) -> ContentBlock {
ContentBlock::TextContent(TextContent {
annotations: None,
text: s.to_string(),
r#type: "text".to_string(),
})
#[test]
fn reconstruct_history_matches_live_compactions() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);
let reconstructed = session.reconstruct_history_from_rollout(&turn_context, &rollout_items);
assert_eq!(expected, reconstructed);
}
#[test]
fn record_initial_history_reconstructs_resumed_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);
tokio_test::block_on(session.record_initial_history(
&turn_context,
InitialHistory::Resumed(ResumedHistory {
conversation_id: ConversationId::default(),
history: rollout_items,
rollout_path: PathBuf::from("/tmp/resume.jsonl"),
}),
));
let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}
#[test]
fn record_initial_history_reconstructs_forked_transcript() {
let (session, turn_context) = make_session_and_context();
let (rollout_items, expected) = sample_rollout(&session, &turn_context);
tokio_test::block_on(
session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)),
);
let actual = session.state.lock_unchecked().history.contents();
assert_eq!(expected, actual);
}
#[test]
@@ -3386,4 +3457,174 @@ mod tests {
assert_eq!(expected, got);
}
fn text_block(s: &str) -> ContentBlock {
ContentBlock::TextContent(TextContent {
annotations: None,
text: s.to_string(),
r#type: "text".to_string(),
})
}
fn make_session_and_context() -> (Session, TurnContext) {
let (tx_event, _rx_event) = async_channel::unbounded();
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
codex_home.path().to_path_buf(),
)
.expect("load default test config");
let config = Arc::new(config);
let conversation_id = ConversationId::default();
let client = ModelClient::new(
config.clone(),
None,
config.model_provider.clone(),
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
include_view_image_tool: config.include_view_image_tool,
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
});
let turn_context = TurnContext {
client,
cwd: config.cwd.clone(),
base_instructions: config.base_instructions.clone(),
user_instructions: config.user_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
tools_config,
is_review_mode: false,
};
let session = Session {
conversation_id,
tx_event,
mcp_connection_manager: McpConnectionManager::default(),
session_manager: ExecSessionManager::default(),
unified_exec_manager: UnifiedExecSessionManager::default(),
notify: None,
rollout: Mutex::new(None),
state: Mutex::new(State {
history: ConversationHistory::new(),
..Default::default()
}),
codex_linux_sandbox_exe: None,
user_shell: shell::Shell::Unknown,
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
};
(session, turn_context)
}
fn sample_rollout(
session: &Session,
turn_context: &TurnContext,
) -> (Vec<RolloutItem>, Vec<ResponseItem>) {
let mut rollout_items = Vec::new();
let mut live_history = ConversationHistory::new();
let initial_context = session.build_initial_context(turn_context);
for item in &initial_context {
rollout_items.push(RolloutItem::ResponseItem(item.clone()));
}
live_history.record_items(initial_context.iter());
let user1 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "first user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user1));
rollout_items.push(RolloutItem::ResponseItem(user1.clone()));
let assistant1 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply one".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant1));
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
let summary1 = "summary one";
let snapshot1 = live_history.contents();
let user_messages1 = collect_user_messages(&snapshot1);
let rebuilt1 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages1,
summary1,
);
live_history.replace(rebuilt1);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary1.to_string(),
}));
let user2 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "second user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user2));
rollout_items.push(RolloutItem::ResponseItem(user2.clone()));
let assistant2 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply two".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant2));
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
let summary2 = "summary two";
let snapshot2 = live_history.contents();
let user_messages2 = collect_user_messages(&snapshot2);
let rebuilt2 = build_compacted_history(
session.build_initial_context(turn_context),
&user_messages2,
summary2,
);
live_history.replace(rebuilt2);
rollout_items.push(RolloutItem::Compacted(CompactedItem {
message: summary2.to_string(),
}));
let user3 = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "third user".to_string(),
}],
};
live_history.record_items(std::iter::once(&user3));
rollout_items.push(RolloutItem::ResponseItem(user3.clone()));
let assistant3 = ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "assistant reply three".to_string(),
}],
};
live_history.record_items(std::iter::once(&assistant3));
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));
(rollout_items, live_history.contents())
}
}

View File

@@ -176,8 +176,8 @@ async fn run_compact_task_inner(
};
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);
let new_history =
build_compacted_history(&sess, turn_context.as_ref(), &user_messages, &summary_text);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
{
let mut state = sess.state.lock_unchecked();
state.history.replace(new_history);
@@ -223,7 +223,7 @@ fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
}
}
fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match item {
@@ -243,13 +243,12 @@ fn is_session_prefix_message(text: &str) -> bool {
)
}
fn build_compacted_history(
sess: &Session,
turn_context: &TurnContext,
pub(crate) fn build_compacted_history(
initial_context: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
) -> Vec<ResponseItem> {
let mut history = sess.build_initial_context(turn_context);
let mut history = initial_context;
let user_messages_text = if user_messages.is_empty() {
"(none)".to_string()
} else {