chore: unify history loading (#2736)
We have two ways of loading conversation with a previous history. Fork conversation and the experimental resume that we had before. In this PR, I am unifying their code path. The path is getting the history items and recording them in a brand new conversation. This PR also constraint the rollout recorder responsibilities to be only recording to the disk and loading from the disk. The PR also fixes a current bug when we have two forking in a row: History 1: <Environment Context> UserMessage_1 UserMessage_2 UserMessage_3 **Fork with n = 1 (only remove one element)** History 2: <Environment Context> UserMessage_1 UserMessage_2 <Environment Context> **Fork with n = 1 (only remove one element)** History 2: <Environment Context> UserMessage_1 UserMessage_2 **<Environment Context>** This shouldn't happen but because we were appending the `<Environment Context>` after each spawning and it's considered as _user message_. Now, we don't add this message if restoring and old conversation.
This commit is contained in:
@@ -43,6 +43,7 @@ use crate::client_common::ResponseEvent;
|
|||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::config_types::ShellEnvironmentPolicy;
|
use crate::config_types::ShellEnvironmentPolicy;
|
||||||
use crate::conversation_history::ConversationHistory;
|
use crate::conversation_history::ConversationHistory;
|
||||||
|
use crate::conversation_manager::InitialHistory;
|
||||||
use crate::environment_context::EnvironmentContext;
|
use crate::environment_context::EnvironmentContext;
|
||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::error::Result as CodexResult;
|
use crate::error::Result as CodexResult;
|
||||||
@@ -169,7 +170,7 @@ impl Codex {
|
|||||||
pub async fn spawn(
|
pub async fn spawn(
|
||||||
config: Config,
|
config: Config,
|
||||||
auth_manager: Arc<AuthManager>,
|
auth_manager: Arc<AuthManager>,
|
||||||
initial_history: Option<Vec<ResponseItem>>,
|
conversation_history: InitialHistory,
|
||||||
) -> CodexResult<CodexSpawnOk> {
|
) -> CodexResult<CodexSpawnOk> {
|
||||||
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||||
let (tx_event, rx_event) = async_channel::unbounded();
|
let (tx_event, rx_event) = async_channel::unbounded();
|
||||||
@@ -177,7 +178,6 @@ impl Codex {
|
|||||||
let user_instructions = get_user_instructions(&config).await;
|
let user_instructions = get_user_instructions(&config).await;
|
||||||
|
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
let resume_path = config.experimental_resume.clone();
|
|
||||||
|
|
||||||
let configure_session = ConfigureSession {
|
let configure_session = ConfigureSession {
|
||||||
provider: config.model_provider.clone(),
|
provider: config.model_provider.clone(),
|
||||||
@@ -191,7 +191,6 @@ impl Codex {
|
|||||||
disable_response_storage: config.disable_response_storage,
|
disable_response_storage: config.disable_response_storage,
|
||||||
notify: config.notify.clone(),
|
notify: config.notify.clone(),
|
||||||
cwd: config.cwd.clone(),
|
cwd: config.cwd.clone(),
|
||||||
resume_path,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generate a unique ID for the lifetime of this Codex session.
|
// Generate a unique ID for the lifetime of this Codex session.
|
||||||
@@ -200,13 +199,15 @@ impl Codex {
|
|||||||
config.clone(),
|
config.clone(),
|
||||||
auth_manager.clone(),
|
auth_manager.clone(),
|
||||||
tx_event.clone(),
|
tx_event.clone(),
|
||||||
initial_history,
|
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
error!("Failed to create session: {e:#}");
|
error!("Failed to create session: {e:#}");
|
||||||
CodexErr::InternalAgentDied
|
CodexErr::InternalAgentDied
|
||||||
})?;
|
})?;
|
||||||
|
session
|
||||||
|
.record_initial_history(&turn_context, conversation_history)
|
||||||
|
.await;
|
||||||
let session_id = session.session_id;
|
let session_id = session.session_id;
|
||||||
|
|
||||||
// This task will run until Op::Shutdown is received.
|
// This task will run until Op::Shutdown is received.
|
||||||
@@ -352,8 +353,6 @@ struct ConfigureSession {
|
|||||||
/// `ConfigureSession` operation so that the business-logic layer can
|
/// `ConfigureSession` operation so that the business-logic layer can
|
||||||
/// operate deterministically.
|
/// operate deterministically.
|
||||||
cwd: PathBuf,
|
cwd: PathBuf,
|
||||||
|
|
||||||
resume_path: Option<PathBuf>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@@ -362,8 +361,8 @@ impl Session {
|
|||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
auth_manager: Arc<AuthManager>,
|
auth_manager: Arc<AuthManager>,
|
||||||
tx_event: Sender<Event>,
|
tx_event: Sender<Event>,
|
||||||
initial_history: Option<Vec<ResponseItem>>,
|
|
||||||
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
|
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
|
||||||
|
let session_id = Uuid::new_v4();
|
||||||
let ConfigureSession {
|
let ConfigureSession {
|
||||||
provider,
|
provider,
|
||||||
model,
|
model,
|
||||||
@@ -376,7 +375,6 @@ impl Session {
|
|||||||
disable_response_storage,
|
disable_response_storage,
|
||||||
notify,
|
notify,
|
||||||
cwd,
|
cwd,
|
||||||
resume_path,
|
|
||||||
} = configure_session;
|
} = configure_session;
|
||||||
debug!("Configuring session: model={model}; provider={provider:?}");
|
debug!("Configuring session: model={model}; provider={provider:?}");
|
||||||
if !cwd.is_absolute() {
|
if !cwd.is_absolute() {
|
||||||
@@ -392,89 +390,25 @@ impl Session {
|
|||||||
// - spin up MCP connection manager
|
// - spin up MCP connection manager
|
||||||
// - perform default shell discovery
|
// - perform default shell discovery
|
||||||
// - load history metadata
|
// - load history metadata
|
||||||
let rollout_fut = async {
|
let rollout_fut = RolloutRecorder::new(&config, session_id, user_instructions.clone());
|
||||||
match resume_path.as_ref() {
|
|
||||||
Some(path) => RolloutRecorder::resume(path, cwd.clone())
|
|
||||||
.await
|
|
||||||
.map(|(rec, saved)| (saved.session_id, Some(saved), rec)),
|
|
||||||
None => {
|
|
||||||
let session_id = Uuid::new_v4();
|
|
||||||
RolloutRecorder::new(&config, session_id, user_instructions.clone())
|
|
||||||
.await
|
|
||||||
.map(|rec| (session_id, None, rec))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
|
let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
|
||||||
let default_shell_fut = shell::default_user_shell();
|
let default_shell_fut = shell::default_user_shell();
|
||||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||||
|
|
||||||
// Join all independent futures.
|
// Join all independent futures.
|
||||||
let (rollout_res, mcp_res, default_shell, (history_log_id, history_entry_count)) =
|
let (rollout_recorder, mcp_res, default_shell, (history_log_id, history_entry_count)) =
|
||||||
tokio::join!(rollout_fut, mcp_fut, default_shell_fut, history_meta_fut);
|
tokio::join!(rollout_fut, mcp_fut, default_shell_fut, history_meta_fut);
|
||||||
|
|
||||||
// Handle rollout result, which determines the session_id.
|
let rollout_recorder = rollout_recorder.map_err(|e| {
|
||||||
struct RolloutResult {
|
error!("failed to initialize rollout recorder: {e:#}");
|
||||||
session_id: Uuid,
|
anyhow::anyhow!("failed to initialize rollout recorder: {e:#}")
|
||||||
rollout_recorder: Option<RolloutRecorder>,
|
})?;
|
||||||
restored_items: Option<Vec<ResponseItem>>,
|
|
||||||
}
|
|
||||||
let rollout_result = match rollout_res {
|
|
||||||
Ok((session_id, maybe_saved, recorder)) => {
|
|
||||||
let restored_items: Option<Vec<ResponseItem>> = initial_history.or_else(|| {
|
|
||||||
maybe_saved.and_then(|saved_session| {
|
|
||||||
if saved_session.items.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(saved_session.items)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
RolloutResult {
|
|
||||||
session_id,
|
|
||||||
rollout_recorder: Some(recorder),
|
|
||||||
restored_items,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if let Some(path) = resume_path.as_ref() {
|
|
||||||
return Err(anyhow::anyhow!(
|
|
||||||
"failed to resume rollout from {path:?}: {e}"
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let message = format!("failed to initialize rollout recorder: {e}");
|
|
||||||
post_session_configured_error_events.push(Event {
|
|
||||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
|
||||||
msg: EventMsg::Error(ErrorEvent {
|
|
||||||
message: message.clone(),
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
warn!("{message}");
|
|
||||||
|
|
||||||
RolloutResult {
|
|
||||||
session_id: Uuid::new_v4(),
|
|
||||||
rollout_recorder: None,
|
|
||||||
restored_items: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let RolloutResult {
|
|
||||||
session_id,
|
|
||||||
rollout_recorder,
|
|
||||||
restored_items,
|
|
||||||
} = rollout_result;
|
|
||||||
|
|
||||||
// Create the mutable state for the Session.
|
// Create the mutable state for the Session.
|
||||||
let mut state = State {
|
let state = State {
|
||||||
history: ConversationHistory::new(),
|
history: ConversationHistory::new(),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
if let Some(restored_items) = restored_items {
|
|
||||||
state.history.record_items(&restored_items);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle MCP manager result and record any startup failures.
|
// Handle MCP manager result and record any startup failures.
|
||||||
let (mcp_connection_manager, failed_clients) = match mcp_res {
|
let (mcp_connection_manager, failed_clients) = match mcp_res {
|
||||||
@@ -539,26 +473,12 @@ impl Session {
|
|||||||
session_manager: ExecSessionManager::default(),
|
session_manager: ExecSessionManager::default(),
|
||||||
notify,
|
notify,
|
||||||
state: Mutex::new(state),
|
state: Mutex::new(state),
|
||||||
rollout: Mutex::new(rollout_recorder),
|
rollout: Mutex::new(Some(rollout_recorder)),
|
||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
user_shell: default_shell,
|
user_shell: default_shell,
|
||||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||||
});
|
});
|
||||||
|
|
||||||
// record the initial user instructions and environment context,
|
|
||||||
// regardless of whether we restored items.
|
|
||||||
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
|
||||||
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
|
||||||
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
|
||||||
}
|
|
||||||
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
|
||||||
Some(turn_context.cwd.clone()),
|
|
||||||
Some(turn_context.approval_policy),
|
|
||||||
Some(turn_context.sandbox_policy.clone()),
|
|
||||||
Some(sess.user_shell.clone()),
|
|
||||||
)));
|
|
||||||
sess.record_conversation_items(&conversation_items).await;
|
|
||||||
|
|
||||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||||
let events = std::iter::once(Event {
|
let events = std::iter::once(Event {
|
||||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||||
@@ -596,6 +516,42 @@ impl Session {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn record_initial_history(
|
||||||
|
&self,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
conversation_history: InitialHistory,
|
||||||
|
) {
|
||||||
|
match conversation_history {
|
||||||
|
InitialHistory::New => {
|
||||||
|
self.record_initial_history_new(turn_context).await;
|
||||||
|
}
|
||||||
|
InitialHistory::Resumed(items) => {
|
||||||
|
self.record_initial_history_resumed(items).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn record_initial_history_new(&self, turn_context: &TurnContext) {
|
||||||
|
// record the initial user instructions and environment context,
|
||||||
|
// regardless of whether we restored items.
|
||||||
|
// TODO: Those items shouldn't be "user messages" IMO. Maybe developer messages.
|
||||||
|
let mut conversation_items = Vec::<ResponseItem>::with_capacity(2);
|
||||||
|
if let Some(user_instructions) = turn_context.user_instructions.as_deref() {
|
||||||
|
conversation_items.push(Prompt::format_user_instructions_message(user_instructions));
|
||||||
|
}
|
||||||
|
conversation_items.push(ResponseItem::from(EnvironmentContext::new(
|
||||||
|
Some(turn_context.cwd.clone()),
|
||||||
|
Some(turn_context.approval_policy),
|
||||||
|
Some(turn_context.sandbox_policy.clone()),
|
||||||
|
Some(self.user_shell.clone()),
|
||||||
|
)));
|
||||||
|
self.record_conversation_items(&conversation_items).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn record_initial_history_resumed(&self, items: Vec<ResponseItem>) {
|
||||||
|
self.record_conversation_items(&items).await;
|
||||||
|
}
|
||||||
|
|
||||||
/// Sends the given event to the client and swallows the send event, if
|
/// Sends the given event to the client and swallows the send event, if
|
||||||
/// any, logging it as an error.
|
/// any, logging it as an error.
|
||||||
pub(crate) async fn send_event(&self, event: Event) {
|
pub(crate) async fn send_event(&self, event: Event) {
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use codex_login::AuthManager;
|
use codex_login::AuthManager;
|
||||||
@@ -16,8 +17,15 @@ use crate::error::Result as CodexResult;
|
|||||||
use crate::protocol::Event;
|
use crate::protocol::Event;
|
||||||
use crate::protocol::EventMsg;
|
use crate::protocol::EventMsg;
|
||||||
use crate::protocol::SessionConfiguredEvent;
|
use crate::protocol::SessionConfiguredEvent;
|
||||||
|
use crate::rollout::RolloutRecorder;
|
||||||
use codex_protocol::models::ResponseItem;
|
use codex_protocol::models::ResponseItem;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum InitialHistory {
|
||||||
|
New,
|
||||||
|
Resumed(Vec<ResponseItem>),
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents a newly created Codex conversation, including the first event
|
/// Represents a newly created Codex conversation, including the first event
|
||||||
/// (which is [`EventMsg::SessionConfigured`]).
|
/// (which is [`EventMsg::SessionConfigured`]).
|
||||||
pub struct NewConversation {
|
pub struct NewConversation {
|
||||||
@@ -57,14 +65,21 @@ impl ConversationManager {
|
|||||||
config: Config,
|
config: Config,
|
||||||
auth_manager: Arc<AuthManager>,
|
auth_manager: Arc<AuthManager>,
|
||||||
) -> CodexResult<NewConversation> {
|
) -> CodexResult<NewConversation> {
|
||||||
let CodexSpawnOk {
|
// TO BE REFACTORED: use the config experimental_resume field until we have a mainstream way.
|
||||||
codex,
|
if let Some(resume_path) = config.experimental_resume.as_ref() {
|
||||||
session_id: conversation_id,
|
let initial_history = RolloutRecorder::get_rollout_history(resume_path).await?;
|
||||||
} = {
|
let CodexSpawnOk {
|
||||||
let initial_history = None;
|
codex,
|
||||||
Codex::spawn(config, auth_manager, initial_history).await?
|
session_id: conversation_id,
|
||||||
};
|
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||||
self.finalize_spawn(codex, conversation_id).await
|
self.finalize_spawn(codex, conversation_id).await
|
||||||
|
} else {
|
||||||
|
let CodexSpawnOk {
|
||||||
|
codex,
|
||||||
|
session_id: conversation_id,
|
||||||
|
} = { Codex::spawn(config, auth_manager, InitialHistory::New).await? };
|
||||||
|
self.finalize_spawn(codex, conversation_id).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn finalize_spawn(
|
async fn finalize_spawn(
|
||||||
@@ -110,6 +125,20 @@ impl ConversationManager {
|
|||||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn resume_conversation_from_rollout(
|
||||||
|
&self,
|
||||||
|
config: Config,
|
||||||
|
rollout_path: PathBuf,
|
||||||
|
auth_manager: Arc<AuthManager>,
|
||||||
|
) -> CodexResult<NewConversation> {
|
||||||
|
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||||
|
let CodexSpawnOk {
|
||||||
|
codex,
|
||||||
|
session_id: conversation_id,
|
||||||
|
} = Codex::spawn(config, auth_manager, initial_history).await?;
|
||||||
|
self.finalize_spawn(codex, conversation_id).await
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
pub async fn remove_conversation(&self, conversation_id: Uuid) {
|
||||||
self.conversations.write().await.remove(&conversation_id);
|
self.conversations.write().await.remove(&conversation_id);
|
||||||
}
|
}
|
||||||
@@ -125,7 +154,7 @@ impl ConversationManager {
|
|||||||
config: Config,
|
config: Config,
|
||||||
) -> CodexResult<NewConversation> {
|
) -> CodexResult<NewConversation> {
|
||||||
// Compute the prefix up to the cut point.
|
// Compute the prefix up to the cut point.
|
||||||
let truncated_history =
|
let history =
|
||||||
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
truncate_after_dropping_last_messages(conversation_history, num_messages_to_drop);
|
||||||
|
|
||||||
// Spawn a new conversation with the computed initial history.
|
// Spawn a new conversation with the computed initial history.
|
||||||
@@ -133,7 +162,7 @@ impl ConversationManager {
|
|||||||
let CodexSpawnOk {
|
let CodexSpawnOk {
|
||||||
codex,
|
codex,
|
||||||
session_id: conversation_id,
|
session_id: conversation_id,
|
||||||
} = Codex::spawn(config, auth_manager, Some(truncated_history)).await?;
|
} = Codex::spawn(config, auth_manager, history).await?;
|
||||||
|
|
||||||
self.finalize_spawn(codex, conversation_id).await
|
self.finalize_spawn(codex, conversation_id).await
|
||||||
}
|
}
|
||||||
@@ -141,9 +170,9 @@ impl ConversationManager {
|
|||||||
|
|
||||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||||
/// and all items that follow them.
|
/// and all items that follow them.
|
||||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> Vec<ResponseItem> {
|
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||||
if n == 0 || items.is_empty() {
|
if n == 0 {
|
||||||
return items;
|
return InitialHistory::Resumed(items);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk backwards counting only `user` Message items, find cut index.
|
// Walk backwards counting only `user` Message items, find cut index.
|
||||||
@@ -161,11 +190,11 @@ fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) ->
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if count < n {
|
if cut_index == 0 {
|
||||||
// If fewer than n messages exist, drop everything.
|
// No prefix remains after dropping; start a new conversation.
|
||||||
Vec::new()
|
InitialHistory::New
|
||||||
} else {
|
} else {
|
||||||
items.into_iter().take(cut_index).collect()
|
InitialHistory::Resumed(items.into_iter().take(cut_index).collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,10 +252,10 @@ mod tests {
|
|||||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
truncated,
|
truncated,
|
||||||
vec![items[0].clone(), items[1].clone(), items[2].clone()]
|
InitialHistory::Resumed(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||||
);
|
);
|
||||||
|
|
||||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||||
assert!(truncated2.is_empty());
|
assert_eq!(truncated2, InitialHistory::New);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ use tracing::warn;
|
|||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use crate::conversation_manager::InitialHistory;
|
||||||
use crate::git_info::GitInfo;
|
use crate::git_info::GitInfo;
|
||||||
use crate::git_info::collect_git_info;
|
use crate::git_info::collect_git_info;
|
||||||
use codex_protocol::models::ResponseItem;
|
use codex_protocol::models::ResponseItem;
|
||||||
@@ -157,20 +158,14 @@ impl RolloutRecorder {
|
|||||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn resume(
|
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||||
path: &Path,
|
|
||||||
cwd: std::path::PathBuf,
|
|
||||||
) -> std::io::Result<(Self, SavedSession)> {
|
|
||||||
info!("Resuming rollout from {path:?}");
|
info!("Resuming rollout from {path:?}");
|
||||||
let text = tokio::fs::read_to_string(path).await?;
|
let text = tokio::fs::read_to_string(path).await?;
|
||||||
let mut lines = text.lines();
|
let mut lines = text.lines();
|
||||||
let meta_line = lines
|
let _ = lines
|
||||||
.next()
|
.next()
|
||||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
|
||||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
|
||||||
let mut items = Vec::new();
|
let mut items = Vec::new();
|
||||||
let mut state = SessionStateSnapshot::default();
|
|
||||||
|
|
||||||
for line in lines {
|
for line in lines {
|
||||||
if line.trim().is_empty() {
|
if line.trim().is_empty() {
|
||||||
@@ -185,9 +180,6 @@ impl RolloutRecorder {
|
|||||||
.map(|s| s == "state")
|
.map(|s| s == "state")
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
{
|
{
|
||||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
|
||||||
state = s
|
|
||||||
}
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||||
@@ -207,27 +199,12 @@ impl RolloutRecorder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let saved = SavedSession {
|
|
||||||
session: session.clone(),
|
|
||||||
items: items.clone(),
|
|
||||||
state: state.clone(),
|
|
||||||
session_id: session.id,
|
|
||||||
};
|
|
||||||
|
|
||||||
let file = std::fs::OpenOptions::new()
|
|
||||||
.append(true)
|
|
||||||
.read(true)
|
|
||||||
.open(path)?;
|
|
||||||
|
|
||||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
|
||||||
tokio::task::spawn(rollout_writer(
|
|
||||||
tokio::fs::File::from_std(file),
|
|
||||||
rx,
|
|
||||||
None,
|
|
||||||
cwd,
|
|
||||||
));
|
|
||||||
info!("Resumed rollout successfully from {path:?}");
|
info!("Resumed rollout successfully from {path:?}");
|
||||||
Ok((Self { tx }, saved))
|
if items.is_empty() {
|
||||||
|
Ok(InitialHistory::New)
|
||||||
|
} else {
|
||||||
|
Ok(InitialHistory::Resumed(items))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||||
|
|||||||
@@ -388,7 +388,7 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
"No message found in session file containing the marker"
|
"No message found in session file containing the marker"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Second run: resume and append.
|
// Second run: resume should create a NEW session file that contains both old and new history.
|
||||||
let orig_len = content.lines().count();
|
let orig_len = content.lines().count();
|
||||||
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||||
let prompt2 = format!("echo {marker2}");
|
let prompt2 = format!("echo {marker2}");
|
||||||
@@ -419,31 +419,58 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
let output2 = cmd2.output().unwrap();
|
let output2 = cmd2.output().unwrap();
|
||||||
assert!(output2.status.success(), "resume codex-cli run failed");
|
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||||
|
|
||||||
// The rollout writer runs on a background async task; give it a moment to flush.
|
// Find the new session file containing the resumed marker.
|
||||||
let mut new_len = orig_len;
|
let deadline = Instant::now() + Duration::from_secs(10);
|
||||||
let deadline = Instant::now() + Duration::from_secs(5);
|
let mut resumed_path: Option<std::path::PathBuf> = None;
|
||||||
let mut content2 = String::new();
|
while Instant::now() < deadline && resumed_path.is_none() {
|
||||||
while Instant::now() < deadline {
|
for entry in WalkDir::new(&sessions_dir) {
|
||||||
if let Ok(c) = std::fs::read_to_string(&path) {
|
let entry = match entry {
|
||||||
let count = c.lines().count();
|
Ok(e) => e,
|
||||||
if count > orig_len {
|
Err(_) => continue,
|
||||||
content2 = c;
|
};
|
||||||
new_len = count;
|
if !entry.file_type().is_file() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let p = entry.path();
|
||||||
|
let Ok(c) = std::fs::read_to_string(p) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if c.contains(&marker2) {
|
||||||
|
resumed_path = Some(p.to_path_buf());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::thread::sleep(Duration::from_millis(50));
|
if resumed_path.is_none() {
|
||||||
|
std::thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if content2.is_empty() {
|
|
||||||
// last attempt
|
let resumed_path = resumed_path.expect("No resumed session file found containing the marker2");
|
||||||
content2 = std::fs::read_to_string(&path).unwrap();
|
// Resume should have written to a new file, not the original one.
|
||||||
new_len = content2.lines().count();
|
assert_ne!(
|
||||||
}
|
resumed_path, path,
|
||||||
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
"resume should create a new session file"
|
||||||
assert!(content2.contains(&marker), "rollout lost original marker");
|
);
|
||||||
|
|
||||||
|
let resumed_content = std::fs::read_to_string(&resumed_path).unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
content2.contains(&marker2),
|
resumed_content.contains(&marker),
|
||||||
"rollout missing resumed marker"
|
"resumed file missing original marker"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
resumed_content.contains(&marker2),
|
||||||
|
"resumed file missing resumed marker"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Original file should remain unchanged.
|
||||||
|
let content_after = std::fs::read_to_string(&path).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
content_after.lines().count(),
|
||||||
|
orig_len,
|
||||||
|
"original rollout file should not change on resume"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
154
codex-rs/core/tests/suite/fork_conversation.rs
Normal file
154
codex-rs/core/tests/suite/fork_conversation.rs
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
use codex_core::ConversationManager;
|
||||||
|
use codex_core::ModelProviderInfo;
|
||||||
|
use codex_core::NewConversation;
|
||||||
|
use codex_core::built_in_model_providers;
|
||||||
|
use codex_core::protocol::ConversationHistoryResponseEvent;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_login::CodexAuth;
|
||||||
|
use core_test_support::load_default_config_for_test;
|
||||||
|
use core_test_support::wait_for_event;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use wiremock::Mock;
|
||||||
|
use wiremock::MockServer;
|
||||||
|
use wiremock::ResponseTemplate;
|
||||||
|
use wiremock::matchers::method;
|
||||||
|
use wiremock::matchers::path;
|
||||||
|
|
||||||
|
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||||
|
fn sse_completed(id: &str) -> String {
|
||||||
|
core_test_support::load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn fork_conversation_twice_drops_to_first_message() {
|
||||||
|
// Start a mock server that completes three turns.
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
let sse = sse_completed("resp");
|
||||||
|
let first = ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse.clone(), "text/event-stream");
|
||||||
|
|
||||||
|
// Expect three calls to /v1/responses – one per user input.
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/responses"))
|
||||||
|
.respond_with(first)
|
||||||
|
.expect(3)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Configure Codex to use the mock server.
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
|
..built_in_model_providers()["openai"].clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&home);
|
||||||
|
config.model_provider = model_provider.clone();
|
||||||
|
let config_for_fork = config.clone();
|
||||||
|
|
||||||
|
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||||
|
let NewConversation {
|
||||||
|
conversation: codex,
|
||||||
|
..
|
||||||
|
} = conversation_manager
|
||||||
|
.new_conversation(config)
|
||||||
|
.await
|
||||||
|
.expect("create conversation");
|
||||||
|
|
||||||
|
// Send three user messages; wait for three completed turns.
|
||||||
|
for text in ["first", "second", "third"] {
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: text.to_string(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let _ = wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request history from the base conversation.
|
||||||
|
codex.submit(Op::GetHistory).await.unwrap();
|
||||||
|
let base_history =
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ConversationHistory(_))).await;
|
||||||
|
|
||||||
|
// Capture entries from the base history and compute expected prefixes after each fork.
|
||||||
|
let entries_after_three = match &base_history {
|
||||||
|
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
|
||||||
|
entries.clone()
|
||||||
|
}
|
||||||
|
_ => panic!("expected ConversationHistory event"),
|
||||||
|
};
|
||||||
|
// History layout for this test:
|
||||||
|
// [0] user instructions,
|
||||||
|
// [1] environment context,
|
||||||
|
// [2] "first" user message,
|
||||||
|
// [3] "second" user message,
|
||||||
|
// [4] "third" user message.
|
||||||
|
|
||||||
|
// Fork 1: drops the last user message and everything after.
|
||||||
|
let expected_after_first = vec![
|
||||||
|
entries_after_three[0].clone(),
|
||||||
|
entries_after_three[1].clone(),
|
||||||
|
entries_after_three[2].clone(),
|
||||||
|
entries_after_three[3].clone(),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Fork 2: drops the last user message and everything after.
|
||||||
|
// [0] user instructions,
|
||||||
|
// [1] environment context,
|
||||||
|
// [2] "first" user message,
|
||||||
|
let expected_after_second = vec![
|
||||||
|
entries_after_three[0].clone(),
|
||||||
|
entries_after_three[1].clone(),
|
||||||
|
entries_after_three[2].clone(),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Fork once with n=1 → drops the last user message and everything after.
|
||||||
|
let NewConversation {
|
||||||
|
conversation: codex_fork1,
|
||||||
|
..
|
||||||
|
} = conversation_manager
|
||||||
|
.fork_conversation(entries_after_three.clone(), 1, config_for_fork.clone())
|
||||||
|
.await
|
||||||
|
.expect("fork 1");
|
||||||
|
|
||||||
|
codex_fork1.submit(Op::GetHistory).await.unwrap();
|
||||||
|
let fork1_history = wait_for_event(&codex_fork1, |ev| {
|
||||||
|
matches!(ev, EventMsg::ConversationHistory(_))
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
let entries_after_first_fork = match &fork1_history {
|
||||||
|
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { entries, .. }) => {
|
||||||
|
assert!(matches!(
|
||||||
|
fork1_history,
|
||||||
|
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_first
|
||||||
|
));
|
||||||
|
entries.clone()
|
||||||
|
}
|
||||||
|
_ => panic!("expected ConversationHistory event after first fork"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fork again with n=1 → drops the (new) last user message, leaving only the first.
|
||||||
|
let NewConversation {
|
||||||
|
conversation: codex_fork2,
|
||||||
|
..
|
||||||
|
} = conversation_manager
|
||||||
|
.fork_conversation(entries_after_first_fork.clone(), 1, config_for_fork.clone())
|
||||||
|
.await
|
||||||
|
.expect("fork 2");
|
||||||
|
|
||||||
|
codex_fork2.submit(Op::GetHistory).await.unwrap();
|
||||||
|
let fork2_history = wait_for_event(&codex_fork2, |ev| {
|
||||||
|
matches!(ev, EventMsg::ConversationHistory(_))
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
assert!(matches!(
|
||||||
|
fork2_history,
|
||||||
|
EventMsg::ConversationHistory(ConversationHistoryResponseEvent { ref entries, .. }) if *entries == expected_after_second
|
||||||
|
));
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ mod client;
|
|||||||
mod compact;
|
mod compact;
|
||||||
mod exec;
|
mod exec;
|
||||||
mod exec_stream_events;
|
mod exec_stream_events;
|
||||||
|
mod fork_conversation;
|
||||||
mod live_cli;
|
mod live_cli;
|
||||||
mod prompt_caching;
|
mod prompt_caching;
|
||||||
mod seatbelt;
|
mod seatbelt;
|
||||||
|
|||||||
Reference in New Issue
Block a user