feat: async ghost commit (#5618)
This commit is contained in:
@@ -76,6 +76,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,6 +271,10 @@ pub(crate) async fn stream_chat_completions(
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
|
||||
@@ -104,8 +104,11 @@ use crate::state::SessionServices;
|
||||
use crate::state::SessionState;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::GhostSnapshotTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
@@ -128,6 +131,8 @@ use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::ReadinessFlag;
|
||||
|
||||
pub mod compact;
|
||||
use self::compact::build_compacted_history;
|
||||
@@ -178,6 +183,7 @@ impl Codex {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: config.features.clone(),
|
||||
};
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
@@ -271,6 +277,7 @@ pub(crate) struct TurnContext {
|
||||
pub(crate) is_review_mode: bool,
|
||||
pub(crate) final_output_json_schema: Option<Value>,
|
||||
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
|
||||
}
|
||||
|
||||
impl TurnContext {
|
||||
@@ -312,6 +319,9 @@ pub(crate) struct SessionConfiguration {
|
||||
/// operate deterministically.
|
||||
cwd: PathBuf,
|
||||
|
||||
/// Set of feature flags for this session
|
||||
features: Features,
|
||||
|
||||
// TODO(pakrym): Remove config from here
|
||||
original_config_do_not_use: Arc<Config>,
|
||||
}
|
||||
@@ -406,6 +416,7 @@ impl Session {
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1096,6 +1107,43 @@ impl Session {
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
async fn maybe_start_ghost_snapshot(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
if turn_context.is_review_mode
|
||||
|| !self
|
||||
.state
|
||||
.lock()
|
||||
.await
|
||||
.session_configuration
|
||||
.features
|
||||
.enabled(Feature::GhostCommit)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let token = match turn_context.tool_call_gate.subscribe().await {
|
||||
Ok(token) => token,
|
||||
Err(err) => {
|
||||
warn!("failed to subscribe to ghost snapshot readiness: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("spawning ghost snapshot task");
|
||||
let task = GhostSnapshotTask::new(token);
|
||||
Arc::new(task)
|
||||
.run(
|
||||
Arc::new(SessionTaskContext::new(self.clone())),
|
||||
turn_context.clone(),
|
||||
Vec::new(),
|
||||
cancellation_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<UserInput>) -> Result<(), Vec<UserInput>> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
@@ -1508,6 +1556,7 @@ async fn spawn_review_thread(
|
||||
is_review_mode: true,
|
||||
final_output_json_schema: None,
|
||||
codex_linux_sandbox_exe: parent_turn_context.codex_linux_sandbox_exe.clone(),
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
};
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
@@ -1571,6 +1620,8 @@ pub(crate) async fn run_task(
|
||||
.await;
|
||||
}
|
||||
|
||||
sess.maybe_start_ghost_snapshot(Arc::clone(&turn_context), cancellation_token.child_token())
|
||||
.await;
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
@@ -1763,6 +1814,13 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_model_visible_history(input: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
input
|
||||
.into_iter()
|
||||
.filter(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
@@ -1783,7 +1841,7 @@ async fn run_turn(
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
input: filter_model_visible_history(input),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
@@ -2278,6 +2336,8 @@ fn is_mcp_client_startup_timeout_error(error: &anyhow::Error) -> bool {
|
||||
|| error_message.contains("timed out handshaking with MCP server")
|
||||
}
|
||||
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
@@ -2594,6 +2654,7 @@ mod tests {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: Features::default(),
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
@@ -2662,6 +2723,7 @@ mod tests {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: Features::default(),
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
|
||||
@@ -2,6 +2,7 @@ use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::TokenUsageInfo;
|
||||
use std::ops::Deref;
|
||||
use tracing::error;
|
||||
|
||||
/// Transcript of conversation history
|
||||
@@ -40,7 +41,9 @@ impl ConversationHistory {
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
for item in items {
|
||||
if !is_api_message(&item) {
|
||||
let item_ref = item.deref();
|
||||
let is_ghost_snapshot = matches!(item_ref, ResponseItem::GhostSnapshot { .. });
|
||||
if !is_api_message(item_ref) && !is_ghost_snapshot {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -165,6 +168,7 @@ impl ConversationHistory {
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
@@ -231,6 +235,7 @@ impl ConversationHistory {
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
@@ -355,6 +360,7 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::GhostSnapshot { .. } => false,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ pub enum Feature {
|
||||
WebSearchRequest,
|
||||
/// Enable the model-based risk assessments for sandboxed commands.
|
||||
SandboxCommandAssessment,
|
||||
/// Create a ghost commit at each turn.
|
||||
GhostCommit,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -248,4 +250,10 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::GhostCommit,
|
||||
key: "ghost_commit",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -26,7 +26,8 @@ pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
112
codex-rs/core/src/tasks/ghost_snapshot.rs
Normal file
112
codex-rs/core/src/tasks/ghost_snapshot.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::codex::TurnContext;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use async_trait::async_trait;
|
||||
use codex_git_tooling::CreateGhostCommitOptions;
|
||||
use codex_git_tooling::GitToolingError;
|
||||
use codex_git_tooling::create_ghost_commit;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::Token;
|
||||
use std::borrow::ToOwned;
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct GhostSnapshotTask {
|
||||
token: Token,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for GhostSnapshotTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
tokio::task::spawn(async move {
|
||||
let token = self.token;
|
||||
let ctx_for_task = Arc::clone(&ctx);
|
||||
let cancelled = tokio::select! {
|
||||
_ = cancellation_token.cancelled() => true,
|
||||
_ = async {
|
||||
let repo_path = ctx_for_task.cwd.clone();
|
||||
// Required to run in a dedicated blocking pool.
|
||||
match tokio::task::spawn_blocking(move || {
|
||||
let options = CreateGhostCommitOptions::new(&repo_path);
|
||||
create_ghost_commit(&options)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(ghost_commit)) => {
|
||||
info!("ghost snapshot blocking task finished");
|
||||
session
|
||||
.session
|
||||
.record_conversation_items(&ctx, &[ResponseItem::GhostSnapshot {
|
||||
commit_id: ghost_commit.id().to_string(),
|
||||
parent: ghost_commit.parent().map(ToOwned::to_owned),
|
||||
}])
|
||||
.await;
|
||||
info!("ghost commit captured: {}", ghost_commit.id());
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"failed to capture ghost snapshot: {err}"
|
||||
);
|
||||
let message = match err {
|
||||
GitToolingError::NotAGitRepository { .. } => {
|
||||
"Snapshots disabled: current directory is not a Git repository."
|
||||
.to_string()
|
||||
}
|
||||
_ => format!("Snapshots disabled after ghost snapshot error: {err}."),
|
||||
};
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"ghost snapshot task panicked: {err}"
|
||||
);
|
||||
let message =
|
||||
format!("Snapshots disabled after ghost snapshot panic: {err}.");
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} => false,
|
||||
};
|
||||
|
||||
if cancelled {
|
||||
info!("ghost snapshot task cancelled");
|
||||
}
|
||||
|
||||
match ctx.tool_call_gate.mark_ready(token).await {
|
||||
Ok(true) => info!("ghost snapshot gate marked ready"),
|
||||
Ok(false) => warn!("ghost snapshot gate already ready"),
|
||||
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
|
||||
}
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GhostSnapshotTask {
|
||||
pub(crate) fn new(token: Token) -> Self {
|
||||
Self { token }
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod compact;
|
||||
mod ghost_snapshot;
|
||||
mod regular;
|
||||
mod review;
|
||||
|
||||
@@ -25,6 +26,7 @@ use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use ghost_snapshot::GhostSnapshotTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_utils_readiness::Readiness;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
@@ -53,12 +54,16 @@ impl ToolCallRuntime {
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let aborted_response = Self::aborted_response(&call);
|
||||
let readiness = self.turn_context.tool_call_gate.clone();
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
||||
res = async {
|
||||
tracing::info!("waiting for tool gate");
|
||||
readiness.wait_ready().await;
|
||||
tracing::info!("tool gate released");
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user