Add session loading support to Codex (#1602)
## Summary - extend rollout format to store all session data in JSON - add resume/write helpers for rollouts - track session state after each conversation - support `LoadSession` op to resume a previous rollout - allow starting Codex with an existing session via `experimental_resume` config variable We need a way later for exploring the available sessions in a user friendly way. ## Testing - `cargo test --no-run` *(fails: `cargo: command not found`)* ------ https://chatgpt.com/codex/tasks/task_i_68792a29dd5c832190bf6930d3466fba This video is outdated. you should use `-c experimental_resume:<full path>` instead of `--resume <full path>` https://github.com/user-attachments/assets/7a9975c7-aa04-4f4e-899a-9e87defd947a
This commit is contained in:
@@ -64,7 +64,11 @@ impl CliConfigOverrides {
|
|||||||
// `-c model=o3` without the quotes.
|
// `-c model=o3` without the quotes.
|
||||||
let value: Value = match parse_toml_value(value_str) {
|
let value: Value = match parse_toml_value(value_str) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(_) => Value::String(value_str.to_string()),
|
Err(_) => {
|
||||||
|
// Strip leading/trailing quotes if present
|
||||||
|
let trimmed = value_str.trim().trim_matches(|c| c == '"' || c == '\'');
|
||||||
|
Value::String(trimmed.to_string())
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((key.to_string(), value))
|
Ok((key.to_string(), value))
|
||||||
|
|||||||
@@ -102,6 +102,9 @@ impl Codex {
|
|||||||
/// of `Codex` and the ID of the `SessionInitialized` event that was
|
/// of `Codex` and the ID of the `SessionInitialized` event that was
|
||||||
/// submitted to start the session.
|
/// submitted to start the session.
|
||||||
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String)> {
|
pub async fn spawn(config: Config, ctrl_c: Arc<Notify>) -> CodexResult<(Codex, String)> {
|
||||||
|
// experimental resume path (undocumented)
|
||||||
|
let resume_path = config.experimental_resume.clone();
|
||||||
|
info!("resume_path: {resume_path:?}");
|
||||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||||
let (tx_event, rx_event) = async_channel::bounded(1600);
|
let (tx_event, rx_event) = async_channel::bounded(1600);
|
||||||
|
|
||||||
@@ -117,6 +120,7 @@ impl Codex {
|
|||||||
disable_response_storage: config.disable_response_storage,
|
disable_response_storage: config.disable_response_storage,
|
||||||
notify: config.notify.clone(),
|
notify: config.notify.clone(),
|
||||||
cwd: config.cwd.clone(),
|
cwd: config.cwd.clone(),
|
||||||
|
resume_path: resume_path.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = Arc::new(config);
|
let config = Arc::new(config);
|
||||||
@@ -306,24 +310,30 @@ impl Session {
|
|||||||
/// transcript, if enabled.
|
/// transcript, if enabled.
|
||||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||||
debug!("Recording items for conversation: {items:?}");
|
debug!("Recording items for conversation: {items:?}");
|
||||||
self.record_rollout_items(items).await;
|
self.record_state_snapshot(items).await;
|
||||||
|
|
||||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||||
transcript.record_items(items);
|
transcript.record_items(items);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append the given items to the session's rollout transcript (if enabled)
|
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||||
/// and persist them to disk.
|
let snapshot = {
|
||||||
async fn record_rollout_items(&self, items: &[ResponseItem]) {
|
let state = self.state.lock().unwrap();
|
||||||
// Clone the recorder outside of the mutex so we don't hold the lock
|
crate::rollout::SessionStateSnapshot {
|
||||||
// across an await point (MutexGuard is not Send).
|
previous_response_id: state.previous_response_id.clone(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let recorder = {
|
let recorder = {
|
||||||
let guard = self.rollout.lock().unwrap();
|
let guard = self.rollout.lock().unwrap();
|
||||||
guard.as_ref().cloned()
|
guard.as_ref().cloned()
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(rec) = recorder {
|
if let Some(rec) = recorder {
|
||||||
|
if let Err(e) = rec.record_state(snapshot).await {
|
||||||
|
error!("failed to record rollout state: {e:#}");
|
||||||
|
}
|
||||||
if let Err(e) = rec.record_items(items).await {
|
if let Err(e) = rec.record_items(items).await {
|
||||||
error!("failed to record rollout items: {e:#}");
|
error!("failed to record rollout items: {e:#}");
|
||||||
}
|
}
|
||||||
@@ -517,7 +527,7 @@ async fn submission_loop(
|
|||||||
ctrl_c: Arc<Notify>,
|
ctrl_c: Arc<Notify>,
|
||||||
) {
|
) {
|
||||||
// Generate a unique ID for the lifetime of this Codex session.
|
// Generate a unique ID for the lifetime of this Codex session.
|
||||||
let session_id = Uuid::new_v4();
|
let mut session_id = Uuid::new_v4();
|
||||||
|
|
||||||
let mut sess: Option<Arc<Session>> = None;
|
let mut sess: Option<Arc<Session>> = None;
|
||||||
// shorthand - send an event when there is no active session
|
// shorthand - send an event when there is no active session
|
||||||
@@ -570,8 +580,11 @@ async fn submission_loop(
|
|||||||
disable_response_storage,
|
disable_response_storage,
|
||||||
notify,
|
notify,
|
||||||
cwd,
|
cwd,
|
||||||
|
resume_path,
|
||||||
} => {
|
} => {
|
||||||
info!("Configuring session: model={model}; provider={provider:?}");
|
info!(
|
||||||
|
"Configuring session: model={model}; provider={provider:?}; resume={resume_path:?}"
|
||||||
|
);
|
||||||
if !cwd.is_absolute() {
|
if !cwd.is_absolute() {
|
||||||
let message = format!("cwd is not absolute: {cwd:?}");
|
let message = format!("cwd is not absolute: {cwd:?}");
|
||||||
error!(message);
|
error!(message);
|
||||||
@@ -584,6 +597,41 @@ async fn submission_loop(
|
|||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Optionally resume an existing rollout.
|
||||||
|
let mut restored_items: Option<Vec<ResponseItem>> = None;
|
||||||
|
let mut restored_prev_id: Option<String> = None;
|
||||||
|
let rollout_recorder: Option<RolloutRecorder> =
|
||||||
|
if let Some(path) = resume_path.as_ref() {
|
||||||
|
match RolloutRecorder::resume(path).await {
|
||||||
|
Ok((rec, saved)) => {
|
||||||
|
session_id = saved.session_id;
|
||||||
|
restored_prev_id = saved.state.previous_response_id;
|
||||||
|
if !saved.items.is_empty() {
|
||||||
|
restored_items = Some(saved.items);
|
||||||
|
}
|
||||||
|
Some(rec)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("failed to resume rollout from {path:?}: {e}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let rollout_recorder = match rollout_recorder {
|
||||||
|
Some(rec) => Some(rec),
|
||||||
|
None => match RolloutRecorder::new(&config, session_id, instructions.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => Some(r),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("failed to initialise rollout recorder: {e}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
let client = ModelClient::new(
|
let client = ModelClient::new(
|
||||||
config.clone(),
|
config.clone(),
|
||||||
@@ -644,21 +692,6 @@ async fn submission_loop(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to create a RolloutRecorder *before* moving the
|
|
||||||
// `instructions` value into the Session struct.
|
|
||||||
// TODO: if ConfigureSession is sent twice, we will create an
|
|
||||||
// overlapping rollout file. Consider passing RolloutRecorder
|
|
||||||
// from above.
|
|
||||||
let rollout_recorder =
|
|
||||||
match RolloutRecorder::new(&config, session_id, instructions.clone()).await {
|
|
||||||
Ok(r) => Some(r),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("failed to initialise rollout recorder: {e}");
|
|
||||||
None
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
sess = Some(Arc::new(Session {
|
sess = Some(Arc::new(Session {
|
||||||
client,
|
client,
|
||||||
tx_event: tx_event.clone(),
|
tx_event: tx_event.clone(),
|
||||||
@@ -676,6 +709,19 @@ async fn submission_loop(
|
|||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Patch restored state into the newly created session.
|
||||||
|
if let Some(sess_arc) = &sess {
|
||||||
|
if restored_prev_id.is_some() || restored_items.is_some() {
|
||||||
|
let mut st = sess_arc.state.lock().unwrap();
|
||||||
|
st.previous_response_id = restored_prev_id;
|
||||||
|
if let (Some(hist), Some(items)) =
|
||||||
|
(st.zdr_transcript.as_mut(), restored_items.as_ref())
|
||||||
|
{
|
||||||
|
hist.record_items(items.iter());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Gather history metadata for SessionConfiguredEvent.
|
// Gather history metadata for SessionConfiguredEvent.
|
||||||
let (history_log_id, history_entry_count) =
|
let (history_log_id, history_entry_count) =
|
||||||
crate::message_history::history_metadata(&config).await;
|
crate::message_history::history_metadata(&config).await;
|
||||||
@@ -744,6 +790,8 @@ async fn submission_loop(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::AddToHistory { text } => {
|
Op::AddToHistory { text } => {
|
||||||
|
// TODO: What should we do if we got AddToHistory before ConfigureSession?
|
||||||
|
// currently, if ConfigureSession has resume path, this history will be ignored
|
||||||
let id = session_id;
|
let id = session_id;
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|||||||
@@ -137,6 +137,9 @@ pub struct Config {
|
|||||||
|
|
||||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||||
pub chatgpt_base_url: String,
|
pub chatgpt_base_url: String,
|
||||||
|
|
||||||
|
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||||
|
pub experimental_resume: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@@ -321,6 +324,9 @@ pub struct ConfigToml {
|
|||||||
|
|
||||||
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
|
||||||
pub chatgpt_base_url: Option<String>,
|
pub chatgpt_base_url: Option<String>,
|
||||||
|
|
||||||
|
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
|
||||||
|
pub experimental_resume: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConfigToml {
|
impl ConfigToml {
|
||||||
@@ -448,6 +454,9 @@ impl Config {
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|info| info.max_output_tokens)
|
.map(|info| info.max_output_tokens)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let experimental_resume = cfg.experimental_resume;
|
||||||
|
|
||||||
let config = Self {
|
let config = Self {
|
||||||
model,
|
model,
|
||||||
model_context_window,
|
model_context_window,
|
||||||
@@ -494,6 +503,8 @@ impl Config {
|
|||||||
.chatgpt_base_url
|
.chatgpt_base_url
|
||||||
.or(cfg.chatgpt_base_url)
|
.or(cfg.chatgpt_base_url)
|
||||||
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
.unwrap_or("https://chatgpt.com/backend-api/".to_string()),
|
||||||
|
|
||||||
|
experimental_resume,
|
||||||
};
|
};
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
@@ -806,6 +817,7 @@ disable_response_storage = true
|
|||||||
model_reasoning_summary: ReasoningSummary::Detailed,
|
model_reasoning_summary: ReasoningSummary::Detailed,
|
||||||
model_supports_reasoning_summaries: false,
|
model_supports_reasoning_summaries: false,
|
||||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||||
|
experimental_resume: None,
|
||||||
},
|
},
|
||||||
o3_profile_config
|
o3_profile_config
|
||||||
);
|
);
|
||||||
@@ -852,6 +864,7 @@ disable_response_storage = true
|
|||||||
model_reasoning_summary: ReasoningSummary::default(),
|
model_reasoning_summary: ReasoningSummary::default(),
|
||||||
model_supports_reasoning_summaries: false,
|
model_supports_reasoning_summaries: false,
|
||||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||||
|
experimental_resume: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||||
@@ -913,6 +926,7 @@ disable_response_storage = true
|
|||||||
model_reasoning_summary: ReasoningSummary::default(),
|
model_reasoning_summary: ReasoningSummary::default(),
|
||||||
model_supports_reasoning_summaries: false,
|
model_supports_reasoning_summaries: false,
|
||||||
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
|
||||||
|
experimental_resume: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||||
|
|||||||
@@ -69,6 +69,10 @@ pub enum Op {
|
|||||||
/// `ConfigureSession` operation so that the business-logic layer can
|
/// `ConfigureSession` operation so that the business-logic layer can
|
||||||
/// operate deterministically.
|
/// operate deterministically.
|
||||||
cwd: std::path::PathBuf,
|
cwd: std::path::PathBuf,
|
||||||
|
|
||||||
|
/// Path to a rollout file to resume from.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
resume_path: Option<std::path::PathBuf>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Abort current task.
|
/// Abort current task.
|
||||||
|
|||||||
@@ -1,33 +1,47 @@
|
|||||||
//! Functionality to persist a Codex conversation *rollout* – a linear list of
|
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||||
//! [`ResponseItem`] objects exchanged during a session – to disk so that
|
|
||||||
//! sessions can be replayed or inspected later (mirrors the behaviour of the
|
|
||||||
//! upstream TypeScript implementation).
|
|
||||||
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::fs::{self};
|
use std::fs::{self};
|
||||||
use std::io::Error as IoError;
|
use std::io::Error as IoError;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use serde_json::Value;
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
use time::format_description::FormatItem;
|
use time::format_description::FormatItem;
|
||||||
use time::macros::format_description;
|
use time::macros::format_description;
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
use tokio::sync::mpsc::Sender;
|
use tokio::sync::mpsc::Sender;
|
||||||
use tokio::sync::mpsc::{self};
|
use tokio::sync::mpsc::{self};
|
||||||
|
use tracing::info;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::models::ResponseItem;
|
use crate::models::ResponseItem;
|
||||||
|
|
||||||
/// Folder inside `~/.codex` that holds saved rollouts.
|
|
||||||
const SESSIONS_SUBDIR: &str = "sessions";
|
const SESSIONS_SUBDIR: &str = "sessions";
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||||
struct SessionMeta {
|
pub struct SessionMeta {
|
||||||
id: String,
|
pub id: Uuid,
|
||||||
timestamp: String,
|
pub timestamp: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
pub instructions: Option<String>,
|
||||||
instructions: Option<String>,
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||||
|
pub struct SessionStateSnapshot {
|
||||||
|
pub previous_response_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||||
|
pub struct SavedSession {
|
||||||
|
pub session: SessionMeta,
|
||||||
|
#[serde(default)]
|
||||||
|
pub items: Vec<ResponseItem>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub state: SessionStateSnapshot,
|
||||||
|
pub session_id: Uuid,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||||
@@ -41,7 +55,13 @@ struct SessionMeta {
|
|||||||
/// ```
|
/// ```
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct RolloutRecorder {
|
pub(crate) struct RolloutRecorder {
|
||||||
tx: Sender<String>,
|
tx: Sender<RolloutCmd>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
enum RolloutCmd {
|
||||||
|
AddItems(Vec<ResponseItem>),
|
||||||
|
UpdateState(SessionStateSnapshot),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RolloutRecorder {
|
impl RolloutRecorder {
|
||||||
@@ -59,7 +79,6 @@ impl RolloutRecorder {
|
|||||||
timestamp,
|
timestamp,
|
||||||
} = create_log_file(config, uuid)?;
|
} = create_log_file(config, uuid)?;
|
||||||
|
|
||||||
// Build the static session metadata JSON first.
|
|
||||||
let timestamp_format: &[FormatItem] = format_description!(
|
let timestamp_format: &[FormatItem] = format_description!(
|
||||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||||
);
|
);
|
||||||
@@ -69,46 +88,29 @@ impl RolloutRecorder {
|
|||||||
|
|
||||||
let meta = SessionMeta {
|
let meta = SessionMeta {
|
||||||
timestamp,
|
timestamp,
|
||||||
id: session_id.to_string(),
|
id: session_id,
|
||||||
instructions,
|
instructions,
|
||||||
};
|
};
|
||||||
|
|
||||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||||
// future will yield, which is fine – we only need to ensure we do not
|
// future will yield, which is fine – we only need to ensure we do not
|
||||||
// perform *blocking* I/O on the caller’s thread.
|
// perform *blocking* I/O on the caller’s thread.
|
||||||
let (tx, mut rx) = mpsc::channel::<String>(256);
|
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||||
|
|
||||||
// Spawn a Tokio task that owns the file handle and performs async
|
// Spawn a Tokio task that owns the file handle and performs async
|
||||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||||
// driver instead of blocking the runtime.
|
// driver instead of blocking the runtime.
|
||||||
tokio::task::spawn(async move {
|
tokio::task::spawn(rollout_writer(
|
||||||
let mut file = tokio::fs::File::from_std(file);
|
tokio::fs::File::from_std(file),
|
||||||
|
rx,
|
||||||
|
Some(meta),
|
||||||
|
));
|
||||||
|
|
||||||
while let Some(line) = rx.recv().await {
|
Ok(Self { tx })
|
||||||
// Write line + newline, then flush to disk.
|
|
||||||
if let Err(e) = file.write_all(line.as_bytes()).await {
|
|
||||||
tracing::warn!("rollout writer: failed to write line: {e}");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Err(e) = file.write_all(b"\n").await {
|
|
||||||
tracing::warn!("rollout writer: failed to write newline: {e}");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Err(e) = file.flush().await {
|
|
||||||
tracing::warn!("rollout writer: failed to flush: {e}");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let recorder = Self { tx };
|
|
||||||
// Ensure SessionMeta is the first item in the file.
|
|
||||||
recorder.record_item(&meta).await?;
|
|
||||||
Ok(recorder)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Append `items` to the rollout file.
|
|
||||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||||
|
let mut filtered = Vec::new();
|
||||||
for item in items {
|
for item in items {
|
||||||
match item {
|
match item {
|
||||||
// Note that function calls may look a bit strange if they are
|
// Note that function calls may look a bit strange if they are
|
||||||
@@ -117,27 +119,86 @@ impl RolloutRecorder {
|
|||||||
ResponseItem::Message { .. }
|
ResponseItem::Message { .. }
|
||||||
| ResponseItem::LocalShellCall { .. }
|
| ResponseItem::LocalShellCall { .. }
|
||||||
| ResponseItem::FunctionCall { .. }
|
| ResponseItem::FunctionCall { .. }
|
||||||
| ResponseItem::FunctionCallOutput { .. } => {}
|
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
||||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||||
// These should never be serialized.
|
// These should never be serialized.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.record_item(item).await?;
|
|
||||||
}
|
}
|
||||||
Ok(())
|
if filtered.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
self.tx
|
||||||
|
.send(RolloutCmd::AddItems(filtered))
|
||||||
|
.await
|
||||||
|
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
|
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||||
// Serialize the item to JSON first so that the writer thread only has
|
|
||||||
// to perform the actual write.
|
|
||||||
let json = serde_json::to_string(item)
|
|
||||||
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
|
|
||||||
|
|
||||||
self.tx
|
self.tx
|
||||||
.send(json)
|
.send(RolloutCmd::UpdateState(state))
|
||||||
.await
|
.await
|
||||||
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
|
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn resume(path: &Path) -> std::io::Result<(Self, SavedSession)> {
|
||||||
|
info!("Resuming rollout from {path:?}");
|
||||||
|
let text = tokio::fs::read_to_string(path).await?;
|
||||||
|
let mut lines = text.lines();
|
||||||
|
let meta_line = lines
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||||
|
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||||
|
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||||
|
let mut items = Vec::new();
|
||||||
|
let mut state = SessionStateSnapshot::default();
|
||||||
|
|
||||||
|
for line in lines {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let v: Value = match serde_json::from_str(line) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
if v.get("record_type")
|
||||||
|
.and_then(|rt| rt.as_str())
|
||||||
|
.map(|s| s == "state")
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||||
|
state = s
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||||
|
match item {
|
||||||
|
ResponseItem::Message { .. }
|
||||||
|
| ResponseItem::LocalShellCall { .. }
|
||||||
|
| ResponseItem::FunctionCall { .. }
|
||||||
|
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
||||||
|
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let saved = SavedSession {
|
||||||
|
session: session.clone(),
|
||||||
|
items: items.clone(),
|
||||||
|
state: state.clone(),
|
||||||
|
session_id: session.id,
|
||||||
|
};
|
||||||
|
|
||||||
|
let file = std::fs::OpenOptions::new()
|
||||||
|
.append(true)
|
||||||
|
.read(true)
|
||||||
|
.open(path)?;
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||||
|
tokio::task::spawn(rollout_writer(tokio::fs::File::from_std(file), rx, None));
|
||||||
|
info!("Resumed rollout successfully from {path:?}");
|
||||||
|
Ok((Self { tx }, saved))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,3 +246,54 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
|||||||
timestamp,
|
timestamp,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn rollout_writer(
|
||||||
|
mut file: tokio::fs::File,
|
||||||
|
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||||
|
meta: Option<SessionMeta>,
|
||||||
|
) {
|
||||||
|
if let Some(meta) = meta {
|
||||||
|
if let Ok(json) = serde_json::to_string(&meta) {
|
||||||
|
let _ = file.write_all(json.as_bytes()).await;
|
||||||
|
let _ = file.write_all(b"\n").await;
|
||||||
|
let _ = file.flush().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while let Some(cmd) = rx.recv().await {
|
||||||
|
match cmd {
|
||||||
|
RolloutCmd::AddItems(items) => {
|
||||||
|
for item in items {
|
||||||
|
match item {
|
||||||
|
ResponseItem::Message { .. }
|
||||||
|
| ResponseItem::LocalShellCall { .. }
|
||||||
|
| ResponseItem::FunctionCall { .. }
|
||||||
|
| ResponseItem::FunctionCallOutput { .. } => {
|
||||||
|
if let Ok(json) = serde_json::to_string(&item) {
|
||||||
|
let _ = file.write_all(json.as_bytes()).await;
|
||||||
|
let _ = file.write_all(b"\n").await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = file.flush().await;
|
||||||
|
}
|
||||||
|
RolloutCmd::UpdateState(state) => {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct StateLine<'a> {
|
||||||
|
record_type: &'static str,
|
||||||
|
#[serde(flatten)]
|
||||||
|
state: &'a SessionStateSnapshot,
|
||||||
|
}
|
||||||
|
if let Ok(json) = serde_json::to_string(&StateLine {
|
||||||
|
record_type: "state",
|
||||||
|
state: &state,
|
||||||
|
}) {
|
||||||
|
let _ = file.write_all(json.as_bytes()).await;
|
||||||
|
let _ = file.write_all(b"\n").await;
|
||||||
|
let _ = file.flush().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
use assert_cmd::Command as AssertCommand;
|
use assert_cmd::Command as AssertCommand;
|
||||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||||
use serde_json::Value;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
@@ -123,6 +122,7 @@ async fn responses_api_stream_cli() {
|
|||||||
assert!(stdout.contains("fixture hello"));
|
assert!(stdout.contains("fixture hello"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// End-to-end: create a session (writes rollout), verify the file, then resume and confirm append.
|
||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
async fn integration_creates_and_checks_session_file() {
|
async fn integration_creates_and_checks_session_file() {
|
||||||
// Honor sandbox network restrictions for CI parity with the other tests.
|
// Honor sandbox network restrictions for CI parity with the other tests.
|
||||||
@@ -170,45 +170,66 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
String::from_utf8_lossy(&output.stderr)
|
String::from_utf8_lossy(&output.stderr)
|
||||||
);
|
);
|
||||||
|
|
||||||
// 5. Sessions are written asynchronously; wait briefly for the directory to appear.
|
// Wait for sessions dir to appear.
|
||||||
let sessions_dir = home.path().join("sessions");
|
let sessions_dir = home.path().join("sessions");
|
||||||
let start = Instant::now();
|
let dir_deadline = Instant::now() + Duration::from_secs(5);
|
||||||
while !sessions_dir.exists() && start.elapsed() < Duration::from_secs(3) {
|
while !sessions_dir.exists() && Instant::now() < dir_deadline {
|
||||||
std::thread::sleep(Duration::from_millis(50));
|
std::thread::sleep(Duration::from_millis(50));
|
||||||
}
|
}
|
||||||
|
assert!(sessions_dir.exists(), "sessions directory never appeared");
|
||||||
|
|
||||||
// 6. Scan all session files and find the one that contains our marker.
|
// Find the session file that contains `marker`.
|
||||||
let mut matching_files = vec![];
|
let deadline = Instant::now() + Duration::from_secs(10);
|
||||||
for entry in WalkDir::new(&sessions_dir) {
|
let mut matching_path: Option<std::path::PathBuf> = None;
|
||||||
let entry = entry.unwrap();
|
while Instant::now() < deadline && matching_path.is_none() {
|
||||||
if entry.file_type().is_file() && entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
for entry in WalkDir::new(&sessions_dir) {
|
||||||
|
let entry = match entry {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
if !entry.file_type().is_file() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if !entry.file_name().to_string_lossy().ends_with(".jsonl") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
let content = std::fs::read_to_string(path).unwrap();
|
let Ok(content) = std::fs::read_to_string(path) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
let mut lines = content.lines();
|
let mut lines = content.lines();
|
||||||
// Skip SessionMeta (first line)
|
if lines.next().is_none() {
|
||||||
let _ = lines.next();
|
continue;
|
||||||
|
}
|
||||||
for line in lines {
|
for line in lines {
|
||||||
let item: Value = serde_json::from_str(line).unwrap();
|
if line.trim().is_empty() {
|
||||||
if let Some("message") = item.get("type").and_then(|t| t.as_str()) {
|
continue;
|
||||||
if let Some(content) = item.get("content") {
|
}
|
||||||
if content.to_string().contains(&marker) {
|
let item: serde_json::Value = match serde_json::from_str(line) {
|
||||||
matching_files.push(path.to_owned());
|
Ok(v) => v,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||||
|
if let Some(c) = item.get("content") {
|
||||||
|
if c.to_string().contains(&marker) {
|
||||||
|
matching_path = Some(path.to_path_buf());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if matching_path.is_none() {
|
||||||
|
std::thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert_eq!(
|
|
||||||
matching_files.len(),
|
|
||||||
1,
|
|
||||||
"Expected exactly one session file containing the marker, found {}",
|
|
||||||
matching_files.len()
|
|
||||||
);
|
|
||||||
let path = &matching_files[0];
|
|
||||||
|
|
||||||
// 7. Verify directory structure: sessions/YYYY/MM/DD/filename.jsonl
|
let path = match matching_path {
|
||||||
|
Some(p) => p,
|
||||||
|
None => panic!("No session file containing the marker was found"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Basic sanity checks on location and metadata.
|
||||||
let rel = match path.strip_prefix(&sessions_dir) {
|
let rel = match path.strip_prefix(&sessions_dir) {
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(_) => panic!("session file should live under sessions/"),
|
Err(_) => panic!("session file should live under sessions/"),
|
||||||
@@ -237,7 +258,6 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
|
day.len() == 2 && day.chars().all(|c| c.is_ascii_digit()),
|
||||||
"Day dir not zero-padded 2-digit numeric: {day}"
|
"Day dir not zero-padded 2-digit numeric: {day}"
|
||||||
);
|
);
|
||||||
// Range checks (best-effort; won't fail on leading zeros)
|
|
||||||
if let Ok(m) = month.parse::<u8>() {
|
if let Ok(m) = month.parse::<u8>() {
|
||||||
assert!((1..=12).contains(&m), "Month out of range: {m}");
|
assert!((1..=12).contains(&m), "Month out of range: {m}");
|
||||||
}
|
}
|
||||||
@@ -245,23 +265,32 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
assert!((1..=31).contains(&d), "Day out of range: {d}");
|
assert!((1..=31).contains(&d), "Day out of range: {d}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. Parse SessionMeta line and basic sanity checks.
|
let content =
|
||||||
let content = std::fs::read_to_string(path).unwrap();
|
std::fs::read_to_string(&path).unwrap_or_else(|_| panic!("Failed to read session file"));
|
||||||
let mut lines = content.lines();
|
let mut lines = content.lines();
|
||||||
let meta: Value = serde_json::from_str(lines.next().unwrap()).unwrap();
|
let meta_line = lines
|
||||||
|
.next()
|
||||||
|
.ok_or("missing session meta line")
|
||||||
|
.unwrap_or_else(|_| panic!("missing session meta line"));
|
||||||
|
let meta: serde_json::Value = serde_json::from_str(meta_line)
|
||||||
|
.unwrap_or_else(|_| panic!("Failed to parse session meta line as JSON"));
|
||||||
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
assert!(meta.get("id").is_some(), "SessionMeta missing id");
|
||||||
assert!(
|
assert!(
|
||||||
meta.get("timestamp").is_some(),
|
meta.get("timestamp").is_some(),
|
||||||
"SessionMeta missing timestamp"
|
"SessionMeta missing timestamp"
|
||||||
);
|
);
|
||||||
|
|
||||||
// 9. Confirm at least one message contains the marker.
|
|
||||||
let mut found_message = false;
|
let mut found_message = false;
|
||||||
for line in lines {
|
for line in lines {
|
||||||
let item: Value = serde_json::from_str(line).unwrap();
|
if line.trim().is_empty() {
|
||||||
if item.get("type").map(|t| t == "message").unwrap_or(false) {
|
continue;
|
||||||
if let Some(content) = item.get("content") {
|
}
|
||||||
if content.to_string().contains(&marker) {
|
let Ok(item) = serde_json::from_str::<serde_json::Value>(line) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if item.get("type").and_then(|t| t.as_str()) == Some("message") {
|
||||||
|
if let Some(c) = item.get("content") {
|
||||||
|
if c.to_string().contains(&marker) {
|
||||||
found_message = true;
|
found_message = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -272,4 +301,61 @@ async fn integration_creates_and_checks_session_file() {
|
|||||||
found_message,
|
found_message,
|
||||||
"No message found in session file containing the marker"
|
"No message found in session file containing the marker"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Second run: resume and append.
|
||||||
|
let orig_len = content.lines().count();
|
||||||
|
let marker2 = format!("integration-resume-{}", Uuid::new_v4());
|
||||||
|
let prompt2 = format!("echo {marker2}");
|
||||||
|
// Cross‑platform safe resume override. On Windows, backslashes in a TOML string must be escaped
|
||||||
|
// or the parse will fail and the raw literal (including quotes) may be preserved all the way down
|
||||||
|
// to Config, which in turn breaks resume because the path is invalid. Normalize to forward slashes
|
||||||
|
// to sidestep the issue.
|
||||||
|
let resume_path_str = path.to_string_lossy().replace('\\', "/");
|
||||||
|
let resume_override = format!("experimental_resume=\"{resume_path_str}\"");
|
||||||
|
let mut cmd2 = AssertCommand::new("cargo");
|
||||||
|
cmd2.arg("run")
|
||||||
|
.arg("-p")
|
||||||
|
.arg("codex-cli")
|
||||||
|
.arg("--quiet")
|
||||||
|
.arg("--")
|
||||||
|
.arg("exec")
|
||||||
|
.arg("--skip-git-repo-check")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(&resume_override)
|
||||||
|
.arg("-C")
|
||||||
|
.arg(env!("CARGO_MANIFEST_DIR"))
|
||||||
|
.arg(&prompt2);
|
||||||
|
cmd2.env("CODEX_HOME", home.path())
|
||||||
|
.env("OPENAI_API_KEY", "dummy")
|
||||||
|
.env("CODEX_RS_SSE_FIXTURE", &fixture)
|
||||||
|
.env("OPENAI_BASE_URL", "http://unused.local");
|
||||||
|
let output2 = cmd2.output().unwrap();
|
||||||
|
assert!(output2.status.success(), "resume codex-cli run failed");
|
||||||
|
|
||||||
|
// The rollout writer runs on a background async task; give it a moment to flush.
|
||||||
|
let mut new_len = orig_len;
|
||||||
|
let deadline = Instant::now() + Duration::from_secs(5);
|
||||||
|
let mut content2 = String::new();
|
||||||
|
while Instant::now() < deadline {
|
||||||
|
if let Ok(c) = std::fs::read_to_string(&path) {
|
||||||
|
let count = c.lines().count();
|
||||||
|
if count > orig_len {
|
||||||
|
content2 = c;
|
||||||
|
new_len = count;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::thread::sleep(Duration::from_millis(50));
|
||||||
|
}
|
||||||
|
if content2.is_empty() {
|
||||||
|
// last attempt
|
||||||
|
content2 = std::fs::read_to_string(&path).unwrap();
|
||||||
|
new_len = content2.lines().count();
|
||||||
|
}
|
||||||
|
assert!(new_len > orig_len, "rollout file did not grow after resume");
|
||||||
|
assert!(content2.contains(&marker), "rollout lost original marker");
|
||||||
|
assert!(
|
||||||
|
content2.contains(&marker2),
|
||||||
|
"rollout missing resumed marker"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user