chore: unify history loading (#2736)
We have two ways of loading conversation with a previous history. Fork conversation and the experimental resume that we had before. In this PR, I am unifying their code path. The path is getting the history items and recording them in a brand new conversation. This PR also constraint the rollout recorder responsibilities to be only recording to the disk and loading from the disk. The PR also fixes a current bug when we have two forking in a row: History 1: <Environment Context> UserMessage_1 UserMessage_2 UserMessage_3 **Fork with n = 1 (only remove one element)** History 2: <Environment Context> UserMessage_1 UserMessage_2 <Environment Context> **Fork with n = 1 (only remove one element)** History 2: <Environment Context> UserMessage_1 UserMessage_2 **<Environment Context>** This shouldn't happen but because we were appending the `<Environment Context>` after each spawning and it's considered as _user message_. Now, we don't add this message if restoring and old conversation.
This commit is contained in:
@@ -43,6 +43,7 @@ use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::environment_context::EnvironmentContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
@@ -169,7 +170,7 @@ impl Codex {
|
||||
pub async fn spawn(
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
initial_history: Option<Vec<ResponseItem>>,
|
||||
conversation_history: InitialHistory,
|
||||
) -> CodexResult<CodexSpawnOk> {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
@@ -177,7 +178,6 @@ impl Codex {
|
||||
let user_instructions = get_user_instructions(&config).await;
|
||||
|
||||
let config = Arc::new(config);
|
||||
let resume_path = config.experimental_resume.clone();
|
||||
|
||||
let configure_session = ConfigureSession {
|
||||
provider: config.model_provider.clone(),
|
||||
@@ -191,7 +191,6 @@ impl Codex {
|
||||
disable_response_storage: config.disable_response_storage,
|
||||
notify: config.notify.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
resume_path,
|
||||
};
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
@@ -200,13 +199,15 @@ impl Codex {
|
||||
config.clone(),
|
||||
auth_manager.clone(),
|
||||
tx_event.clone(),
|
||||
initial_history,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to create session: {e:#}");
|
||||
CodexErr::InternalAgentDied
|
||||
})?;
|
||||
session
|
||||
.record_initial_history(&turn_context, conversation_history)
|
||||
.await;
|
||||
let session_id = session.session_id;
|
||||
|
||||
// This task will run until Op::Shutdown is received.
|
||||
@@ -352,8 +353,6 @@ struct ConfigureSession {
|
||||
/// `ConfigureSession` operation so that the business-logic layer can
|
||||
/// operate deterministically.
|
||||
cwd: PathBuf,
|
||||
|
||||
resume_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -362,8 +361,8 @@ impl Session {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_history: Option<Vec<ResponseItem>>,
|
||||
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
|
||||
let session_id = Uuid::new_v4();
|
||||
let ConfigureSession {
|
||||
provider,
|
||||
model,
|
||||
@@ -376,7 +375,6 @@ impl Session {
|
||||
disable_response_storage,
|
||||
notify,
|
||||
cwd,
|
||||
resume_path,
|
||||
} = configure_session;
|
||||
debug!("Configuring session: model={model}; provider={provider:?}");
|
||||
if !cwd.is_absolute() {
|
||||
@@ -392,89 +390,25 @@ impl Session {
|
||||
// - spin up MCP connection manager
|
||||
// - perform default shell discovery
|
||||
// - load history metadata
|
||||
let rollout_fut = async {
|
||||
match resume_path.as_ref() {
|
||||
Some(path) => RolloutRecorder::resume(path, cwd.clone())
|
||||
.await
|
||||
.map(|(rec, saved)| (saved.session_id, Some(saved), rec)),
|
||||
None => {
|
||||
let session_id = Uuid::new_v4();
|
||||
RolloutRecorder::new(&config, session_id, user_instructions.clone())
|
||||
.await
|
||||
.map(|rec| (session_id, None, rec))
|
||||
}
|
||||
}
|
||||
};
|
||||
let rollout_fut = RolloutRecorder::new(&config, session_id, user_instructions.clone());
|
||||
|
||||
let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
|
||||
// Join all independent futures.
|
||||
let (rollout_res, mcp_res, default_shell, (history_log_id, history_entry_count)) =
|
||||
let (rollout_recorder, mcp_res, default_shell, (history_log_id, history_entry_count)) =
|
||||
tokio::join!(rollout_fut, mcp_fut, default_shell_fut, history_meta_fut);
|
||||
|
||||
// Handle rollout result, which determines the session_id.
|
||||
struct RolloutResult {
|
||||
session_id: Uuid,
|
||||
rollout_recorder: Option<RolloutRecorder>,
|
||||
restored_items: Option<Vec<ResponseItem>>,
|
||||
}
|
||||
let rollout_result = match rollout_res {
|
||||
Ok((session_id, maybe_saved, recorder)) => {
|
||||
let restored_items: Option<Vec<ResponseItem>> = initial_history.or_else(|| {
|
||||
maybe_saved.and_then(|saved_session| {
|
||||
if saved_session.items.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(saved_session.items)
|
||||
}
|
||||
})
|
||||
});
|
||||
RolloutResult {
|
||||
session_id,
|
||||
rollout_recorder: Some(recorder),
|
||||
restored_items,
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(path) = resume_path.as_ref() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"failed to resume rollout from {path:?}: {e}"
|
||||
));
|
||||
}
|
||||
|
||||
let message = format!("failed to initialize rollout recorder: {e}");
|
||||
post_session_configured_error_events.push(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
msg: EventMsg::Error(ErrorEvent {
|
||||
message: message.clone(),
|
||||
}),
|
||||
});
|
||||
warn!("{message}");
|
||||
|
||||
RolloutResult {
|
||||
session_id: Uuid::new_v4(),
|
||||
rollout_recorder: None,
|
||||
restored_items: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let RolloutResult {
|
||||
session_id,
|
||||
rollout_recorder,
|
||||
restored_items,
|
||||
} = rollout_result;
|
||||
|
||||
let rollout_recorder = rollout_recorder.map_err(|e| {
|
||||
error!("failed to initialize rollout recorder: {e:#}");
|
||||
anyhow::anyhow!("failed to initialize rollout recorder: {e:#}")
|
||||
})?;
|
||||
// Create the mutable state for the Session.
|
||||
let mut state = State {
|
||||
let state = State {
|
||||
history: ConversationHistory::new(),
|
||||
..Default::default()
|
||||
};
|
||||
if let Some(restored_items) = restored_items {
|
||||
state.history.record_items(&restored_items);
|
||||
}
|
||||
|
||||
// Handle MCP manager result and record any startup failures.
|
||||
let (mcp_connection_manager, failed_clients) = match mcp_res {
|
||||
@@ -539,26 +473,12 @@ impl Session {
|
||||
session_manager: ExecSessionManager::default(),
|
||||
notify,
|
||||
state: Mutex::new(state),
|
||||
rollout: Mutex::new(rollout_recorder),
|
||||
rollout: Mutex::new(Some(rollout_recorder)),
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
user_shell: default_shell,
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
});
|
||||
|
||||
// record the initial user instructions and environment context,
|
||||
// regardless of whether we restored items.
|
||||
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
||||
}
|
||||
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
Some(sess.user_shell.clone()),
|
||||
)));
|
||||
sess.record_conversation_items(&conversation_items).await;
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
let events = std::iter::once(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
@@ -596,6 +516,42 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_initial_history(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
conversation_history: InitialHistory,
|
||||
) {
|
||||
match conversation_history {
|
||||
InitialHistory::New => {
|
||||
self.record_initial_history_new(turn_context).await;
|
||||
}
|
||||
InitialHistory::Resumed(items) => {
|
||||
self.record_initial_history_resumed(items).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_initial_history_new(&self, turn_context: &TurnContext) {
|
||||
// record the initial user instructions and environment context,
|
||||
// regardless of whether we restored items.
|
||||
// TODO: Those items shouldn't be "user messages" IMO. Maybe developer messages.
|
||||
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
||||
}
|
||||
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
Some(self.user_shell.clone()),
|
||||
)));
|
||||
self.record_conversation_items(&conversation_items).await;
|
||||
}
|
||||
|
||||
async fn record_initial_history_resumed(&self, items: Vec<ResponseItem>) {
|
||||
self.record_conversation_items(&items).await;
|
||||
}
|
||||
|
||||
/// Sends the given event to the client and swallows the send event, if
|
||||
/// any, logging it as an error.
|
||||
pub(crate) async fn send_event(&self, event: Event) {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_login::AuthManager;
|
||||
@@ -16,8 +17,15 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum InitialHistory {
|
||||
New,
|
||||
Resumed(Vec<ResponseItem>),
|
||||
}
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
@@ -57,14 +65,21 @@ impl ConversationManager {
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = {
|
||||
let initial_history = None;
|
||||
Codex::spawn(config, auth_manager, initial_history).await?
|
||||
};
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
// TO BE REFACTORED: use the config experimental_resume field until we have a mainstream way.
|
||||
if let Some(resume_path) = config.experimental_resume.as_ref() {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(resume_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
} else {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = { Codex::spawn(config, auth_manager, InitialHistory::New).await? };
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn finalize_spawn(
|
||||
@@ -110,6 +125,20 @@ impl ConversationManager {
|
||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||
}
|
||||
|
||||
pub async fn resume_conversation_from_rollout(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
||||
self.conversations.write().await.remove(&conversation_id);
|
||||
}
|
||||
@@ -125,7 +154,7 @@ impl ConversationManager {
|
||||
config: Config,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let truncated_history =
|
||||
let history =
|
||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
@@ -133,7 +162,7 @@ impl ConversationManager {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, Some(truncated_history)).await?;
|
||||
} = Codex::spawn(config, auth_manager, history).await?;
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
@@ -141,9 +170,9 @@ impl ConversationManager {
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> Vec<ResponseItem> {
|
||||
if n == 0 || items.is_empty() {
|
||||
return items;
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return InitialHistory::Resumed(items);
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
@@ -161,11 +190,11 @@ fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) ->
|
||||
}
|
||||
}
|
||||
}
|
||||
if count < n {
|
||||
// If fewer than n messages exist, drop everything.
|
||||
Vec::new()
|
||||
if cut_index == 0 {
|
||||
// No prefix remains after dropping; start a new conversation.
|
||||
InitialHistory::New
|
||||
} else {
|
||||
items.into_iter().take(cut_index).collect()
|
||||
InitialHistory::Resumed(items.into_iter().take(cut_index).collect())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,10 +252,10 @@ mod tests {
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
assert_eq!(
|
||||
truncated,
|
||||
vec![items[0].clone(), items[1].clone(), items[2].clone()]
|
||||
InitialHistory::Resumed(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert!(truncated2.is_empty());
|
||||
assert_eq!(truncated2, InitialHistory::New);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::conversation_manager::InitialHistory;
|
||||
use crate::git_info::GitInfo;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -157,20 +158,14 @@ impl RolloutRecorder {
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(
|
||||
path: &Path,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<(Self, SavedSession)> {
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
let _ = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
@@ -185,9 +180,6 @@ impl RolloutRecorder {
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
@@ -207,27 +199,12 @@ impl RolloutRecorder {
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
None,
|
||||
cwd,
|
||||
));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
if items.is_empty() {
|
||||
Ok(InitialHistory::New)
|
||||
} else {
|
||||
Ok(InitialHistory::Resumed(items))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user