From 6e8c055fd50e9b88f00b77891cde185ed23c2676 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Mon, 18 Aug 2025 17:23:40 -0700 Subject: [PATCH] fix: async-ify login flow (#2393) This replaces blocking I/O with async/non-blocking I/O in a number of cases. This facilitates the use of `tokio::sync::Notify` and `tokio::select!` in #2394. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2393). * #2399 * #2398 * #2396 * #2395 * #2394 * __->__ #2393 * #2389 --- codex-rs/cli/src/login.rs | 2 +- codex-rs/login/src/server.rs | 155 ++++++++++-------- codex-rs/login/tests/login_server_e2e.rs | 20 +-- .../mcp-server/src/codex_message_processor.rs | 12 +- codex-rs/tui/src/onboarding/auth.rs | 36 ++-- 5 files changed, 126 insertions(+), 99 deletions(-) diff --git a/codex-rs/cli/src/login.rs b/codex-rs/cli/src/login.rs index 5f9dc5f9..fc40a027 100644 --- a/codex-rs/cli/src/login.rs +++ b/codex-rs/cli/src/login.rs @@ -21,7 +21,7 @@ pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> { server.actual_port, server.auth_url, ); - server.block_until_done() + server.block_until_done().await } pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! { diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index 19ef4c1c..060b333c 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -52,15 +52,15 @@ impl ServerOptions { pub struct LoginServer { pub auth_url: String, pub actual_port: u16, - pub server_handle: thread::JoinHandle>, pub shutdown_flag: Arc, - pub server: Arc, + server_handle: tokio::task::JoinHandle>, + server: Arc, } impl LoginServer { - pub fn block_until_done(self) -> io::Result<()> { + pub async fn block_until_done(self) -> io::Result<()> { self.server_handle - .join() + .await .map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))? } @@ -118,7 +118,8 @@ pub fn run_login_server( if opts.open_browser { let _ = webbrowser::open(&auth_url); } - let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); + let shutdown_flag: Arc = + shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); let shutdown_flag_clone = shutdown_flag.clone(); let timeout_flag = Arc::new(AtomicBool::new(false)); @@ -135,31 +136,46 @@ pub fn run_login_server( ); } - let server_for_thread = server.clone(); - let server_handle = thread::spawn(move || { - while !shutdown_flag.load(Ordering::SeqCst) { - let req = match server_for_thread.recv() { - Ok(r) => r, - Err(e) => { - // If we've been asked to shut down, break gracefully so that - // we can report timeout or cancellation status uniformly. - if shutdown_flag.load(Ordering::SeqCst) { - break; - } else { - return Err(io::Error::other(e)); + let (tx, mut rx) = tokio::sync::mpsc::channel::(16); + let _server_handle = { + let server = server.clone(); + let shutdown_flag = shutdown_flag.clone(); + thread::spawn(move || { + while !shutdown_flag.load(Ordering::SeqCst) { + match server.recv() { + Ok(request) => tx.blocking_send(request).map_err(|e| { + eprintln!("Failed to send request to channel: {e}"); + io::Error::other("Failed to send request to channel") + })?, + Err(e) => { + // If we've been asked to shut down, break gracefully so that + // we can report timeout or cancellation status uniformly. + if shutdown_flag.load(Ordering::SeqCst) { + break; + } else { + return Err(io::Error::other(e)); + } } - } - }; + }; + } + Ok(()) + }) + }; + + let server_handle = tokio::spawn(async move { + while let Some(req) = rx.recv().await { + let url_raw = req.url().to_string(); + let response = + process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await; - let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state); let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); match response { HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { - let _ = req.respond(r); + let _ = tokio::task::spawn_blocking(move || req.respond(r)).await; } HandledRequest::RedirectWithHeader(header) => { let redirect = Response::empty(302).with_header(header); - let _ = req.respond(redirect); + let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await; } } @@ -196,15 +212,14 @@ enum HandledRequest { ResponseAndExit(Response>>), } -fn process_request( - req: &Request, +async fn process_request( + url_raw: &str, opts: &ServerOptions, redirect_uri: &str, pkce: &PkceCodes, actual_port: u16, state: &str, ) -> HandledRequest { - let url_raw = req.url().to_string(); let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { Ok(u) => u, Err(e) => { @@ -235,18 +250,22 @@ fn process_request( }; match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code) + .await { Ok(tokens) => { // Obtain API key via token-exchange and persist - let api_key = - obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok(); - if let Err(err) = persist_tokens( + let api_key = obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token) + .await + .ok(); + if let Err(err) = persist_tokens_async( &opts.codex_home, api_key.clone(), tokens.id_token.clone(), Some(tokens.access_token.clone()), Some(tokens.refresh_token.clone()), - ) { + ) + .await + { eprintln!("Persist error: {err}"); return HandledRequest::Response( Response::from_string(format!("Unable to persist auth file: {err}")) @@ -352,7 +371,7 @@ struct ExchangedTokens { refresh_token: String, } -fn exchange_code_for_tokens( +async fn exchange_code_for_tokens( issuer: &str, client_id: &str, redirect_uri: &str, @@ -366,7 +385,7 @@ fn exchange_code_for_tokens( refresh_token: String, } - let client = reqwest::blocking::Client::new(); + let client = reqwest::Client::new(); let resp = client .post(format!("{issuer}/oauth/token")) .header("Content-Type", "application/x-www-form-urlencoded") @@ -378,6 +397,7 @@ fn exchange_code_for_tokens( urlencoding::encode(&pkce.code_verifier) )) .send() + .await .map_err(io::Error::other)?; if !resp.status().is_success() { @@ -387,7 +407,7 @@ fn exchange_code_for_tokens( ))); } - let tokens: TokenResponse = resp.json().map_err(io::Error::other)?; + let tokens: TokenResponse = resp.json().await.map_err(io::Error::other)?; Ok(ExchangedTokens { id_token: tokens.id_token, access_token: tokens.access_token, @@ -395,43 +415,49 @@ fn exchange_code_for_tokens( }) } -fn persist_tokens( +async fn persist_tokens_async( codex_home: &Path, api_key: Option, id_token: String, access_token: Option, refresh_token: Option, ) -> io::Result<()> { - let auth_file = get_auth_file(codex_home); - if let Some(parent) = auth_file.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent).map_err(io::Error::other)?; + // Reuse existing synchronous logic but run it off the async runtime. + let codex_home = codex_home.to_path_buf(); + tokio::task::spawn_blocking(move || { + let auth_file = get_auth_file(&codex_home); + if let Some(parent) = auth_file.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent).map_err(io::Error::other)?; + } } - } - let mut auth = read_or_default(&auth_file); - if let Some(key) = api_key { - auth.openai_api_key = Some(key); - } - let tokens = auth - .tokens - .get_or_insert_with(crate::token_data::TokenData::default); - tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?; - // Persist chatgpt_account_id if present in claims - if let Some(acc) = jwt_auth_claims(&id_token) - .get("chatgpt_account_id") - .and_then(|v| v.as_str()) - { - tokens.account_id = Some(acc.to_string()); - } - if let Some(at) = access_token { - tokens.access_token = at; - } - if let Some(rt) = refresh_token { - tokens.refresh_token = rt; - } - auth.last_refresh = Some(Utc::now()); - super::write_auth_json(&auth_file, &auth) + let mut auth = read_or_default(&auth_file); + if let Some(key) = api_key { + auth.openai_api_key = Some(key); + } + let tokens = auth + .tokens + .get_or_insert_with(crate::token_data::TokenData::default); + tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?; + // Persist chatgpt_account_id if present in claims + if let Some(acc) = jwt_auth_claims(&id_token) + .get("chatgpt_account_id") + .and_then(|v| v.as_str()) + { + tokens.account_id = Some(acc.to_string()); + } + if let Some(at) = access_token { + tokens.access_token = at; + } + if let Some(rt) = refresh_token { + tokens.refresh_token = rt; + } + auth.last_refresh = Some(Utc::now()); + super::write_auth_json(&auth_file, &auth) + }) + .await + .map_err(|e| io::Error::other(format!("persist task failed: {e}")))? } fn read_or_default(path: &Path) -> AuthDotJson { @@ -524,13 +550,13 @@ fn jwt_auth_claims(jwt: &str) -> serde_json::Map { serde_json::Map::new() } -fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result { +async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result { // Token exchange for an API key access token #[derive(serde::Deserialize)] struct ExchangeResp { access_token: String, } - let client = reqwest::blocking::Client::new(); + let client = reqwest::Client::new(); let resp = client .post(format!("{issuer}/oauth/token")) .header("Content-Type", "application/x-www-form-urlencoded") @@ -543,6 +569,7 @@ fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result io::Result (SocketAddr, thread::JoinHandle<()>) { (addr, handle) } -#[test] -fn end_to_end_login_flow_persists_auth_json() { +#[tokio::test] +async fn end_to_end_login_flow_persists_auth_json() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { println!( "Skipping test because it cannot execute when network is disabled in a Codex sandbox." @@ -106,16 +106,16 @@ fn end_to_end_login_flow_persists_auth_json() { let login_port = server.actual_port; // Simulate browser callback, and follow redirect to /success - let client = reqwest::blocking::Client::builder() + let client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::limited(5)) .build() .unwrap(); let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=test_state_123"); - let resp = client.get(&url).send().unwrap(); + let resp = client.get(&url).send().await.unwrap(); assert!(resp.status().is_success()); // Wait for server shutdown - server.block_until_done().unwrap(); + server.block_until_done().await.unwrap(); // Validate auth.json let auth_path = codex_home.join("auth.json"); @@ -133,8 +133,8 @@ fn end_to_end_login_flow_persists_auth_json() { drop(issuer_handle); } -#[test] -fn creates_missing_codex_home_dir() { +#[tokio::test] +async fn creates_missing_codex_home_dir() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { println!( "Skipping test because it cannot execute when network is disabled in a Codex sandbox." @@ -164,12 +164,12 @@ fn creates_missing_codex_home_dir() { let server = run_login_server(opts, None).unwrap(); let login_port = server.actual_port; - let client = reqwest::blocking::Client::new(); + let client = reqwest::Client::new(); let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2"); - let resp = client.get(&url).send().unwrap(); + let resp = client.get(&url).send().await.unwrap(); assert!(resp.status().is_success()); - server.block_until_done().unwrap(); + server.block_until_done().await.unwrap(); let auth_path = codex_home.join("auth.json"); assert!( diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index d13bdbf3..7e5da55a 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -180,15 +180,9 @@ impl CodexMessageProcessor { let outgoing_clone = self.outgoing.clone(); let active_login = self.active_login.clone(); tokio::spawn(async move { - let result = - tokio::task::spawn_blocking(move || server.block_until_done()).await; - let (success, error_msg) = match result { - Ok(Ok(())) => (true, None), - Ok(Err(err)) => (false, Some(format!("Login server error: {err}"))), - Err(join_err) => ( - false, - Some(format!("failed to join login server thread: {join_err}")), - ), + let (success, error_msg) = match server.block_until_done().await { + Ok(()) => (true, None), + Err(err) => (false, Some(format!("Login server error: {err}"))), }; let notification = LoginChatGptCompleteNotification { login_id, diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index 8407e84e..7961d75b 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -27,7 +27,6 @@ use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; -use std::thread::JoinHandle; use super::onboarding_screen::StepState; // no additional imports @@ -47,7 +46,7 @@ pub(crate) enum SignInState { pub(crate) struct ContinueInBrowserState { auth_url: String, shutdown_flag: Option>, - _login_wait_handle: Option>, + _login_wait_handle: Option>, } impl Drop for ContinueInBrowserState { fn drop(&mut self) { @@ -288,11 +287,16 @@ impl AuthModeWidget { Ok(child) => { let auth_url = child.auth_url.clone(); let shutdown_flag = child.shutdown_flag.clone(); + + let event_tx = self.event_tx.clone(); + let join_handle = tokio::spawn(async move { + spawn_completion_poller(child, event_tx).await; + }); self.sign_in_state = SignInState::ChatGptContinueInBrowser(ContinueInBrowserState { auth_url, shutdown_flag: Some(shutdown_flag), - _login_wait_handle: Some(self.spawn_completion_poller(child)), + _login_wait_handle: Some(join_handle), }); self.event_tx.send(AppEvent::RequestRedraw); } @@ -313,19 +317,21 @@ impl AuthModeWidget { } self.event_tx.send(AppEvent::RequestRedraw); } +} - fn spawn_completion_poller(&self, child: codex_login::LoginServer) -> JoinHandle<()> { - let event_tx = self.event_tx.clone(); - std::thread::spawn(move || { - if let Ok(()) = child.block_until_done() { - event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); - } else { - event_tx.send(AppEvent::OnboardingAuthComplete(Err( - "login failed".to_string() - ))); - } - }) - } +async fn spawn_completion_poller( + child: codex_login::LoginServer, + event_tx: AppEventSender, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + if let Ok(()) = child.block_until_done().await { + event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); + } else { + event_tx.send(AppEvent::OnboardingAuthComplete(Err( + "login failed".to_string() + ))); + } + }) } impl StepStateProvider for AuthModeWidget {