From d58df2828683763ec59249e34d507afd3765add5 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Mon, 18 Aug 2025 17:32:03 -0700 Subject: [PATCH] fix: change `shutdown_flag` from `Arc` to `tokio::sync::Notify` (#2394) Prior to this PR, we had: https://github.com/openai/codex/blob/71cae06e6643b4adf644d9769208c3c5fcd1f2be/codex-rs/login/src/server.rs#L141-L142 which means that we could be blocked waiting for a new request in `server_for_thread.recv()` and not notice that the state of `shutdown_flag` had changed. With this PR, we use `shutdown_flag: Notify` so that we can `tokio::select!` on `shutdown_notify.notified()` and `rx.recv()` (which is the "async stream" of requests read from `server_for_thread.recv()`) and handle whichever one happens first. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2394). * #2399 * #2398 * #2396 * #2395 * __->__ #2394 * #2393 * #2389 --- codex-rs/login/src/server.rs | 134 +++++++++++++++------------- codex-rs/tui/src/onboarding/auth.rs | 15 ++-- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index 060b333c..419874e7 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -52,7 +52,7 @@ impl ServerOptions { pub struct LoginServer { pub auth_url: String, pub actual_port: u16, - pub shutdown_flag: Arc, + shutdown_flag: Arc, server_handle: tokio::task::JoinHandle>, server: Arc, } @@ -70,7 +70,7 @@ impl LoginServer { pub fn cancel_handle(&self) -> ShutdownHandle { ShutdownHandle { - shutdown_flag: self.shutdown_flag.clone(), + shutdown_notify: self.shutdown_flag.clone(), server: self.server.clone(), } } @@ -78,24 +78,32 @@ impl LoginServer { #[derive(Clone)] pub struct ShutdownHandle { - shutdown_flag: Arc, + shutdown_notify: Arc, server: Arc, } +impl std::fmt::Debug for ShutdownHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ShutdownHandle") + .field("shutdown_notify", &self.shutdown_notify) + .finish() + } +} + impl ShutdownHandle { pub fn cancel(&self) { - shutdown(&self.shutdown_flag, &self.server); + shutdown(&self.shutdown_notify, &self.server); } } -pub fn shutdown(shutdown_flag: &AtomicBool, server: &Server) { - shutdown_flag.store(true, Ordering::SeqCst); +pub fn shutdown(shutdown_notify: &tokio::sync::Notify, server: &Server) { + shutdown_notify.notify_waiters(); server.unblock(); } pub fn run_login_server( opts: ServerOptions, - shutdown_flag: Option>, + shutdown_flag: Option>, ) -> io::Result { let pkce = generate_pkce(); let state = opts.force_state.clone().unwrap_or_else(generate_state); @@ -118,9 +126,9 @@ pub fn run_login_server( if opts.open_browser { let _ = webbrowser::open(&auth_url); } - let shutdown_flag: Arc = - shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); - let shutdown_flag_clone = shutdown_flag.clone(); + let shutdown_notify: Arc = + shutdown_flag.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new())); + let shutdown_notify_clone = shutdown_notify.clone(); let timeout_flag = Arc::new(AtomicBool::new(false)); // Channel used to signal completion to timeout watcher. @@ -130,7 +138,7 @@ pub fn run_login_server( spawn_timeout_watcher( done_rx, timeout, - shutdown_flag.clone(), + shutdown_notify.clone(), timeout_flag.clone(), server.clone(), ); @@ -139,61 +147,62 @@ pub fn run_login_server( let (tx, mut rx) = tokio::sync::mpsc::channel::(16); let _server_handle = { let server = server.clone(); - let shutdown_flag = shutdown_flag.clone(); - thread::spawn(move || { - while !shutdown_flag.load(Ordering::SeqCst) { - match server.recv() { - Ok(request) => tx.blocking_send(request).map_err(|e| { - eprintln!("Failed to send request to channel: {e}"); - io::Error::other("Failed to send request to channel") - })?, - Err(e) => { - // If we've been asked to shut down, break gracefully so that - // we can report timeout or cancellation status uniformly. - if shutdown_flag.load(Ordering::SeqCst) { - break; - } else { - return Err(io::Error::other(e)); - } - } - }; + thread::spawn(move || -> io::Result<()> { + while let Ok(request) = server.recv() { + tx.blocking_send(request).map_err(|e| { + eprintln!("Failed to send request to channel: {e}"); + io::Error::other("Failed to send request to channel") + })?; } Ok(()) }) }; + let server_for_task = server.clone(); let server_handle = tokio::spawn(async move { - while let Some(req) = rx.recv().await { - let url_raw = req.url().to_string(); - let response = - process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await; - - let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); - match response { - HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { - let _ = tokio::task::spawn_blocking(move || req.respond(r)).await; + loop { + tokio::select! { + _ = shutdown_notify.notified() => { + let _ = done_tx.send(()); + if timeout_flag.load(Ordering::SeqCst) { + return Err(io::Error::other("Login timed out")); + } else { + return Err(io::Error::other("Login was not completed")); + } } - HandledRequest::RedirectWithHeader(header) => { - let redirect = Response::empty(302).with_header(header); - let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await; + maybe_req = rx.recv() => { + let Some(req) = maybe_req else { + let _ = done_tx.send(()); + if timeout_flag.load(Ordering::SeqCst) { + return Err(io::Error::other("Login timed out")); + } else { + return Err(io::Error::other("Login was not completed")); + } + }; + + let url_raw = req.url().to_string(); + let response = + process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await; + + let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); + match response { + HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { + let _ = tokio::task::spawn_blocking(move || req.respond(r)).await; + } + HandledRequest::RedirectWithHeader(header) => { + let redirect = Response::empty(302).with_header(header); + let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await; + } + } + + if is_login_complete { + shutdown_notify.notify_waiters(); + let _ = done_tx.send(()); + server_for_task.unblock(); + return Ok(()); + } } } - - if is_login_complete { - shutdown_flag.store(true, Ordering::SeqCst); - // Login has succeeded, so disarm the timeout watcher. - let _ = done_tx.send(()); - return Ok(()); - } - } - - // Login has failed or timed out, so disarm the timeout watcher. - let _ = done_tx.send(()); - - if timeout_flag.load(Ordering::SeqCst) { - Err(io::Error::other("Login timed out")) - } else { - Err(io::Error::other("Login was not completed")) } }); @@ -201,7 +210,7 @@ pub fn run_login_server( auth_url: auth_url.clone(), actual_port, server_handle, - shutdown_flag: shutdown_flag_clone, + shutdown_flag: shutdown_notify_clone, server, }) } @@ -317,17 +326,14 @@ async fn process_request( fn spawn_timeout_watcher( done_rx: mpsc::Receiver<()>, timeout: Duration, - shutdown_flag: Arc, + shutdown_notify: Arc, timeout_flag: Arc, server: Arc, ) { thread::spawn(move || { - if done_rx.recv_timeout(timeout).is_err() - && shutdown_flag - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_ok() - { + if done_rx.recv_timeout(timeout).is_err() { timeout_flag.store(true, Ordering::SeqCst); + shutdown_notify.notify_waiters(); server.unblock(); } }); diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index 7961d75b..7166e349 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -1,5 +1,6 @@ use codex_login::CLIENT_ID; use codex_login::ServerOptions; +use codex_login::ShutdownHandle; use codex_login::run_login_server; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; @@ -24,9 +25,6 @@ use crate::onboarding::onboarding_screen::KeyboardHandler; use crate::onboarding::onboarding_screen::StepStateProvider; use crate::shimmer::shimmer_spans; use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering; use super::onboarding_screen::StepState; // no additional imports @@ -45,13 +43,14 @@ pub(crate) enum SignInState { /// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up. pub(crate) struct ContinueInBrowserState { auth_url: String, - shutdown_flag: Option>, + shutdown_handle: Option, _login_wait_handle: Option>, } + impl Drop for ContinueInBrowserState { fn drop(&mut self) { - if let Some(flag) = &self.shutdown_flag { - flag.store(true, Ordering::SeqCst); + if let Some(flag) = &self.shutdown_handle { + flag.cancel(); } } } @@ -286,7 +285,7 @@ impl AuthModeWidget { match server { Ok(child) => { let auth_url = child.auth_url.clone(); - let shutdown_flag = child.shutdown_flag.clone(); + let shutdown_handle = child.cancel_handle(); let event_tx = self.event_tx.clone(); let join_handle = tokio::spawn(async move { @@ -295,7 +294,7 @@ impl AuthModeWidget { self.sign_in_state = SignInState::ChatGptContinueInBrowser(ContinueInBrowserState { auth_url, - shutdown_flag: Some(shutdown_flag), + shutdown_handle: Some(shutdown_handle), _login_wait_handle: Some(join_handle), }); self.event_tx.send(AppEvent::RequestRedraw);