fix: change shutdown_flag from Arc<AtomicBool> to tokio::sync::Notify (#2394)

Prior to this PR, we had:

71cae06e66/codex-rs/login/src/server.rs (L141-L142)

which means that we could be blocked waiting for a new request in
`server_for_thread.recv()` and not notice that the state of
`shutdown_flag` had changed.

With this PR, we use `shutdown_flag: Notify` so that we can
`tokio::select!` on `shutdown_notify.notified()` and `rx.recv()` (which
is the "async stream" of requests read from `server_for_thread.recv()`)
and handle whichever one happens first.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2394).
* #2399
* #2398
* #2396
* #2395
* __->__ #2394
* #2393
* #2389
This commit is contained in:
Michael Bolin
2025-08-18 17:32:03 -07:00
committed by GitHub
parent 38b84ffd43
commit d58df28286
2 changed files with 77 additions and 72 deletions

View File

@@ -52,7 +52,7 @@ impl ServerOptions {
pub struct LoginServer {
pub auth_url: String,
pub actual_port: u16,
pub shutdown_flag: Arc<AtomicBool>,
shutdown_flag: Arc<tokio::sync::Notify>,
server_handle: tokio::task::JoinHandle<io::Result<()>>,
server: Arc<Server>,
}
@@ -70,7 +70,7 @@ impl LoginServer {
pub fn cancel_handle(&self) -> ShutdownHandle {
ShutdownHandle {
shutdown_flag: self.shutdown_flag.clone(),
shutdown_notify: self.shutdown_flag.clone(),
server: self.server.clone(),
}
}
@@ -78,24 +78,32 @@ impl LoginServer {
#[derive(Clone)]
pub struct ShutdownHandle {
shutdown_flag: Arc<AtomicBool>,
shutdown_notify: Arc<tokio::sync::Notify>,
server: Arc<Server>,
}
impl std::fmt::Debug for ShutdownHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ShutdownHandle")
.field("shutdown_notify", &self.shutdown_notify)
.finish()
}
}
impl ShutdownHandle {
pub fn cancel(&self) {
shutdown(&self.shutdown_flag, &self.server);
shutdown(&self.shutdown_notify, &self.server);
}
}
pub fn shutdown(shutdown_flag: &AtomicBool, server: &Server) {
shutdown_flag.store(true, Ordering::SeqCst);
pub fn shutdown(shutdown_notify: &tokio::sync::Notify, server: &Server) {
shutdown_notify.notify_waiters();
server.unblock();
}
pub fn run_login_server(
opts: ServerOptions,
shutdown_flag: Option<Arc<AtomicBool>>,
shutdown_flag: Option<Arc<tokio::sync::Notify>>,
) -> io::Result<LoginServer> {
let pkce = generate_pkce();
let state = opts.force_state.clone().unwrap_or_else(generate_state);
@@ -118,9 +126,9 @@ pub fn run_login_server(
if opts.open_browser {
let _ = webbrowser::open(&auth_url);
}
let shutdown_flag: Arc<AtomicBool> =
shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
let shutdown_flag_clone = shutdown_flag.clone();
let shutdown_notify: Arc<tokio::sync::Notify> =
shutdown_flag.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new()));
let shutdown_notify_clone = shutdown_notify.clone();
let timeout_flag = Arc::new(AtomicBool::new(false));
// Channel used to signal completion to timeout watcher.
@@ -130,7 +138,7 @@ pub fn run_login_server(
spawn_timeout_watcher(
done_rx,
timeout,
shutdown_flag.clone(),
shutdown_notify.clone(),
timeout_flag.clone(),
server.clone(),
);
@@ -139,61 +147,62 @@ pub fn run_login_server(
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
let _server_handle = {
let server = server.clone();
let shutdown_flag = shutdown_flag.clone();
thread::spawn(move || {
while !shutdown_flag.load(Ordering::SeqCst) {
match server.recv() {
Ok(request) => tx.blocking_send(request).map_err(|e| {
eprintln!("Failed to send request to channel: {e}");
io::Error::other("Failed to send request to channel")
})?,
Err(e) => {
// If we've been asked to shut down, break gracefully so that
// we can report timeout or cancellation status uniformly.
if shutdown_flag.load(Ordering::SeqCst) {
break;
} else {
return Err(io::Error::other(e));
}
}
};
thread::spawn(move || -> io::Result<()> {
while let Ok(request) = server.recv() {
tx.blocking_send(request).map_err(|e| {
eprintln!("Failed to send request to channel: {e}");
io::Error::other("Failed to send request to channel")
})?;
}
Ok(())
})
};
let server_for_task = server.clone();
let server_handle = tokio::spawn(async move {
while let Some(req) = rx.recv().await {
let url_raw = req.url().to_string();
let response =
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
match response {
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
loop {
tokio::select! {
_ = shutdown_notify.notified() => {
let _ = done_tx.send(());
if timeout_flag.load(Ordering::SeqCst) {
return Err(io::Error::other("Login timed out"));
} else {
return Err(io::Error::other("Login was not completed"));
}
}
HandledRequest::RedirectWithHeader(header) => {
let redirect = Response::empty(302).with_header(header);
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
maybe_req = rx.recv() => {
let Some(req) = maybe_req else {
let _ = done_tx.send(());
if timeout_flag.load(Ordering::SeqCst) {
return Err(io::Error::other("Login timed out"));
} else {
return Err(io::Error::other("Login was not completed"));
}
};
let url_raw = req.url().to_string();
let response =
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
match response {
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
}
HandledRequest::RedirectWithHeader(header) => {
let redirect = Response::empty(302).with_header(header);
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
}
}
if is_login_complete {
shutdown_notify.notify_waiters();
let _ = done_tx.send(());
server_for_task.unblock();
return Ok(());
}
}
}
if is_login_complete {
shutdown_flag.store(true, Ordering::SeqCst);
// Login has succeeded, so disarm the timeout watcher.
let _ = done_tx.send(());
return Ok(());
}
}
// Login has failed or timed out, so disarm the timeout watcher.
let _ = done_tx.send(());
if timeout_flag.load(Ordering::SeqCst) {
Err(io::Error::other("Login timed out"))
} else {
Err(io::Error::other("Login was not completed"))
}
});
@@ -201,7 +210,7 @@ pub fn run_login_server(
auth_url: auth_url.clone(),
actual_port,
server_handle,
shutdown_flag: shutdown_flag_clone,
shutdown_flag: shutdown_notify_clone,
server,
})
}
@@ -317,17 +326,14 @@ async fn process_request(
fn spawn_timeout_watcher(
done_rx: mpsc::Receiver<()>,
timeout: Duration,
shutdown_flag: Arc<AtomicBool>,
shutdown_notify: Arc<tokio::sync::Notify>,
timeout_flag: Arc<AtomicBool>,
server: Arc<Server>,
) {
thread::spawn(move || {
if done_rx.recv_timeout(timeout).is_err()
&& shutdown_flag
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
if done_rx.recv_timeout(timeout).is_err() {
timeout_flag.store(true, Ordering::SeqCst);
shutdown_notify.notify_waiters();
server.unblock();
}
});

View File

@@ -1,5 +1,6 @@
use codex_login::CLIENT_ID;
use codex_login::ServerOptions;
use codex_login::ShutdownHandle;
use codex_login::run_login_server;
use crossterm::event::KeyCode;
use crossterm::event::KeyEvent;
@@ -24,9 +25,6 @@ use crate::onboarding::onboarding_screen::KeyboardHandler;
use crate::onboarding::onboarding_screen::StepStateProvider;
use crate::shimmer::shimmer_spans;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use super::onboarding_screen::StepState;
// no additional imports
@@ -45,13 +43,14 @@ pub(crate) enum SignInState {
/// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up.
pub(crate) struct ContinueInBrowserState {
auth_url: String,
shutdown_flag: Option<Arc<AtomicBool>>,
shutdown_handle: Option<ShutdownHandle>,
_login_wait_handle: Option<tokio::task::JoinHandle<()>>,
}
impl Drop for ContinueInBrowserState {
fn drop(&mut self) {
if let Some(flag) = &self.shutdown_flag {
flag.store(true, Ordering::SeqCst);
if let Some(flag) = &self.shutdown_handle {
flag.cancel();
}
}
}
@@ -286,7 +285,7 @@ impl AuthModeWidget {
match server {
Ok(child) => {
let auth_url = child.auth_url.clone();
let shutdown_flag = child.shutdown_flag.clone();
let shutdown_handle = child.cancel_handle();
let event_tx = self.event_tx.clone();
let join_handle = tokio::spawn(async move {
@@ -295,7 +294,7 @@ impl AuthModeWidget {
self.sign_in_state =
SignInState::ChatGptContinueInBrowser(ContinueInBrowserState {
auth_url,
shutdown_flag: Some(shutdown_flag),
shutdown_handle: Some(shutdown_handle),
_login_wait_handle: Some(join_handle),
});
self.event_tx.send(AppEvent::RequestRedraw);