fix: async-ify login flow (#2393)

This replaces blocking I/O with async/non-blocking I/O in a number of
cases. This facilitates the use of `tokio::sync::Notify` and
`tokio::select!` in #2394.









---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2393).
* #2399
* #2398
* #2396
* #2395
* #2394
* __->__ #2393
* #2389
This commit is contained in:
Michael Bolin
2025-08-18 17:23:40 -07:00
committed by GitHub
parent 37e5b087a7
commit 6e8c055fd5
5 changed files with 126 additions and 99 deletions

View File

@@ -21,7 +21,7 @@ pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> {
server.actual_port, server.auth_url, server.actual_port, server.auth_url,
); );
server.block_until_done() server.block_until_done().await
} }
pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! { pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! {

View File

@@ -52,15 +52,15 @@ impl ServerOptions {
pub struct LoginServer { pub struct LoginServer {
pub auth_url: String, pub auth_url: String,
pub actual_port: u16, pub actual_port: u16,
pub server_handle: thread::JoinHandle<io::Result<()>>,
pub shutdown_flag: Arc<AtomicBool>, pub shutdown_flag: Arc<AtomicBool>,
pub server: Arc<Server>, server_handle: tokio::task::JoinHandle<io::Result<()>>,
server: Arc<Server>,
} }
impl LoginServer { impl LoginServer {
pub fn block_until_done(self) -> io::Result<()> { pub async fn block_until_done(self) -> io::Result<()> {
self.server_handle self.server_handle
.join() .await
.map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))? .map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))?
} }
@@ -118,7 +118,8 @@ pub fn run_login_server(
if opts.open_browser { if opts.open_browser {
let _ = webbrowser::open(&auth_url); let _ = webbrowser::open(&auth_url);
} }
let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); let shutdown_flag: Arc<AtomicBool> =
shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
let shutdown_flag_clone = shutdown_flag.clone(); let shutdown_flag_clone = shutdown_flag.clone();
let timeout_flag = Arc::new(AtomicBool::new(false)); let timeout_flag = Arc::new(AtomicBool::new(false));
@@ -135,31 +136,46 @@ pub fn run_login_server(
); );
} }
let server_for_thread = server.clone(); let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
let server_handle = thread::spawn(move || { let _server_handle = {
while !shutdown_flag.load(Ordering::SeqCst) { let server = server.clone();
let req = match server_for_thread.recv() { let shutdown_flag = shutdown_flag.clone();
Ok(r) => r, thread::spawn(move || {
Err(e) => { while !shutdown_flag.load(Ordering::SeqCst) {
// If we've been asked to shut down, break gracefully so that match server.recv() {
// we can report timeout or cancellation status uniformly. Ok(request) => tx.blocking_send(request).map_err(|e| {
if shutdown_flag.load(Ordering::SeqCst) { eprintln!("Failed to send request to channel: {e}");
break; io::Error::other("Failed to send request to channel")
} else { })?,
return Err(io::Error::other(e)); Err(e) => {
// If we've been asked to shut down, break gracefully so that
// we can report timeout or cancellation status uniformly.
if shutdown_flag.load(Ordering::SeqCst) {
break;
} else {
return Err(io::Error::other(e));
}
} }
} };
}; }
Ok(())
})
};
let server_handle = tokio::spawn(async move {
while let Some(req) = rx.recv().await {
let url_raw = req.url().to_string();
let response =
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state);
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
match response { match response {
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
let _ = req.respond(r); let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
} }
HandledRequest::RedirectWithHeader(header) => { HandledRequest::RedirectWithHeader(header) => {
let redirect = Response::empty(302).with_header(header); let redirect = Response::empty(302).with_header(header);
let _ = req.respond(redirect); let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
} }
} }
@@ -196,15 +212,14 @@ enum HandledRequest {
ResponseAndExit(Response<Cursor<Vec<u8>>>), ResponseAndExit(Response<Cursor<Vec<u8>>>),
} }
fn process_request( async fn process_request(
req: &Request, url_raw: &str,
opts: &ServerOptions, opts: &ServerOptions,
redirect_uri: &str, redirect_uri: &str,
pkce: &PkceCodes, pkce: &PkceCodes,
actual_port: u16, actual_port: u16,
state: &str, state: &str,
) -> HandledRequest { ) -> HandledRequest {
let url_raw = req.url().to_string();
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
@@ -235,18 +250,22 @@ fn process_request(
}; };
match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code) match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
.await
{ {
Ok(tokens) => { Ok(tokens) => {
// Obtain API key via token-exchange and persist // Obtain API key via token-exchange and persist
let api_key = let api_key = obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok(); .await
if let Err(err) = persist_tokens( .ok();
if let Err(err) = persist_tokens_async(
&opts.codex_home, &opts.codex_home,
api_key.clone(), api_key.clone(),
tokens.id_token.clone(), tokens.id_token.clone(),
Some(tokens.access_token.clone()), Some(tokens.access_token.clone()),
Some(tokens.refresh_token.clone()), Some(tokens.refresh_token.clone()),
) { )
.await
{
eprintln!("Persist error: {err}"); eprintln!("Persist error: {err}");
return HandledRequest::Response( return HandledRequest::Response(
Response::from_string(format!("Unable to persist auth file: {err}")) Response::from_string(format!("Unable to persist auth file: {err}"))
@@ -352,7 +371,7 @@ struct ExchangedTokens {
refresh_token: String, refresh_token: String,
} }
fn exchange_code_for_tokens( async fn exchange_code_for_tokens(
issuer: &str, issuer: &str,
client_id: &str, client_id: &str,
redirect_uri: &str, redirect_uri: &str,
@@ -366,7 +385,7 @@ fn exchange_code_for_tokens(
refresh_token: String, refresh_token: String,
} }
let client = reqwest::blocking::Client::new(); let client = reqwest::Client::new();
let resp = client let resp = client
.post(format!("{issuer}/oauth/token")) .post(format!("{issuer}/oauth/token"))
.header("Content-Type", "application/x-www-form-urlencoded") .header("Content-Type", "application/x-www-form-urlencoded")
@@ -378,6 +397,7 @@ fn exchange_code_for_tokens(
urlencoding::encode(&pkce.code_verifier) urlencoding::encode(&pkce.code_verifier)
)) ))
.send() .send()
.await
.map_err(io::Error::other)?; .map_err(io::Error::other)?;
if !resp.status().is_success() { if !resp.status().is_success() {
@@ -387,7 +407,7 @@ fn exchange_code_for_tokens(
))); )));
} }
let tokens: TokenResponse = resp.json().map_err(io::Error::other)?; let tokens: TokenResponse = resp.json().await.map_err(io::Error::other)?;
Ok(ExchangedTokens { Ok(ExchangedTokens {
id_token: tokens.id_token, id_token: tokens.id_token,
access_token: tokens.access_token, access_token: tokens.access_token,
@@ -395,43 +415,49 @@ fn exchange_code_for_tokens(
}) })
} }
fn persist_tokens( async fn persist_tokens_async(
codex_home: &Path, codex_home: &Path,
api_key: Option<String>, api_key: Option<String>,
id_token: String, id_token: String,
access_token: Option<String>, access_token: Option<String>,
refresh_token: Option<String>, refresh_token: Option<String>,
) -> io::Result<()> { ) -> io::Result<()> {
let auth_file = get_auth_file(codex_home); // Reuse existing synchronous logic but run it off the async runtime.
if let Some(parent) = auth_file.parent() { let codex_home = codex_home.to_path_buf();
if !parent.exists() { tokio::task::spawn_blocking(move || {
std::fs::create_dir_all(parent).map_err(io::Error::other)?; let auth_file = get_auth_file(&codex_home);
if let Some(parent) = auth_file.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent).map_err(io::Error::other)?;
}
} }
}
let mut auth = read_or_default(&auth_file); let mut auth = read_or_default(&auth_file);
if let Some(key) = api_key { if let Some(key) = api_key {
auth.openai_api_key = Some(key); auth.openai_api_key = Some(key);
} }
let tokens = auth let tokens = auth
.tokens .tokens
.get_or_insert_with(crate::token_data::TokenData::default); .get_or_insert_with(crate::token_data::TokenData::default);
tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?; tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?;
// Persist chatgpt_account_id if present in claims // Persist chatgpt_account_id if present in claims
if let Some(acc) = jwt_auth_claims(&id_token) if let Some(acc) = jwt_auth_claims(&id_token)
.get("chatgpt_account_id") .get("chatgpt_account_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
{ {
tokens.account_id = Some(acc.to_string()); tokens.account_id = Some(acc.to_string());
} }
if let Some(at) = access_token { if let Some(at) = access_token {
tokens.access_token = at; tokens.access_token = at;
} }
if let Some(rt) = refresh_token { if let Some(rt) = refresh_token {
tokens.refresh_token = rt; tokens.refresh_token = rt;
} }
auth.last_refresh = Some(Utc::now()); auth.last_refresh = Some(Utc::now());
super::write_auth_json(&auth_file, &auth) super::write_auth_json(&auth_file, &auth)
})
.await
.map_err(|e| io::Error::other(format!("persist task failed: {e}")))?
} }
fn read_or_default(path: &Path) -> AuthDotJson { fn read_or_default(path: &Path) -> AuthDotJson {
@@ -524,13 +550,13 @@ fn jwt_auth_claims(jwt: &str) -> serde_json::Map<String, serde_json::Value> {
serde_json::Map::new() serde_json::Map::new()
} }
fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> { async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> {
// Token exchange for an API key access token // Token exchange for an API key access token
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct ExchangeResp { struct ExchangeResp {
access_token: String, access_token: String,
} }
let client = reqwest::blocking::Client::new(); let client = reqwest::Client::new();
let resp = client let resp = client
.post(format!("{issuer}/oauth/token")) .post(format!("{issuer}/oauth/token"))
.header("Content-Type", "application/x-www-form-urlencoded") .header("Content-Type", "application/x-www-form-urlencoded")
@@ -543,6 +569,7 @@ fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<S
urlencoding::encode("urn:ietf:params:oauth:token-type:id_token") urlencoding::encode("urn:ietf:params:oauth:token-type:id_token")
)) ))
.send() .send()
.await
.map_err(io::Error::other)?; .map_err(io::Error::other)?;
if !resp.status().is_success() { if !resp.status().is_success() {
return Err(io::Error::other(format!( return Err(io::Error::other(format!(
@@ -550,6 +577,6 @@ fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<S
resp.status() resp.status()
))); )));
} }
let body: ExchangeResp = resp.json().map_err(io::Error::other)?; let body: ExchangeResp = resp.json().await.map_err(io::Error::other)?;
Ok(body.access_token) Ok(body.access_token)
} }

View File

@@ -73,8 +73,8 @@ fn start_mock_issuer() -> (SocketAddr, thread::JoinHandle<()>) {
(addr, handle) (addr, handle)
} }
#[test] #[tokio::test]
fn end_to_end_login_flow_persists_auth_json() { async fn end_to_end_login_flow_persists_auth_json() {
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!( println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox." "Skipping test because it cannot execute when network is disabled in a Codex sandbox."
@@ -106,16 +106,16 @@ fn end_to_end_login_flow_persists_auth_json() {
let login_port = server.actual_port; let login_port = server.actual_port;
// Simulate browser callback, and follow redirect to /success // Simulate browser callback, and follow redirect to /success
let client = reqwest::blocking::Client::builder() let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::limited(5)) .redirect(reqwest::redirect::Policy::limited(5))
.build() .build()
.unwrap(); .unwrap();
let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=test_state_123"); let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=test_state_123");
let resp = client.get(&url).send().unwrap(); let resp = client.get(&url).send().await.unwrap();
assert!(resp.status().is_success()); assert!(resp.status().is_success());
// Wait for server shutdown // Wait for server shutdown
server.block_until_done().unwrap(); server.block_until_done().await.unwrap();
// Validate auth.json // Validate auth.json
let auth_path = codex_home.join("auth.json"); let auth_path = codex_home.join("auth.json");
@@ -133,8 +133,8 @@ fn end_to_end_login_flow_persists_auth_json() {
drop(issuer_handle); drop(issuer_handle);
} }
#[test] #[tokio::test]
fn creates_missing_codex_home_dir() { async fn creates_missing_codex_home_dir() {
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!( println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox." "Skipping test because it cannot execute when network is disabled in a Codex sandbox."
@@ -164,12 +164,12 @@ fn creates_missing_codex_home_dir() {
let server = run_login_server(opts, None).unwrap(); let server = run_login_server(opts, None).unwrap();
let login_port = server.actual_port; let login_port = server.actual_port;
let client = reqwest::blocking::Client::new(); let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2"); let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2");
let resp = client.get(&url).send().unwrap(); let resp = client.get(&url).send().await.unwrap();
assert!(resp.status().is_success()); assert!(resp.status().is_success());
server.block_until_done().unwrap(); server.block_until_done().await.unwrap();
let auth_path = codex_home.join("auth.json"); let auth_path = codex_home.join("auth.json");
assert!( assert!(

View File

@@ -180,15 +180,9 @@ impl CodexMessageProcessor {
let outgoing_clone = self.outgoing.clone(); let outgoing_clone = self.outgoing.clone();
let active_login = self.active_login.clone(); let active_login = self.active_login.clone();
tokio::spawn(async move { tokio::spawn(async move {
let result = let (success, error_msg) = match server.block_until_done().await {
tokio::task::spawn_blocking(move || server.block_until_done()).await; Ok(()) => (true, None),
let (success, error_msg) = match result { Err(err) => (false, Some(format!("Login server error: {err}"))),
Ok(Ok(())) => (true, None),
Ok(Err(err)) => (false, Some(format!("Login server error: {err}"))),
Err(join_err) => (
false,
Some(format!("failed to join login server thread: {join_err}")),
),
}; };
let notification = LoginChatGptCompleteNotification { let notification = LoginChatGptCompleteNotification {
login_id, login_id,

View File

@@ -27,7 +27,6 @@ use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::thread::JoinHandle;
use super::onboarding_screen::StepState; use super::onboarding_screen::StepState;
// no additional imports // no additional imports
@@ -47,7 +46,7 @@ pub(crate) enum SignInState {
pub(crate) struct ContinueInBrowserState { pub(crate) struct ContinueInBrowserState {
auth_url: String, auth_url: String,
shutdown_flag: Option<Arc<AtomicBool>>, shutdown_flag: Option<Arc<AtomicBool>>,
_login_wait_handle: Option<JoinHandle<()>>, _login_wait_handle: Option<tokio::task::JoinHandle<()>>,
} }
impl Drop for ContinueInBrowserState { impl Drop for ContinueInBrowserState {
fn drop(&mut self) { fn drop(&mut self) {
@@ -288,11 +287,16 @@ impl AuthModeWidget {
Ok(child) => { Ok(child) => {
let auth_url = child.auth_url.clone(); let auth_url = child.auth_url.clone();
let shutdown_flag = child.shutdown_flag.clone(); let shutdown_flag = child.shutdown_flag.clone();
let event_tx = self.event_tx.clone();
let join_handle = tokio::spawn(async move {
spawn_completion_poller(child, event_tx).await;
});
self.sign_in_state = self.sign_in_state =
SignInState::ChatGptContinueInBrowser(ContinueInBrowserState { SignInState::ChatGptContinueInBrowser(ContinueInBrowserState {
auth_url, auth_url,
shutdown_flag: Some(shutdown_flag), shutdown_flag: Some(shutdown_flag),
_login_wait_handle: Some(self.spawn_completion_poller(child)), _login_wait_handle: Some(join_handle),
}); });
self.event_tx.send(AppEvent::RequestRedraw); self.event_tx.send(AppEvent::RequestRedraw);
} }
@@ -313,19 +317,21 @@ impl AuthModeWidget {
} }
self.event_tx.send(AppEvent::RequestRedraw); self.event_tx.send(AppEvent::RequestRedraw);
} }
}
fn spawn_completion_poller(&self, child: codex_login::LoginServer) -> JoinHandle<()> { async fn spawn_completion_poller(
let event_tx = self.event_tx.clone(); child: codex_login::LoginServer,
std::thread::spawn(move || { event_tx: AppEventSender,
if let Ok(()) = child.block_until_done() { ) -> tokio::task::JoinHandle<()> {
event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); tokio::spawn(async move {
} else { if let Ok(()) = child.block_until_done().await {
event_tx.send(AppEvent::OnboardingAuthComplete(Err( event_tx.send(AppEvent::OnboardingAuthComplete(Ok(())));
"login failed".to_string() } else {
))); event_tx.send(AppEvent::OnboardingAuthComplete(Err(
} "login failed".to_string()
}) )));
} }
})
} }
impl StepStateProvider for AuthModeWidget { impl StepStateProvider for AuthModeWidget {