fix: change shutdown_flag from Arc<AtomicBool> to tokio::sync::Notify (#2394)
Prior to this PR, we had:
71cae06e66/codex-rs/login/src/server.rs (L141-L142)
which means that we could be blocked waiting for a new request in
`server_for_thread.recv()` and not notice that the state of
`shutdown_flag` had changed.
With this PR, we use `shutdown_flag: Notify` so that we can
`tokio::select!` on `shutdown_notify.notified()` and `rx.recv()` (which
is the "async stream" of requests read from `server_for_thread.recv()`)
and handle whichever one happens first.
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2394).
* #2399
* #2398
* #2396
* #2395
* __->__ #2394
* #2393
* #2389
This commit is contained in:
@@ -52,7 +52,7 @@ impl ServerOptions {
|
||||
pub struct LoginServer {
|
||||
pub auth_url: String,
|
||||
pub actual_port: u16,
|
||||
pub shutdown_flag: Arc<AtomicBool>,
|
||||
shutdown_flag: Arc<tokio::sync::Notify>,
|
||||
server_handle: tokio::task::JoinHandle<io::Result<()>>,
|
||||
server: Arc<Server>,
|
||||
}
|
||||
@@ -70,7 +70,7 @@ impl LoginServer {
|
||||
|
||||
pub fn cancel_handle(&self) -> ShutdownHandle {
|
||||
ShutdownHandle {
|
||||
shutdown_flag: self.shutdown_flag.clone(),
|
||||
shutdown_notify: self.shutdown_flag.clone(),
|
||||
server: self.server.clone(),
|
||||
}
|
||||
}
|
||||
@@ -78,24 +78,32 @@ impl LoginServer {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShutdownHandle {
|
||||
shutdown_flag: Arc<AtomicBool>,
|
||||
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||
server: Arc<Server>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ShutdownHandle {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ShutdownHandle")
|
||||
.field("shutdown_notify", &self.shutdown_notify)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ShutdownHandle {
|
||||
pub fn cancel(&self) {
|
||||
shutdown(&self.shutdown_flag, &self.server);
|
||||
shutdown(&self.shutdown_notify, &self.server);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shutdown(shutdown_flag: &AtomicBool, server: &Server) {
|
||||
shutdown_flag.store(true, Ordering::SeqCst);
|
||||
pub fn shutdown(shutdown_notify: &tokio::sync::Notify, server: &Server) {
|
||||
shutdown_notify.notify_waiters();
|
||||
server.unblock();
|
||||
}
|
||||
|
||||
pub fn run_login_server(
|
||||
opts: ServerOptions,
|
||||
shutdown_flag: Option<Arc<AtomicBool>>,
|
||||
shutdown_flag: Option<Arc<tokio::sync::Notify>>,
|
||||
) -> io::Result<LoginServer> {
|
||||
let pkce = generate_pkce();
|
||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||
@@ -118,9 +126,9 @@ pub fn run_login_server(
|
||||
if opts.open_browser {
|
||||
let _ = webbrowser::open(&auth_url);
|
||||
}
|
||||
let shutdown_flag: Arc<AtomicBool> =
|
||||
shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
|
||||
let shutdown_flag_clone = shutdown_flag.clone();
|
||||
let shutdown_notify: Arc<tokio::sync::Notify> =
|
||||
shutdown_flag.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new()));
|
||||
let shutdown_notify_clone = shutdown_notify.clone();
|
||||
let timeout_flag = Arc::new(AtomicBool::new(false));
|
||||
|
||||
// Channel used to signal completion to timeout watcher.
|
||||
@@ -130,7 +138,7 @@ pub fn run_login_server(
|
||||
spawn_timeout_watcher(
|
||||
done_rx,
|
||||
timeout,
|
||||
shutdown_flag.clone(),
|
||||
shutdown_notify.clone(),
|
||||
timeout_flag.clone(),
|
||||
server.clone(),
|
||||
);
|
||||
@@ -139,61 +147,62 @@ pub fn run_login_server(
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||
let _server_handle = {
|
||||
let server = server.clone();
|
||||
let shutdown_flag = shutdown_flag.clone();
|
||||
thread::spawn(move || {
|
||||
while !shutdown_flag.load(Ordering::SeqCst) {
|
||||
match server.recv() {
|
||||
Ok(request) => tx.blocking_send(request).map_err(|e| {
|
||||
eprintln!("Failed to send request to channel: {e}");
|
||||
io::Error::other("Failed to send request to channel")
|
||||
})?,
|
||||
Err(e) => {
|
||||
// If we've been asked to shut down, break gracefully so that
|
||||
// we can report timeout or cancellation status uniformly.
|
||||
if shutdown_flag.load(Ordering::SeqCst) {
|
||||
break;
|
||||
} else {
|
||||
return Err(io::Error::other(e));
|
||||
}
|
||||
}
|
||||
};
|
||||
thread::spawn(move || -> io::Result<()> {
|
||||
while let Ok(request) = server.recv() {
|
||||
tx.blocking_send(request).map_err(|e| {
|
||||
eprintln!("Failed to send request to channel: {e}");
|
||||
io::Error::other("Failed to send request to channel")
|
||||
})?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
|
||||
let server_for_task = server.clone();
|
||||
let server_handle = tokio::spawn(async move {
|
||||
while let Some(req) = rx.recv().await {
|
||||
let url_raw = req.url().to_string();
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||
match response {
|
||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_notify.notified() => {
|
||||
let _ = done_tx.send(());
|
||||
if timeout_flag.load(Ordering::SeqCst) {
|
||||
return Err(io::Error::other("Login timed out"));
|
||||
} else {
|
||||
return Err(io::Error::other("Login was not completed"));
|
||||
}
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
maybe_req = rx.recv() => {
|
||||
let Some(req) = maybe_req else {
|
||||
let _ = done_tx.send(());
|
||||
if timeout_flag.load(Ordering::SeqCst) {
|
||||
return Err(io::Error::other("Login timed out"));
|
||||
} else {
|
||||
return Err(io::Error::other("Login was not completed"));
|
||||
}
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let response =
|
||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||
|
||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||
match response {
|
||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||
}
|
||||
}
|
||||
|
||||
if is_login_complete {
|
||||
shutdown_notify.notify_waiters();
|
||||
let _ = done_tx.send(());
|
||||
server_for_task.unblock();
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_login_complete {
|
||||
shutdown_flag.store(true, Ordering::SeqCst);
|
||||
// Login has succeeded, so disarm the timeout watcher.
|
||||
let _ = done_tx.send(());
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Login has failed or timed out, so disarm the timeout watcher.
|
||||
let _ = done_tx.send(());
|
||||
|
||||
if timeout_flag.load(Ordering::SeqCst) {
|
||||
Err(io::Error::other("Login timed out"))
|
||||
} else {
|
||||
Err(io::Error::other("Login was not completed"))
|
||||
}
|
||||
});
|
||||
|
||||
@@ -201,7 +210,7 @@ pub fn run_login_server(
|
||||
auth_url: auth_url.clone(),
|
||||
actual_port,
|
||||
server_handle,
|
||||
shutdown_flag: shutdown_flag_clone,
|
||||
shutdown_flag: shutdown_notify_clone,
|
||||
server,
|
||||
})
|
||||
}
|
||||
@@ -317,17 +326,14 @@ async fn process_request(
|
||||
fn spawn_timeout_watcher(
|
||||
done_rx: mpsc::Receiver<()>,
|
||||
timeout: Duration,
|
||||
shutdown_flag: Arc<AtomicBool>,
|
||||
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||
timeout_flag: Arc<AtomicBool>,
|
||||
server: Arc<Server>,
|
||||
) {
|
||||
thread::spawn(move || {
|
||||
if done_rx.recv_timeout(timeout).is_err()
|
||||
&& shutdown_flag
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_ok()
|
||||
{
|
||||
if done_rx.recv_timeout(timeout).is_err() {
|
||||
timeout_flag.store(true, Ordering::SeqCst);
|
||||
shutdown_notify.notify_waiters();
|
||||
server.unblock();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use codex_login::CLIENT_ID;
|
||||
use codex_login::ServerOptions;
|
||||
use codex_login::ShutdownHandle;
|
||||
use codex_login::run_login_server;
|
||||
use crossterm::event::KeyCode;
|
||||
use crossterm::event::KeyEvent;
|
||||
@@ -24,9 +25,6 @@ use crate::onboarding::onboarding_screen::KeyboardHandler;
|
||||
use crate::onboarding::onboarding_screen::StepStateProvider;
|
||||
use crate::shimmer::shimmer_spans;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use super::onboarding_screen::StepState;
|
||||
// no additional imports
|
||||
@@ -45,13 +43,14 @@ pub(crate) enum SignInState {
|
||||
/// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up.
|
||||
pub(crate) struct ContinueInBrowserState {
|
||||
auth_url: String,
|
||||
shutdown_flag: Option<Arc<AtomicBool>>,
|
||||
shutdown_handle: Option<ShutdownHandle>,
|
||||
_login_wait_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl Drop for ContinueInBrowserState {
|
||||
fn drop(&mut self) {
|
||||
if let Some(flag) = &self.shutdown_flag {
|
||||
flag.store(true, Ordering::SeqCst);
|
||||
if let Some(flag) = &self.shutdown_handle {
|
||||
flag.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -286,7 +285,7 @@ impl AuthModeWidget {
|
||||
match server {
|
||||
Ok(child) => {
|
||||
let auth_url = child.auth_url.clone();
|
||||
let shutdown_flag = child.shutdown_flag.clone();
|
||||
let shutdown_handle = child.cancel_handle();
|
||||
|
||||
let event_tx = self.event_tx.clone();
|
||||
let join_handle = tokio::spawn(async move {
|
||||
@@ -295,7 +294,7 @@ impl AuthModeWidget {
|
||||
self.sign_in_state =
|
||||
SignInState::ChatGptContinueInBrowser(ContinueInBrowserState {
|
||||
auth_url,
|
||||
shutdown_flag: Some(shutdown_flag),
|
||||
shutdown_handle: Some(shutdown_handle),
|
||||
_login_wait_handle: Some(join_handle),
|
||||
});
|
||||
self.event_tx.send(AppEvent::RequestRedraw);
|
||||
|
||||
Reference in New Issue
Block a user