diff --git a/codex-rs/cli/src/login.rs b/codex-rs/cli/src/login.rs index 895eeb10..a5ee7fa4 100644 --- a/codex-rs/cli/src/login.rs +++ b/codex-rs/cli/src/login.rs @@ -4,42 +4,24 @@ use codex_core::config::ConfigOverrides; use codex_login::AuthMode; use codex_login::CLIENT_ID; use codex_login::CodexAuth; -use codex_login::LoginServerInfo; use codex_login::OPENAI_API_KEY_ENV_VAR; use codex_login::ServerOptions; use codex_login::login_with_api_key; use codex_login::logout; -use codex_login::run_server_blocking_with_notify; +use codex_login::run_login_server; use std::env; -use std::path::Path; -use std::sync::mpsc; +use std::path::PathBuf; -pub async fn login_with_chatgpt(codex_home: &Path) -> std::io::Result<()> { - let (tx, rx) = mpsc::channel::(); - let client_id = CLIENT_ID; - let codex_home = codex_home.to_path_buf(); - tokio::spawn(async move { - match rx.recv() { - Ok(LoginServerInfo { - auth_url, - actual_port, - }) => { - eprintln!( - "Starting local login server on http://localhost:{actual_port}.\nIf your browser did not open, navigate to this URL to authenticate:\n\n{auth_url}", - ); - } - _ => { - tracing::error!("Failed to receive login server info"); - } - } - }); +pub async fn login_with_chatgpt(codex_home: PathBuf) -> std::io::Result<()> { + let opts = ServerOptions::new(codex_home, CLIENT_ID.to_string()); + let server = run_login_server(opts, None)?; - tokio::task::spawn_blocking(move || { - let opts = ServerOptions::new(&codex_home, client_id); - run_server_blocking_with_notify(opts, Some(tx), None) - }) - .await - .map_err(std::io::Error::other)??; + eprintln!( + "Starting local login server on http://localhost:{}.\nIf your browser did not open, navigate to this URL to authenticate:\n\n{}", + server.actual_port, server.auth_url, + ); + + server.block_until_done()?; eprintln!("Successfully logged in"); Ok(()) @@ -48,7 +30,7 @@ pub async fn login_with_chatgpt(codex_home: &Path) -> std::io::Result<()> { pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! { let config = load_config_or_exit(cli_config_overrides); - match login_with_chatgpt(&config.codex_home).await { + match login_with_chatgpt(config.codex_home).await { Ok(_) => { eprintln!("Successfully logged in"); std::process::exit(0); diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index c40c1e6b..7a5f0277 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -16,10 +16,9 @@ use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; -pub use crate::server::LoginServerInfo; +pub use crate::server::LoginServer; pub use crate::server::ServerOptions; -pub use crate::server::run_server_blocking; -pub use crate::server::run_server_blocking_with_notify; +pub use crate::server::run_login_server; pub use crate::token_data::TokenData; use crate::token_data::parse_id_token; @@ -252,65 +251,6 @@ pub fn logout(codex_home: &Path) -> std::io::Result { } } -/// Represents a running login server. The server can be stopped by calling `cancel()` on SpawnedLogin. -#[derive(Debug, Clone)] -pub struct SpawnedLogin { - url: Arc>>, - done: Arc>>, - shutdown: Arc, -} - -impl SpawnedLogin { - pub fn get_login_url(&self) -> Option { - self.url.lock().ok().and_then(|u| u.clone()) - } - - pub fn get_auth_result(&self) -> Option { - self.done.lock().ok().and_then(|d| *d) - } - - pub fn cancel(&self) { - self.shutdown - .store(true, std::sync::atomic::Ordering::SeqCst); - } -} - -pub fn spawn_login_with_chatgpt(codex_home: &Path) -> std::io::Result { - let (tx, rx) = std::sync::mpsc::channel::(); - let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let done = Arc::new(Mutex::new(None::)); - let url = Arc::new(Mutex::new(None::)); - - let codex_home_buf = codex_home.to_path_buf(); - let client_id = CLIENT_ID.to_string(); - - let shutdown_clone = shutdown.clone(); - let done_clone = done.clone(); - std::thread::spawn(move || { - let opts = ServerOptions::new(&codex_home_buf, &client_id); - let res = run_server_blocking_with_notify(opts, Some(tx), Some(shutdown_clone)); - let success = res.is_ok(); - if let Ok(mut lock) = done_clone.lock() { - *lock = Some(success); - } - }); - - let url_clone = url.clone(); - std::thread::spawn(move || { - if let Ok(u) = rx.recv() { - if let Ok(mut lock) = url_clone.lock() { - *lock = Some(u.auth_url); - } - } - }); - - Ok(SpawnedLogin { - url, - done, - shutdown, - }) -} - pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> { let auth_dot_json = AuthDotJson { openai_api_key: Some(api_key.to_string()), diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index 550ee703..9365905f 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -1,39 +1,40 @@ use std::io::{self}; use std::path::Path; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; +use std::thread; +use crate::AuthDotJson; +use crate::get_auth_file; +use crate::pkce::PkceCodes; +use crate::pkce::generate_pkce; use base64::Engine; use chrono::Utc; use rand::RngCore; use tiny_http::Response; use tiny_http::Server; -use crate::AuthDotJson; -use crate::get_auth_file; -use crate::pkce::PkceCodes; -use crate::pkce::generate_pkce; - const DEFAULT_ISSUER: &str = "https://auth.openai.com"; const DEFAULT_PORT: u16 = 1455; #[derive(Debug, Clone)] -pub struct ServerOptions<'a> { - pub codex_home: &'a Path, - pub client_id: &'a str, - pub issuer: &'a str, +pub struct ServerOptions { + pub codex_home: PathBuf, + pub client_id: String, + pub issuer: String, pub port: u16, pub open_browser: bool, pub force_state: Option, } -impl<'a> ServerOptions<'a> { - pub fn new(codex_home: &'a Path, client_id: &'a str) -> Self { +impl ServerOptions { + pub fn new(codex_home: PathBuf, client_id: String) -> Self { Self { codex_home, - client_id, - issuer: DEFAULT_ISSUER, + client_id: client_id.to_string(), + issuer: DEFAULT_ISSUER.to_string(), port: DEFAULT_PORT, open_browser: true, force_state: None, @@ -41,21 +42,31 @@ impl<'a> ServerOptions<'a> { } } -#[allow(dead_code)] -pub fn run_server_blocking(opts: ServerOptions) -> io::Result<()> { - run_server_blocking_with_notify(opts, None, None) -} - -pub struct LoginServerInfo { +#[derive(Debug)] +pub struct LoginServer { pub auth_url: String, pub actual_port: u16, + pub server_handle: thread::JoinHandle>, + pub shutdown_flag: Arc, } -pub fn run_server_blocking_with_notify( +impl LoginServer { + pub fn block_until_done(self) -> io::Result<()> { + #[expect(clippy::expect_used)] + self.server_handle + .join() + .expect("can't join on the server thread") + } + + pub fn cancel(&self) { + self.shutdown_flag.store(true, Ordering::SeqCst); + } +} + +pub fn run_login_server( opts: ServerOptions, - notify_started: Option>, shutdown_flag: Option>, -) -> io::Result<()> { +) -> io::Result { let pkce = generate_pkce(); let state = opts.force_state.clone().unwrap_or_else(generate_state); @@ -71,135 +82,138 @@ pub fn run_server_blocking_with_notify( }; let redirect_uri = format!("http://localhost:{actual_port}/auth/callback"); - let auth_url = build_authorize_url(opts.issuer, opts.client_id, &redirect_uri, &pkce, &state); - - if let Some(tx) = ¬ify_started { - let _ = tx.send(LoginServerInfo { - auth_url: auth_url.clone(), - actual_port, - }); - } + let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state); if opts.open_browser { let _ = webbrowser::open(&auth_url); } - let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); - while !shutdown_flag.load(Ordering::SeqCst) { - let req = match server.recv() { - Ok(r) => r, - Err(e) => return Err(io::Error::other(e)), - }; + let shutdown_flag_clone = shutdown_flag.clone(); + let server_handle = thread::spawn(move || { + while !shutdown_flag.load(Ordering::SeqCst) { + let req = match server.recv() { + Ok(r) => r, + Err(e) => return Err(io::Error::other(e)), + }; - let url_raw = req.url().to_string(); - let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { - Ok(u) => u, - Err(e) => { - eprintln!("URL parse error: {e}"); - let _ = req.respond(Response::from_string("Bad Request").with_status_code(400)); - continue; - } - }; - let path = parsed_url.path().to_string(); - - match path.as_str() { - "/auth/callback" => { - let params: std::collections::HashMap = - parsed_url.query_pairs().into_owned().collect(); - if params.get("state").map(String::as_str) != Some(state.as_str()) { - let _ = - req.respond(Response::from_string("State mismatch").with_status_code(400)); + let url_raw = req.url().to_string(); + let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { + Ok(u) => u, + Err(e) => { + eprintln!("URL parse error: {e}"); + let _ = req.respond(Response::from_string("Bad Request").with_status_code(400)); continue; } - let code = match params.get("code") { - Some(c) if !c.is_empty() => c.clone(), - _ => { - let _ = req.respond( - Response::from_string("Missing authorization code") - .with_status_code(400), - ); + }; + let path = parsed_url.path().to_string(); + + match path.as_str() { + "/auth/callback" => { + let params: std::collections::HashMap = + parsed_url.query_pairs().into_owned().collect(); + if params.get("state").map(String::as_str) != Some(state.as_str()) { + let _ = req + .respond(Response::from_string("State mismatch").with_status_code(400)); continue; } - }; - - match exchange_code_for_tokens( - opts.issuer, - opts.client_id, - &redirect_uri, - &pkce, - &code, - ) { - Ok(tokens) => { - // Obtain API key via token-exchange and persist - let api_key = - obtain_api_key(opts.issuer, opts.client_id, &tokens.id_token).ok(); - if let Err(err) = persist_tokens( - opts.codex_home, - api_key.clone(), - tokens.id_token.clone(), - Some(tokens.access_token.clone()), - Some(tokens.refresh_token.clone()), - ) { - eprintln!("Persist error: {err}"); + let code = match params.get("code") { + Some(c) if !c.is_empty() => c.clone(), + _ => { let _ = req.respond( - Response::from_string(format!( - "Unable to persist auth file: {err}" - )) - .with_status_code(500), + Response::from_string("Missing authorization code") + .with_status_code(400), ); continue; } + }; - let success_url = compose_success_url( - actual_port, - opts.issuer, - &tokens.id_token, - &tokens.access_token, - ); - match tiny_http::Header::from_bytes( - &b"Location"[..], - success_url.as_bytes(), - ) { - Ok(h) => { - let response = tiny_http::Response::empty(302).with_header(h); - let _ = req.respond(response); - } - Err(_) => { + match exchange_code_for_tokens( + &opts.issuer, + &opts.client_id, + &redirect_uri, + &pkce, + &code, + ) { + Ok(tokens) => { + // Obtain API key via token-exchange and persist + let api_key = + obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token) + .ok(); + if let Err(err) = persist_tokens( + &opts.codex_home, + api_key.clone(), + tokens.id_token.clone(), + Some(tokens.access_token.clone()), + Some(tokens.refresh_token.clone()), + ) { + eprintln!("Persist error: {err}"); let _ = req.respond( - Response::from_string("Internal Server Error") - .with_status_code(500), + Response::from_string(format!( + "Unable to persist auth file: {err}" + )) + .with_status_code(500), ); + continue; + } + + let success_url = compose_success_url( + actual_port, + &opts.issuer, + &tokens.id_token, + &tokens.access_token, + ); + match tiny_http::Header::from_bytes( + &b"Location"[..], + success_url.as_bytes(), + ) { + Ok(h) => { + let response = tiny_http::Response::empty(302).with_header(h); + let _ = req.respond(response); + } + Err(_) => { + let _ = req.respond( + Response::from_string("Internal Server Error") + .with_status_code(500), + ); + } } } - } - Err(err) => { - eprintln!("Token exchange error: {err}"); - let _ = req.respond( - Response::from_string(format!("Token exchange failed: {err}")) - .with_status_code(500), - ); + Err(err) => { + eprintln!("Token exchange error: {err}"); + let _ = req.respond( + Response::from_string(format!("Token exchange failed: {err}")) + .with_status_code(500), + ); + } } } - } - "/success" => { - let body = include_str!("assets/success.html"); - let mut resp = Response::from_data(body.as_bytes()); - if let Ok(h) = tiny_http::Header::from_bytes( - &b"Content-Type"[..], - &b"text/html; charset=utf-8"[..], - ) { - resp.add_header(h); + "/success" => { + let body = include_str!("assets/success.html"); + let mut resp = Response::from_data(body.as_bytes()); + if let Ok(h) = tiny_http::Header::from_bytes( + &b"Content-Type"[..], + &b"text/html; charset=utf-8"[..], + ) { + resp.add_header(h); + } + let _ = req.respond(resp); + shutdown_flag.store(true, Ordering::SeqCst); + return Ok(()); + } + _ => { + let _ = req.respond(Response::from_string("Not Found").with_status_code(404)); } - let _ = req.respond(resp); - shutdown_flag.store(true, Ordering::SeqCst); - } - _ => { - let _ = req.respond(Response::from_string("Not Found").with_status_code(404)); } } - } + Err(io::Error::other("Login flow was not completed")) + }); - Ok(()) + Ok(LoginServer { + auth_url: auth_url.clone(), + actual_port, + server_handle, + shutdown_flag: shutdown_flag_clone, + }) } fn build_authorize_url( diff --git a/codex-rs/login/tests/login_server_e2e.rs b/codex-rs/login/tests/login_server_e2e.rs index 3d004800..b3e12468 100644 --- a/codex-rs/login/tests/login_server_e2e.rs +++ b/codex-rs/login/tests/login_server_e2e.rs @@ -1,12 +1,11 @@ -#![expect(clippy::unwrap_used)] +#![allow(clippy::unwrap_used)] use std::net::SocketAddr; use std::net::TcpListener; use std::thread; use base64::Engine; -use codex_login::LoginServerInfo; use codex_login::ServerOptions; -use codex_login::run_server_blocking_with_notify; +use codex_login::run_login_server; use tempfile::tempdir; // See spawn.rs for details @@ -94,21 +93,16 @@ fn end_to_end_login_flow_persists_auth_json() { // Run server in background let server_home = codex_home.clone(); - let (tx, rx) = std::sync::mpsc::channel::(); - let server_thread = thread::spawn(move || { - let opts = ServerOptions { - codex_home: &server_home, - client_id: codex_login::CLIENT_ID, - issuer: &issuer, - port: 0, - open_browser: false, - force_state: Some(state), - }; - run_server_blocking_with_notify(opts, Some(tx), None).unwrap(); - }); - - let server_info = rx.recv().unwrap(); - let login_port = server_info.actual_port; + let opts = ServerOptions { + codex_home: server_home, + client_id: codex_login::CLIENT_ID.to_string(), + issuer, + port: 0, + open_browser: false, + force_state: Some(state), + }; + let server = run_login_server(opts, None).unwrap(); + let login_port = server.actual_port; // Simulate browser callback, and follow redirect to /success let client = reqwest::blocking::Client::builder() @@ -120,9 +114,7 @@ fn end_to_end_login_flow_persists_auth_json() { assert!(resp.status().is_success()); // Wait for server shutdown - server_thread - .join() - .unwrap_or_else(|_| panic!("server thread panicked")); + server.block_until_done().unwrap(); // Validate auth.json let auth_path = codex_home.join("auth.json"); @@ -159,30 +151,23 @@ fn creates_missing_codex_home_dir() { // Run server in background let server_home = codex_home.clone(); - let (tx, rx) = std::sync::mpsc::channel::(); - let server_thread = thread::spawn(move || { - let opts = ServerOptions { - codex_home: &server_home, - client_id: codex_login::CLIENT_ID, - issuer: &issuer, - port: 0, - open_browser: false, - force_state: Some(state), - }; - run_server_blocking_with_notify(opts, Some(tx), None).unwrap() - }); - - let server_info = rx.recv().unwrap(); - let login_port = server_info.actual_port; + let opts = ServerOptions { + codex_home: server_home, + client_id: codex_login::CLIENT_ID.to_string(), + issuer, + port: 0, + open_browser: false, + force_state: Some(state), + }; + let server = run_login_server(opts, None).unwrap(); + let login_port = server.actual_port; let client = reqwest::blocking::Client::new(); let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2"); let resp = client.get(&url).send().unwrap(); assert!(resp.status().is_success()); - server_thread - .join() - .unwrap_or_else(|_| panic!("server thread panicked")); + server.block_until_done().unwrap(); let auth_path = codex_home.join("auth.json"); assert!( diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index a8864e4d..de986786 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -1,3 +1,6 @@ +use codex_login::CLIENT_ID; +use codex_login::ServerOptions; +use codex_login::run_login_server; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use ratatui::buffer::Buffer; @@ -22,6 +25,10 @@ use crate::onboarding::onboarding_screen::KeyboardHandler; use crate::onboarding::onboarding_screen::StepStateProvider; use crate::shimmer::shimmer_spans; use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; +use std::thread::JoinHandle; use super::onboarding_screen::StepState; // no additional imports @@ -39,12 +46,14 @@ pub(crate) enum SignInState { #[derive(Debug)] /// Used to manage the lifecycle of SpawnedLogin and ensure it gets cleaned up. pub(crate) struct ContinueInBrowserState { - login_child: Option, + auth_url: String, + shutdown_flag: Option>, + _login_wait_handle: Option>, } impl Drop for ContinueInBrowserState { fn drop(&mut self) { - if let Some(child) = &self.login_child { - child.cancel(); + if let Some(flag) = &self.shutdown_flag { + flag.store(true, Ordering::SeqCst); } } } @@ -184,16 +193,12 @@ impl AuthModeWidget { let mut lines = vec![Line::from(spans), Line::from("")]; if let SignInState::ChatGptContinueInBrowser(state) = &self.sign_in_state { - if let Some(url) = state - .login_child - .as_ref() - .and_then(|child| child.get_login_url()) - { + if !state.auth_url.is_empty() { lines.push(Line::from(" If the link doesn't open automatically, open the following link to authenticate:")); lines.push(Line::from(vec![ Span::raw(" "), Span::styled( - url, + state.auth_url.as_str(), Style::default() .fg(LIGHT_BLUE) .add_modifier(Modifier::UNDERLINED), @@ -289,12 +294,17 @@ impl AuthModeWidget { fn start_chatgpt_login(&mut self) { self.error = None; - match codex_login::spawn_login_with_chatgpt(&self.codex_home) { + let opts = ServerOptions::new(self.codex_home.clone(), CLIENT_ID.to_string()); + let server = run_login_server(opts, None); + match server { Ok(child) => { - self.spawn_completion_poller(child.clone()); + let auth_url = child.auth_url.clone(); + let shutdown_flag = child.shutdown_flag.clone(); self.sign_in_state = SignInState::ChatGptContinueInBrowser(ContinueInBrowserState { - login_child: Some(child), + auth_url, + shutdown_flag: Some(shutdown_flag), + _login_wait_handle: Some(self.spawn_completion_poller(child)), }); self.event_tx.send(AppEvent::RequestRedraw); } @@ -316,23 +326,17 @@ impl AuthModeWidget { self.event_tx.send(AppEvent::RequestRedraw); } - fn spawn_completion_poller(&self, child: codex_login::SpawnedLogin) { + fn spawn_completion_poller(&self, child: codex_login::LoginServer) -> JoinHandle<()> { let event_tx = self.event_tx.clone(); std::thread::spawn(move || { - loop { - if let Some(success) = child.get_auth_result() { - if success { - event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); - } else { - event_tx.send(AppEvent::OnboardingAuthComplete(Err( - "login failed".to_string() - ))); - } - break; - } - std::thread::sleep(std::time::Duration::from_millis(250)); + if let Ok(()) = child.block_until_done() { + event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); + } else { + event_tx.send(AppEvent::OnboardingAuthComplete(Err( + "login failed".to_string() + ))); } - }); + }) } }