Add session loading support to Codex (#1602)
## Summary - extend rollout format to store all session data in JSON - add resume/write helpers for rollouts - track session state after each conversation - support `LoadSession` op to resume a previous rollout - allow starting Codex with an existing session via `experimental_resume` config variable We need a way later for exploring the available sessions in a user friendly way. ## Testing - `cargo test --no-run` *(fails: `cargo: command not found`)* ------ https://chatgpt.com/codex/tasks/task_i_68792a29dd5c832190bf6930d3466fba This video is outdated. you should use `-c experimental_resume:<full path>` instead of `--resume <full path>` https://github.com/user-attachments/assets/7a9975c7-aa04-4f4e-899a-9e87defd947a
This commit is contained in:
@@ -1,33 +1,47 @@
|
||||
//! Functionality to persist a Codex conversation *rollout* – a linear list of
|
||||
//! [`ResponseItem`] objects exchanged during a session – to disk so that
|
||||
//! sessions can be replayed or inspected later (mirrors the behaviour of the
|
||||
//! upstream TypeScript implementation).
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
/// Folder inside `~/.codex` that holds saved rollouts.
|
||||
const SESSIONS_SUBDIR: &str = "sessions";
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionMeta {
|
||||
id: String,
|
||||
timestamp: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instructions: Option<String>,
|
||||
#[derive(Serialize, Deserialize, Clone, Default)]
|
||||
pub struct SessionMeta {
|
||||
pub id: Uuid,
|
||||
pub timestamp: String,
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SessionStateSnapshot {
|
||||
pub previous_response_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct SavedSession {
|
||||
pub session: SessionMeta,
|
||||
#[serde(default)]
|
||||
pub items: Vec<ResponseItem>,
|
||||
#[serde(default)]
|
||||
pub state: SessionStateSnapshot,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
@@ -41,7 +55,13 @@ struct SessionMeta {
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RolloutRecorder {
|
||||
tx: Sender<String>,
|
||||
tx: Sender<RolloutCmd>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
@@ -59,7 +79,6 @@ impl RolloutRecorder {
|
||||
timestamp,
|
||||
} = create_log_file(config, uuid)?;
|
||||
|
||||
// Build the static session metadata JSON first.
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
@@ -69,46 +88,29 @@ impl RolloutRecorder {
|
||||
|
||||
let meta = SessionMeta {
|
||||
timestamp,
|
||||
id: session_id.to_string(),
|
||||
id: session_id,
|
||||
instructions,
|
||||
};
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller’s thread.
|
||||
let (tx, mut rx) = mpsc::channel::<String>(256);
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(async move {
|
||||
let mut file = tokio::fs::File::from_std(file);
|
||||
tokio::task::spawn(rollout_writer(
|
||||
tokio::fs::File::from_std(file),
|
||||
rx,
|
||||
Some(meta),
|
||||
));
|
||||
|
||||
while let Some(line) = rx.recv().await {
|
||||
// Write line + newline, then flush to disk.
|
||||
if let Err(e) = file.write_all(line.as_bytes()).await {
|
||||
tracing::warn!("rollout writer: failed to write line: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.write_all(b"\n").await {
|
||||
tracing::warn!("rollout writer: failed to write newline: {e}");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = file.flush().await {
|
||||
tracing::warn!("rollout writer: failed to flush: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let recorder = Self { tx };
|
||||
// Ensure SessionMeta is the first item in the file.
|
||||
recorder.record_item(&meta).await?;
|
||||
Ok(recorder)
|
||||
Ok(Self { tx })
|
||||
}
|
||||
|
||||
/// Append `items` to the rollout file.
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -117,27 +119,86 @@ impl RolloutRecorder {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => {}
|
||||
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
||||
// These should never be serialized.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
self.record_item(item).await?;
|
||||
}
|
||||
Ok(())
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
|
||||
// Serialize the item to JSON first so that the writer thread only has
|
||||
// to perform the actual write.
|
||||
let json = serde_json::to_string(item)
|
||||
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(json)
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
}
|
||||
|
||||
pub async fn resume(path: &Path) -> std::io::Result<(Self, SavedSession)> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
let mut lines = text.lines();
|
||||
let meta_line = lines
|
||||
.next()
|
||||
.ok_or_else(|| IoError::other("empty session file"))?;
|
||||
let session: SessionMeta = serde_json::from_str(meta_line)
|
||||
.map_err(|e| IoError::other(format!("failed to parse session meta: {e}")))?;
|
||||
let mut items = Vec::new();
|
||||
let mut state = SessionStateSnapshot::default();
|
||||
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(s) = serde_json::from_value::<SessionStateSnapshot>(v.clone()) {
|
||||
state = s
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let saved = SavedSession {
|
||||
session: session.clone(),
|
||||
items: items.clone(),
|
||||
state: state.clone(),
|
||||
session_id: session.id,
|
||||
};
|
||||
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.read(true)
|
||||
.open(path)?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
tokio::task::spawn(rollout_writer(tokio::fs::File::from_std(file), rx, None));
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok((Self { tx }, saved))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,3 +246,54 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
mut file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
meta: Option<SessionMeta>,
|
||||
) {
|
||||
if let Some(meta) = meta {
|
||||
if let Ok(json) = serde_json::to_string(&meta) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. } => {
|
||||
if let Ok(json) = serde_json::to_string(&item) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
}
|
||||
}
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
}
|
||||
}
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
}
|
||||
if let Ok(json) = serde_json::to_string(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
}) {
|
||||
let _ = file.write_all(json.as_bytes()).await;
|
||||
let _ = file.write_all(b"\n").await;
|
||||
let _ = file.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user