feat: async ghost commit (#5618)
This commit is contained in:
2
codex-rs/Cargo.lock
generated
2
codex-rs/Cargo.lock
generated
@@ -1063,10 +1063,12 @@ dependencies = [
|
||||
"codex-apply-patch",
|
||||
"codex-async-utils",
|
||||
"codex-file-search",
|
||||
"codex-git-tooling",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"codex-rmcp-client",
|
||||
"codex-utils-pty",
|
||||
"codex-utils-readiness",
|
||||
"codex-utils-string",
|
||||
"codex-utils-tokenizer",
|
||||
"core-foundation 0.9.4",
|
||||
|
||||
@@ -24,10 +24,12 @@ codex-apply-patch = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-otel = { workspace = true, features = ["otel"] }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-git-tooling = { workspace = true }
|
||||
codex-rmcp-client = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-utils-string = { workspace = true }
|
||||
codex-utils-pty = { workspace = true }
|
||||
codex-utils-readiness = { workspace = true }
|
||||
codex-utils-tokenizer = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
dunce = { workspace = true }
|
||||
|
||||
@@ -76,6 +76,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,6 +271,10 @@ pub(crate) async fn stream_chat_completions(
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
|
||||
@@ -104,8 +104,11 @@ use crate::state::SessionServices;
|
||||
use crate::state::SessionState;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::CompactTask;
|
||||
use crate::tasks::GhostSnapshotTask;
|
||||
use crate::tasks::RegularTask;
|
||||
use crate::tasks::ReviewTask;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use crate::tools::ToolRouter;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::parallel::ToolCallRuntime;
|
||||
@@ -128,6 +131,8 @@ use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::ReadinessFlag;
|
||||
|
||||
pub mod compact;
|
||||
use self::compact::build_compacted_history;
|
||||
@@ -178,6 +183,7 @@ impl Codex {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: config.features.clone(),
|
||||
};
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
@@ -271,6 +277,7 @@ pub(crate) struct TurnContext {
|
||||
pub(crate) is_review_mode: bool,
|
||||
pub(crate) final_output_json_schema: Option<Value>,
|
||||
pub(crate) codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
|
||||
}
|
||||
|
||||
impl TurnContext {
|
||||
@@ -312,6 +319,9 @@ pub(crate) struct SessionConfiguration {
|
||||
/// operate deterministically.
|
||||
cwd: PathBuf,
|
||||
|
||||
/// Set of feature flags for this session
|
||||
features: Features,
|
||||
|
||||
// TODO(pakrym): Remove config from here
|
||||
original_config_do_not_use: Arc<Config>,
|
||||
}
|
||||
@@ -406,6 +416,7 @@ impl Session {
|
||||
is_review_mode: false,
|
||||
final_output_json_schema: None,
|
||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1096,6 +1107,43 @@ impl Session {
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
async fn maybe_start_ghost_snapshot(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
cancellation_token: CancellationToken,
|
||||
) {
|
||||
if turn_context.is_review_mode
|
||||
|| !self
|
||||
.state
|
||||
.lock()
|
||||
.await
|
||||
.session_configuration
|
||||
.features
|
||||
.enabled(Feature::GhostCommit)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
let token = match turn_context.tool_call_gate.subscribe().await {
|
||||
Ok(token) => token,
|
||||
Err(err) => {
|
||||
warn!("failed to subscribe to ghost snapshot readiness: {err}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("spawning ghost snapshot task");
|
||||
let task = GhostSnapshotTask::new(token);
|
||||
Arc::new(task)
|
||||
.run(
|
||||
Arc::new(SessionTaskContext::new(self.clone())),
|
||||
turn_context.clone(),
|
||||
Vec::new(),
|
||||
cancellation_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<UserInput>) -> Result<(), Vec<UserInput>> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
@@ -1508,6 +1556,7 @@ async fn spawn_review_thread(
|
||||
is_review_mode: true,
|
||||
final_output_json_schema: None,
|
||||
codex_linux_sandbox_exe: parent_turn_context.codex_linux_sandbox_exe.clone(),
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
};
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
@@ -1571,6 +1620,8 @@ pub(crate) async fn run_task(
|
||||
.await;
|
||||
}
|
||||
|
||||
sess.maybe_start_ghost_snapshot(Arc::clone(&turn_context), cancellation_token.child_token())
|
||||
.await;
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
@@ -1763,6 +1814,13 @@ fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_model_visible_history(input: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
input
|
||||
.into_iter()
|
||||
.filter(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
@@ -1783,7 +1841,7 @@ async fn run_turn(
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
input: filter_model_visible_history(input),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
@@ -2278,6 +2336,8 @@ fn is_mcp_client_startup_timeout_error(error: &anyhow::Error) -> bool {
|
||||
|| error_message.contains("timed out handshaking with MCP server")
|
||||
}
|
||||
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
@@ -2594,6 +2654,7 @@ mod tests {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: Features::default(),
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
@@ -2662,6 +2723,7 @@ mod tests {
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
cwd: config.cwd.clone(),
|
||||
original_config_do_not_use: Arc::clone(&config),
|
||||
features: Features::default(),
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
|
||||
@@ -2,6 +2,7 @@ use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::TokenUsageInfo;
|
||||
use std::ops::Deref;
|
||||
use tracing::error;
|
||||
|
||||
/// Transcript of conversation history
|
||||
@@ -40,7 +41,9 @@ impl ConversationHistory {
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
for item in items {
|
||||
if !is_api_message(&item) {
|
||||
let item_ref = item.deref();
|
||||
let is_ghost_snapshot = matches!(item_ref, ResponseItem::GhostSnapshot { .. });
|
||||
if !is_api_message(item_ref) && !is_ghost_snapshot {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -165,6 +168,7 @@ impl ConversationHistory {
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
@@ -231,6 +235,7 @@ impl ConversationHistory {
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::Message { .. } => {
|
||||
// nothing to do for these variants
|
||||
@@ -355,6 +360,7 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::GhostSnapshot { .. } => false,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,6 +41,8 @@ pub enum Feature {
|
||||
WebSearchRequest,
|
||||
/// Enable the model-based risk assessments for sandboxed commands.
|
||||
SandboxCommandAssessment,
|
||||
/// Create a ghost commit at each turn.
|
||||
GhostCommit,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -248,4 +250,10 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::GhostCommit,
|
||||
key: "ghost_commit",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -26,7 +26,8 @@ pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
112
codex-rs/core/src/tasks/ghost_snapshot.rs
Normal file
112
codex-rs/core/src/tasks/ghost_snapshot.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::codex::TurnContext;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use async_trait::async_trait;
|
||||
use codex_git_tooling::CreateGhostCommitOptions;
|
||||
use codex_git_tooling::GitToolingError;
|
||||
use codex_git_tooling::create_ghost_commit;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::Token;
|
||||
use std::borrow::ToOwned;
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct GhostSnapshotTask {
|
||||
token: Token,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for GhostSnapshotTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
tokio::task::spawn(async move {
|
||||
let token = self.token;
|
||||
let ctx_for_task = Arc::clone(&ctx);
|
||||
let cancelled = tokio::select! {
|
||||
_ = cancellation_token.cancelled() => true,
|
||||
_ = async {
|
||||
let repo_path = ctx_for_task.cwd.clone();
|
||||
// Required to run in a dedicated blocking pool.
|
||||
match tokio::task::spawn_blocking(move || {
|
||||
let options = CreateGhostCommitOptions::new(&repo_path);
|
||||
create_ghost_commit(&options)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(ghost_commit)) => {
|
||||
info!("ghost snapshot blocking task finished");
|
||||
session
|
||||
.session
|
||||
.record_conversation_items(&ctx, &[ResponseItem::GhostSnapshot {
|
||||
commit_id: ghost_commit.id().to_string(),
|
||||
parent: ghost_commit.parent().map(ToOwned::to_owned),
|
||||
}])
|
||||
.await;
|
||||
info!("ghost commit captured: {}", ghost_commit.id());
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"failed to capture ghost snapshot: {err}"
|
||||
);
|
||||
let message = match err {
|
||||
GitToolingError::NotAGitRepository { .. } => {
|
||||
"Snapshots disabled: current directory is not a Git repository."
|
||||
.to_string()
|
||||
}
|
||||
_ => format!("Snapshots disabled after ghost snapshot error: {err}."),
|
||||
};
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"ghost snapshot task panicked: {err}"
|
||||
);
|
||||
let message =
|
||||
format!("Snapshots disabled after ghost snapshot panic: {err}.");
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} => false,
|
||||
};
|
||||
|
||||
if cancelled {
|
||||
info!("ghost snapshot task cancelled");
|
||||
}
|
||||
|
||||
match ctx.tool_call_gate.mark_ready(token).await {
|
||||
Ok(true) => info!("ghost snapshot gate marked ready"),
|
||||
Ok(false) => warn!("ghost snapshot gate already ready"),
|
||||
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
|
||||
}
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GhostSnapshotTask {
|
||||
pub(crate) fn new(token: Token) -> Self {
|
||||
Self { token }
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
mod compact;
|
||||
mod ghost_snapshot;
|
||||
mod regular;
|
||||
mod review;
|
||||
|
||||
@@ -25,6 +26,7 @@ use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use ghost_snapshot::GhostSnapshotTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_utils_readiness::Readiness;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
@@ -53,12 +54,16 @@ impl ToolCallRuntime {
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let aborted_response = Self::aborted_response(&call);
|
||||
let readiness = self.turn_context.tool_call_gate.clone();
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => Ok(aborted_response),
|
||||
res = async {
|
||||
tracing::info!("waiting for tool gate");
|
||||
readiness.wait_ready().await;
|
||||
tracing::info!("tool gate released");
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
|
||||
@@ -41,6 +41,29 @@ fn network_disabled() -> bool {
|
||||
std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok()
|
||||
}
|
||||
|
||||
fn filter_out_ghost_snapshot_entries(items: &[Value]) -> Vec<Value> {
|
||||
items
|
||||
.iter()
|
||||
.filter(|item| !is_ghost_snapshot_message(item))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_ghost_snapshot_message(item: &Value) -> bool {
|
||||
if item.get("type").and_then(Value::as_str) != Some("message") {
|
||||
return false;
|
||||
}
|
||||
if item.get("role").and_then(Value::as_str) != Some("user") {
|
||||
return false;
|
||||
}
|
||||
item.get("content")
|
||||
.and_then(Value::as_array)
|
||||
.and_then(|content| content.first())
|
||||
.and_then(|entry| entry.get("text"))
|
||||
.and_then(Value::as_str)
|
||||
.is_some_and(|text| text.trim_start().starts_with("<ghost_snapshot>"))
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
/// Scenario: compact an initial conversation, resume it, fork one turn back, and
|
||||
/// ensure the model-visible history matches expectations at each request.
|
||||
@@ -556,13 +579,15 @@ async fn compact_resume_after_second_compaction_preserves_history() {
|
||||
let resume_input_array = input_after_resume
|
||||
.as_array()
|
||||
.expect("input after resume should be an array");
|
||||
let compact_filtered = filter_out_ghost_snapshot_entries(compact_input_array);
|
||||
let resume_filtered = filter_out_ghost_snapshot_entries(resume_input_array);
|
||||
assert!(
|
||||
compact_input_array.len() <= resume_input_array.len(),
|
||||
compact_filtered.len() <= resume_filtered.len(),
|
||||
"after-resume input should have at least as many items as after-compact"
|
||||
);
|
||||
assert_eq!(
|
||||
compact_input_array.as_slice(),
|
||||
&resume_input_array[..compact_input_array.len()]
|
||||
compact_filtered.as_slice(),
|
||||
&resume_filtered[..compact_filtered.len()]
|
||||
);
|
||||
// hard coded test
|
||||
let prompt = requests[0]["instructions"]
|
||||
|
||||
@@ -116,6 +116,12 @@ pub enum ResponseItem {
|
||||
status: Option<String>,
|
||||
action: WebSearchAction,
|
||||
},
|
||||
// Generated by the harness but considered exactly as a model response.
|
||||
GhostSnapshot {
|
||||
commit_id: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
parent: Option<String>,
|
||||
},
|
||||
#[serde(other)]
|
||||
Other,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Readiness flag with token-based authorization and async waiting (Tokio).
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicI32;
|
||||
use std::sync::atomic::Ordering;
|
||||
@@ -71,6 +72,10 @@ impl ReadinessFlag {
|
||||
.map_err(|_| errors::ReadinessError::TokenLockFailed)?;
|
||||
Ok(f(&mut guard))
|
||||
}
|
||||
|
||||
fn load_ready(&self) -> bool {
|
||||
self.ready.load(Ordering::Acquire)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReadinessFlag {
|
||||
@@ -79,14 +84,37 @@ impl Default for ReadinessFlag {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ReadinessFlag {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("ReadinessFlag")
|
||||
.field("ready", &self.load_ready())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Readiness for ReadinessFlag {
|
||||
fn is_ready(&self) -> bool {
|
||||
self.ready.load(Ordering::Acquire)
|
||||
if self.load_ready() {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Ok(tokens) = self.tokens.try_lock()
|
||||
&& tokens.is_empty()
|
||||
{
|
||||
let was_ready = self.ready.swap(true, Ordering::AcqRel);
|
||||
drop(tokens);
|
||||
if !was_ready {
|
||||
let _ = self.tx.send(true);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
self.load_ready()
|
||||
}
|
||||
|
||||
async fn subscribe(&self) -> Result<Token, errors::ReadinessError> {
|
||||
if self.is_ready() {
|
||||
if self.load_ready() {
|
||||
return Err(errors::ReadinessError::FlagAlreadyReady);
|
||||
}
|
||||
|
||||
@@ -97,7 +125,7 @@ impl Readiness for ReadinessFlag {
|
||||
// check above and inserting the token.
|
||||
let inserted = self
|
||||
.with_tokens(|tokens| {
|
||||
if self.is_ready() {
|
||||
if self.load_ready() {
|
||||
return false;
|
||||
}
|
||||
tokens.insert(token);
|
||||
@@ -113,7 +141,7 @@ impl Readiness for ReadinessFlag {
|
||||
}
|
||||
|
||||
async fn mark_ready(&self, token: Token) -> Result<bool, errors::ReadinessError> {
|
||||
if self.is_ready() {
|
||||
if self.load_ready() {
|
||||
return Ok(false);
|
||||
}
|
||||
if token.0 == 0 {
|
||||
@@ -202,7 +230,8 @@ mod tests {
|
||||
async fn mark_ready_rejects_unknown_token() -> Result<(), ReadinessError> {
|
||||
let flag = ReadinessFlag::new();
|
||||
assert!(!flag.mark_ready(Token(42)).await?);
|
||||
assert!(!flag.is_ready());
|
||||
assert!(!flag.load_ready());
|
||||
assert!(flag.is_ready());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -233,6 +262,19 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn is_ready_without_subscribers_marks_flag_ready() -> Result<(), ReadinessError> {
|
||||
let flag = ReadinessFlag::new();
|
||||
|
||||
assert!(flag.is_ready());
|
||||
assert!(flag.is_ready());
|
||||
assert_matches!(
|
||||
flag.subscribe().await,
|
||||
Err(ReadinessError::FlagAlreadyReady)
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn subscribe_returns_error_when_lock_is_held() {
|
||||
let flag = ReadinessFlag::new();
|
||||
|
||||
Reference in New Issue
Block a user