fix: change shutdown_flag from Arc<AtomicBool> to tokio::sync::Notify (#2394)
Prior to this PR, we had:
71cae06e66/codex-rs/login/src/server.rs (L141-L142)
which means that we could be blocked waiting for a new request in
`server_for_thread.recv()` and not notice that the state of
`shutdown_flag` had changed.
With this PR, we use `shutdown_flag: Notify` so that we can
`tokio::select!` on `shutdown_notify.notified()` and `rx.recv()` (which
is the "async stream" of requests read from `server_for_thread.recv()`)
and handle whichever one happens first.
---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2394).
* #2399
* #2398
* #2396
* #2395
* __->__ #2394
* #2393
* #2389
This commit is contained in:
@@ -52,7 +52,7 @@ impl ServerOptions {
|
|||||||
pub struct LoginServer {
|
pub struct LoginServer {
|
||||||
pub auth_url: String,
|
pub auth_url: String,
|
||||||
pub actual_port: u16,
|
pub actual_port: u16,
|
||||||
pub shutdown_flag: Arc<AtomicBool>,
|
shutdown_flag: Arc<tokio::sync::Notify>,
|
||||||
server_handle: tokio::task::JoinHandle<io::Result<()>>,
|
server_handle: tokio::task::JoinHandle<io::Result<()>>,
|
||||||
server: Arc<Server>,
|
server: Arc<Server>,
|
||||||
}
|
}
|
||||||
@@ -70,7 +70,7 @@ impl LoginServer {
|
|||||||
|
|
||||||
pub fn cancel_handle(&self) -> ShutdownHandle {
|
pub fn cancel_handle(&self) -> ShutdownHandle {
|
||||||
ShutdownHandle {
|
ShutdownHandle {
|
||||||
shutdown_flag: self.shutdown_flag.clone(),
|
shutdown_notify: self.shutdown_flag.clone(),
|
||||||
server: self.server.clone(),
|
server: self.server.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,24 +78,32 @@ impl LoginServer {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ShutdownHandle {
|
pub struct ShutdownHandle {
|
||||||
shutdown_flag: Arc<AtomicBool>,
|
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||||
server: Arc<Server>,
|
server: Arc<Server>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for ShutdownHandle {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("ShutdownHandle")
|
||||||
|
.field("shutdown_notify", &self.shutdown_notify)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ShutdownHandle {
|
impl ShutdownHandle {
|
||||||
pub fn cancel(&self) {
|
pub fn cancel(&self) {
|
||||||
shutdown(&self.shutdown_flag, &self.server);
|
shutdown(&self.shutdown_notify, &self.server);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shutdown(shutdown_flag: &AtomicBool, server: &Server) {
|
pub fn shutdown(shutdown_notify: &tokio::sync::Notify, server: &Server) {
|
||||||
shutdown_flag.store(true, Ordering::SeqCst);
|
shutdown_notify.notify_waiters();
|
||||||
server.unblock();
|
server.unblock();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn run_login_server(
|
pub fn run_login_server(
|
||||||
opts: ServerOptions,
|
opts: ServerOptions,
|
||||||
shutdown_flag: Option<Arc<AtomicBool>>,
|
shutdown_flag: Option<Arc<tokio::sync::Notify>>,
|
||||||
) -> io::Result<LoginServer> {
|
) -> io::Result<LoginServer> {
|
||||||
let pkce = generate_pkce();
|
let pkce = generate_pkce();
|
||||||
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
||||||
@@ -118,9 +126,9 @@ pub fn run_login_server(
|
|||||||
if opts.open_browser {
|
if opts.open_browser {
|
||||||
let _ = webbrowser::open(&auth_url);
|
let _ = webbrowser::open(&auth_url);
|
||||||
}
|
}
|
||||||
let shutdown_flag: Arc<AtomicBool> =
|
let shutdown_notify: Arc<tokio::sync::Notify> =
|
||||||
shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
|
shutdown_flag.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new()));
|
||||||
let shutdown_flag_clone = shutdown_flag.clone();
|
let shutdown_notify_clone = shutdown_notify.clone();
|
||||||
let timeout_flag = Arc::new(AtomicBool::new(false));
|
let timeout_flag = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
// Channel used to signal completion to timeout watcher.
|
// Channel used to signal completion to timeout watcher.
|
||||||
@@ -130,7 +138,7 @@ pub fn run_login_server(
|
|||||||
spawn_timeout_watcher(
|
spawn_timeout_watcher(
|
||||||
done_rx,
|
done_rx,
|
||||||
timeout,
|
timeout,
|
||||||
shutdown_flag.clone(),
|
shutdown_notify.clone(),
|
||||||
timeout_flag.clone(),
|
timeout_flag.clone(),
|
||||||
server.clone(),
|
server.clone(),
|
||||||
);
|
);
|
||||||
@@ -139,61 +147,62 @@ pub fn run_login_server(
|
|||||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
||||||
let _server_handle = {
|
let _server_handle = {
|
||||||
let server = server.clone();
|
let server = server.clone();
|
||||||
let shutdown_flag = shutdown_flag.clone();
|
thread::spawn(move || -> io::Result<()> {
|
||||||
thread::spawn(move || {
|
while let Ok(request) = server.recv() {
|
||||||
while !shutdown_flag.load(Ordering::SeqCst) {
|
tx.blocking_send(request).map_err(|e| {
|
||||||
match server.recv() {
|
eprintln!("Failed to send request to channel: {e}");
|
||||||
Ok(request) => tx.blocking_send(request).map_err(|e| {
|
io::Error::other("Failed to send request to channel")
|
||||||
eprintln!("Failed to send request to channel: {e}");
|
})?;
|
||||||
io::Error::other("Failed to send request to channel")
|
|
||||||
})?,
|
|
||||||
Err(e) => {
|
|
||||||
// If we've been asked to shut down, break gracefully so that
|
|
||||||
// we can report timeout or cancellation status uniformly.
|
|
||||||
if shutdown_flag.load(Ordering::SeqCst) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
return Err(io::Error::other(e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let server_for_task = server.clone();
|
||||||
let server_handle = tokio::spawn(async move {
|
let server_handle = tokio::spawn(async move {
|
||||||
while let Some(req) = rx.recv().await {
|
loop {
|
||||||
let url_raw = req.url().to_string();
|
tokio::select! {
|
||||||
let response =
|
_ = shutdown_notify.notified() => {
|
||||||
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
let _ = done_tx.send(());
|
||||||
|
if timeout_flag.load(Ordering::SeqCst) {
|
||||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
return Err(io::Error::other("Login timed out"));
|
||||||
match response {
|
} else {
|
||||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
return Err(io::Error::other("Login was not completed"));
|
||||||
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
}
|
||||||
}
|
}
|
||||||
HandledRequest::RedirectWithHeader(header) => {
|
maybe_req = rx.recv() => {
|
||||||
let redirect = Response::empty(302).with_header(header);
|
let Some(req) = maybe_req else {
|
||||||
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
let _ = done_tx.send(());
|
||||||
|
if timeout_flag.load(Ordering::SeqCst) {
|
||||||
|
return Err(io::Error::other("Login timed out"));
|
||||||
|
} else {
|
||||||
|
return Err(io::Error::other("Login was not completed"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let url_raw = req.url().to_string();
|
||||||
|
let response =
|
||||||
|
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
||||||
|
|
||||||
|
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||||
|
match response {
|
||||||
|
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||||
|
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
||||||
|
}
|
||||||
|
HandledRequest::RedirectWithHeader(header) => {
|
||||||
|
let redirect = Response::empty(302).with_header(header);
|
||||||
|
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_login_complete {
|
||||||
|
shutdown_notify.notify_waiters();
|
||||||
|
let _ = done_tx.send(());
|
||||||
|
server_for_task.unblock();
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_login_complete {
|
|
||||||
shutdown_flag.store(true, Ordering::SeqCst);
|
|
||||||
// Login has succeeded, so disarm the timeout watcher.
|
|
||||||
let _ = done_tx.send(());
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login has failed or timed out, so disarm the timeout watcher.
|
|
||||||
let _ = done_tx.send(());
|
|
||||||
|
|
||||||
if timeout_flag.load(Ordering::SeqCst) {
|
|
||||||
Err(io::Error::other("Login timed out"))
|
|
||||||
} else {
|
|
||||||
Err(io::Error::other("Login was not completed"))
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -201,7 +210,7 @@ pub fn run_login_server(
|
|||||||
auth_url: auth_url.clone(),
|
auth_url: auth_url.clone(),
|
||||||
actual_port,
|
actual_port,
|
||||||
server_handle,
|
server_handle,
|
||||||
shutdown_flag: shutdown_flag_clone,
|
shutdown_flag: shutdown_notify_clone,
|
||||||
server,
|
server,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -317,17 +326,14 @@ async fn process_request(
|
|||||||
fn spawn_timeout_watcher(
|
fn spawn_timeout_watcher(
|
||||||
done_rx: mpsc::Receiver<()>,
|
done_rx: mpsc::Receiver<()>,
|
||||||
timeout: Duration,
|
timeout: Duration,
|
||||||
shutdown_flag: Arc<AtomicBool>,
|
shutdown_notify: Arc<tokio::sync::Notify>,
|
||||||
timeout_flag: Arc<AtomicBool>,
|
timeout_flag: Arc<AtomicBool>,
|
||||||
server: Arc<Server>,
|
server: Arc<Server>,
|
||||||
) {
|
) {
|
||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
if done_rx.recv_timeout(timeout).is_err()
|
if done_rx.recv_timeout(timeout).is_err() {
|
||||||
&& shutdown_flag
|
|
||||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
|
||||||
.is_ok()
|
|
||||||
{
|
|
||||||
timeout_flag.store(true, Ordering::SeqCst);
|
timeout_flag.store(true, Ordering::SeqCst);
|
||||||
|
shutdown_notify.notify_waiters();
|
||||||
server.unblock();
|
server.unblock();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use codex_login::CLIENT_ID;
|
use codex_login::CLIENT_ID;
|
||||||
use codex_login::ServerOptions;
|
use codex_login::ServerOptions;
|
||||||
|
use codex_login::ShutdownHandle;
|
||||||
use codex_login::run_login_server;
|
use codex_login::run_login_server;
|
||||||
use crossterm::event::KeyCode;
|
use crossterm::event::KeyCode;
|
||||||
use crossterm::event::KeyEvent;
|
use crossterm::event::KeyEvent;
|
||||||
@@ -24,9 +25,6 @@ use crate::onboarding::onboarding_screen::KeyboardHandler;
|
|||||||
use crate::onboarding::onboarding_screen::StepStateProvider;
|
use crate::onboarding::onboarding_screen::StepStateProvider;
|
||||||
use crate::shimmer::shimmer_spans;
|
use crate::shimmer::shimmer_spans;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::AtomicBool;
|
|
||||||
use std::sync::atomic::Ordering;
|
|
||||||
|
|
||||||
use super::onboarding_screen::StepState;
|
use super::onboarding_screen::StepState;
|
||||||
// no additional imports
|
// no additional imports
|
||||||
@@ -45,13 +43,14 @@ pub(crate) enum SignInState {
|
|||||||
/// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up.
|
/// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up.
|
||||||
pub(crate) struct ContinueInBrowserState {
|
pub(crate) struct ContinueInBrowserState {
|
||||||
auth_url: String,
|
auth_url: String,
|
||||||
shutdown_flag: Option<Arc<AtomicBool>>,
|
shutdown_handle: Option<ShutdownHandle>,
|
||||||
_login_wait_handle: Option<tokio::task::JoinHandle<()>>,
|
_login_wait_handle: Option<tokio::task::JoinHandle<()>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ContinueInBrowserState {
|
impl Drop for ContinueInBrowserState {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
if let Some(flag) = &self.shutdown_flag {
|
if let Some(flag) = &self.shutdown_handle {
|
||||||
flag.store(true, Ordering::SeqCst);
|
flag.cancel();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -286,7 +285,7 @@ impl AuthModeWidget {
|
|||||||
match server {
|
match server {
|
||||||
Ok(child) => {
|
Ok(child) => {
|
||||||
let auth_url = child.auth_url.clone();
|
let auth_url = child.auth_url.clone();
|
||||||
let shutdown_flag = child.shutdown_flag.clone();
|
let shutdown_handle = child.cancel_handle();
|
||||||
|
|
||||||
let event_tx = self.event_tx.clone();
|
let event_tx = self.event_tx.clone();
|
||||||
let join_handle = tokio::spawn(async move {
|
let join_handle = tokio::spawn(async move {
|
||||||
@@ -295,7 +294,7 @@ impl AuthModeWidget {
|
|||||||
self.sign_in_state =
|
self.sign_in_state =
|
||||||
SignInState::ChatGptContinueInBrowser(ContinueInBrowserState {
|
SignInState::ChatGptContinueInBrowser(ContinueInBrowserState {
|
||||||
auth_url,
|
auth_url,
|
||||||
shutdown_flag: Some(shutdown_flag),
|
shutdown_handle: Some(shutdown_handle),
|
||||||
_login_wait_handle: Some(join_handle),
|
_login_wait_handle: Some(join_handle),
|
||||||
});
|
});
|
||||||
self.event_tx.send(AppEvent::RequestRedraw);
|
self.event_tx.send(AppEvent::RequestRedraw);
|
||||||
|
|||||||
Reference in New Issue
Block a user