From e9b597cfa3c78a239c914ad1ff0ff2c6f5e3310a Mon Sep 17 00:00:00 2001 From: easong-openai Date: Thu, 14 Aug 2025 17:11:26 -0700 Subject: [PATCH] Port login server to rust (#2294) Port the login server to rust. --------- Co-authored-by: pakrym-oai --- codex-rs/Cargo.lock | 213 +++- codex-rs/cli/src/login.rs | 44 +- codex-rs/login/Cargo.toml | 8 +- codex-rs/login/src/assets/success.html | 198 ++++ codex-rs/login/src/lib.rs | 204 ++-- codex-rs/login/src/login_with_chatgpt.py | 933 ------------------ codex-rs/login/src/pkce.rs | 27 + codex-rs/login/src/server.rs | 443 +++++++++ codex-rs/login/src/token_data.rs | 17 +- codex-rs/login/tests/login_server_e2e.rs | 192 ++++ codex-rs/tui/src/app.rs | 4 + codex-rs/tui/src/onboarding/auth.rs | 30 +- .../tui/src/onboarding/onboarding_screen.rs | 12 + 13 files changed, 1228 insertions(+), 1097 deletions(-) create mode 100644 codex-rs/login/src/assets/success.html delete mode 100644 codex-rs/login/src/login_with_chatgpt.py create mode 100644 codex-rs/login/src/pkce.rs create mode 100644 codex-rs/login/src/server.rs create mode 100644 codex-rs/login/tests/login_server_e2e.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 41392633..8f077bc0 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -203,6 +203,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "ascii" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -481,6 +487,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cesu8" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" + [[package]] name = "cfg-expr" version = "0.15.8" @@ -518,6 +530,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "chunked_transfer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" + [[package]] name = "clap" version = "4.5.43" @@ -798,12 +816,18 @@ dependencies = [ "base64 0.22.1", "chrono", "pretty_assertions", + "rand 0.8.5", "reqwest", "serde", "serde_json", + "sha2", "tempfile", "thiserror 2.0.12", + "tiny_http", "tokio", + "url", + "urlencoding", + "webbrowser", ] [[package]] @@ -951,6 +975,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "memchr", +] + [[package]] name = "compact_str" version = "0.8.1" @@ -1005,6 +1039,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -2455,6 +2499,28 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "jni" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +dependencies = [ + "cesu8", + "cfg-if", + "combine", + "jni-sys", + "log", + "thiserror 1.0.69", + "walkdir", + "windows-sys 0.45.0", +] + +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "jobserver" version = "0.1.33" @@ -2791,6 +2857,12 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -2944,6 +3016,31 @@ dependencies = [ "libc", ] +[[package]] +name = "objc2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "561f357ba7f3a2a61563a186a163d0a3a5247e1089524a3981d49adb775078bc" +dependencies = [ + "objc2-encode", +] + +[[package]] +name = "objc2-encode" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" + +[[package]] +name = "objc2-foundation" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" +dependencies = [ + "bitflags 2.9.1", + "objc2", +] + [[package]] name = "object" version = "0.36.7" @@ -3670,6 +3767,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2", @@ -3992,7 +4090,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", "core-foundation-sys", "libc", "security-framework-sys", @@ -4151,6 +4249,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -4515,7 +4624,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.9.1", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -4710,6 +4819,18 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tiny_http" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" +dependencies = [ + "ascii", + "chunked_transfer", + "httpdate", + "log", +] + [[package]] name = "tinystr" version = "0.8.1" @@ -5162,6 +5283,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -5385,6 +5512,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaf4f3c0ba838e82b4e5ccc4157003fb8c324ee24c058470ffb82820becbde98" +dependencies = [ + "core-foundation 0.10.1", + "jni", + "log", + "ndk-context", + "objc2", + "objc2-foundation", + "url", + "web-sys", +] + [[package]] name = "weezl" version = "0.1.10" @@ -5573,6 +5716,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -5600,6 +5752,21 @@ dependencies = [ "windows-targets 0.53.2", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -5632,6 +5799,12 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -5644,6 +5817,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -5656,6 +5835,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -5680,6 +5865,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -5692,6 +5883,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -5704,6 +5901,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -5716,6 +5919,12 @@ version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/codex-rs/cli/src/login.rs b/codex-rs/cli/src/login.rs index 1a70bd27..895eeb10 100644 --- a/codex-rs/cli/src/login.rs +++ b/codex-rs/cli/src/login.rs @@ -1,20 +1,54 @@ -use std::env; - use codex_common::CliConfigOverrides; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_login::AuthMode; +use codex_login::CLIENT_ID; use codex_login::CodexAuth; +use codex_login::LoginServerInfo; use codex_login::OPENAI_API_KEY_ENV_VAR; +use codex_login::ServerOptions; use codex_login::login_with_api_key; -use codex_login::login_with_chatgpt; use codex_login::logout; +use codex_login::run_server_blocking_with_notify; +use std::env; +use std::path::Path; +use std::sync::mpsc; + +pub async fn login_with_chatgpt(codex_home: &Path) -> std::io::Result<()> { + let (tx, rx) = mpsc::channel::(); + let client_id = CLIENT_ID; + let codex_home = codex_home.to_path_buf(); + tokio::spawn(async move { + match rx.recv() { + Ok(LoginServerInfo { + auth_url, + actual_port, + }) => { + eprintln!( + "Starting local login server on http://localhost:{actual_port}.\nIf your browser did not open, navigate to this URL to authenticate:\n\n{auth_url}", + ); + } + _ => { + tracing::error!("Failed to receive login server info"); + } + } + }); + + tokio::task::spawn_blocking(move || { + let opts = ServerOptions::new(&codex_home, client_id); + run_server_blocking_with_notify(opts, Some(tx), None) + }) + .await + .map_err(std::io::Error::other)??; + + eprintln!("Successfully logged in"); + Ok(()) +} pub async fn run_login_with_chatgpt(cli_config_overrides: CliConfigOverrides) -> ! { let config = load_config_or_exit(cli_config_overrides); - let capture_output = false; - match login_with_chatgpt(&config.codex_home, capture_output).await { + match login_with_chatgpt(&config.codex_home).await { Ok(_) => { eprintln!("Successfully logged in"); std::process::exit(0); diff --git a/codex-rs/login/Cargo.toml b/codex-rs/login/Cargo.toml index 85c11505..c1e21ca6 100644 --- a/codex-rs/login/Cargo.toml +++ b/codex-rs/login/Cargo.toml @@ -9,11 +9,14 @@ workspace = true [dependencies] base64 = "0.22" chrono = { version = "0.4", features = ["serde"] } -reqwest = { version = "0.12", features = ["json"] } +rand = "0.8" +reqwest = { version = "0.12", features = ["json", "blocking"] } serde = { version = "1", features = ["derive"] } serde_json = "1" +sha2 = "0.10" tempfile = "3" thiserror = "2.0.12" +tiny_http = "0.12" tokio = { version = "1", features = [ "io-std", "macros", @@ -21,6 +24,9 @@ tokio = { version = "1", features = [ "rt-multi-thread", "signal", ] } +url = "2" +urlencoding = "2.1" +webbrowser = "1.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/codex-rs/login/src/assets/success.html b/codex-rs/login/src/assets/success.html new file mode 100644 index 00000000..eb2a0ee7 --- /dev/null +++ b/codex-rs/login/src/assets/success.html @@ -0,0 +1,198 @@ + + + + + Sign into Codex CLI + + + + +
+
+
+ +
Signed in to Codex CLI
+
+ + +
+
+ + + + \ No newline at end of file diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index a1dad79e..d4358d27 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -1,5 +1,4 @@ use chrono::DateTime; - use chrono::Utc; use serde::Deserialize; use serde::Serialize; @@ -9,27 +8,26 @@ use std::fs::OpenOptions; use std::fs::remove_file; use std::io::Read; use std::io::Write; -use std::io::{self}; #[cfg(unix)] use std::os::unix::fs::OpenOptionsExt; use std::path::Path; use std::path::PathBuf; -use std::process::Child; -use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex; use std::time::Duration; -use tempfile::NamedTempFile; -use tokio::process::Command; +pub use crate::server::LoginServerInfo; +pub use crate::server::ServerOptions; +pub use crate::server::run_server_blocking; +pub use crate::server::run_server_blocking_with_notify; pub use crate::token_data::TokenData; use crate::token_data::parse_id_token; +mod pkce; +mod server; mod token_data; -const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py"); - -const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; #[derive(Clone, Debug, PartialEq, Copy)] @@ -254,139 +252,65 @@ pub fn logout(codex_home: &Path) -> std::io::Result { } } -/// Represents a running login subprocess. The child can be killed by holding -/// the mutex and calling `kill()`. +/// Represents a running login server. The server can be stopped by calling `cancel()` on SpawnedLogin. #[derive(Debug, Clone)] pub struct SpawnedLogin { - pub child: Arc>, - pub stdout: Arc>>, - pub stderr: Arc>>, + url: Arc>>, + done: Arc>>, + shutdown: Arc, } impl SpawnedLogin { - /// Returns the login URL, if one has been emitted by the login subprocess. - /// - /// The Python helper prints the URL to stderr; we capture it and extract - /// the last whitespace-separated token that starts with "http". pub fn get_login_url(&self) -> Option { - self.stderr - .lock() - .ok() - .and_then(|buffer| String::from_utf8(buffer.clone()).ok()) - .and_then(|output| { - output - .split_whitespace() - .filter(|part| part.starts_with("http")) - .next_back() - .map(|s| s.to_string()) - }) + self.url.lock().ok().and_then(|u| u.clone()) + } + + pub fn get_auth_result(&self) -> Option { + self.done.lock().ok().and_then(|d| *d) + } + + pub fn cancel(&self) { + self.shutdown + .store(true, std::sync::atomic::Ordering::SeqCst); } } -// Helpers for streaming child output into shared buffers -struct AppendWriter { - buf: Arc>>, -} - -impl Write for AppendWriter { - fn write(&mut self, data: &[u8]) -> io::Result { - if let Ok(mut b) = self.buf.lock() { - b.extend_from_slice(data); - } - Ok(data.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -fn spawn_pipe_reader(mut reader: R, buf: Arc>>) { - std::thread::spawn(move || { - let _ = io::copy(&mut reader, &mut AppendWriter { buf }); - }); -} - -/// Spawn the ChatGPT login Python server as a child process and return a handle to its process. pub fn spawn_login_with_chatgpt(codex_home: &Path) -> std::io::Result { - let script_path = write_login_script_to_disk()?; - let mut cmd = std::process::Command::new("python3"); - cmd.arg(&script_path) - .env("CODEX_HOME", codex_home) - .env("CODEX_CLIENT_ID", CLIENT_ID) - .stdin(Stdio::null()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()); + let (tx, rx) = std::sync::mpsc::channel::(); + let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let done = Arc::new(Mutex::new(None::)); + let url = Arc::new(Mutex::new(None::)); - let mut child = cmd.spawn()?; + let codex_home_buf = codex_home.to_path_buf(); + let client_id = CLIENT_ID.to_string(); - let stdout_buf = Arc::new(Mutex::new(Vec::new())); - let stderr_buf = Arc::new(Mutex::new(Vec::new())); + let shutdown_clone = shutdown.clone(); + let done_clone = done.clone(); + std::thread::spawn(move || { + let opts = ServerOptions::new(&codex_home_buf, &client_id); + let res = run_server_blocking_with_notify(opts, Some(tx), Some(shutdown_clone)); + let success = res.is_ok(); + if let Ok(mut lock) = done_clone.lock() { + *lock = Some(success); + } + }); - if let Some(out) = child.stdout.take() { - spawn_pipe_reader(out, stdout_buf.clone()); - } - if let Some(err) = child.stderr.take() { - spawn_pipe_reader(err, stderr_buf.clone()); - } + let url_clone = url.clone(); + std::thread::spawn(move || { + if let Ok(u) = rx.recv() { + if let Ok(mut lock) = url_clone.lock() { + *lock = Some(u.auth_url); + } + } + }); Ok(SpawnedLogin { - child: Arc::new(Mutex::new(child)), - stdout: stdout_buf, - stderr: stderr_buf, + url, + done, + shutdown, }) } -/// Run `python3 -c {{SOURCE_FOR_PYTHON_SERVER}}` with the CODEX_HOME -/// environment variable set to the provided `codex_home` path. If the -/// subprocess exits 0, read the OPENAI_API_KEY property out of -/// CODEX_HOME/auth.json and return Ok(OPENAI_API_KEY). Otherwise, return Err -/// with any information from the subprocess. -/// -/// If `capture_output` is true, the subprocess's output will be captured and -/// recorded in memory. Otherwise, the subprocess's output will be sent to the -/// current process's stdout/stderr. -pub async fn login_with_chatgpt(codex_home: &Path, capture_output: bool) -> std::io::Result<()> { - let script_path = write_login_script_to_disk()?; - let child = Command::new("python3") - .arg(&script_path) - .env("CODEX_HOME", codex_home) - .env("CODEX_CLIENT_ID", CLIENT_ID) - .stdin(Stdio::null()) - .stdout(if capture_output { - Stdio::piped() - } else { - Stdio::inherit() - }) - .stderr(if capture_output { - Stdio::piped() - } else { - Stdio::inherit() - }) - .spawn()?; - - let output = child.wait_with_output().await?; - if output.status.success() { - Ok(()) - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - Err(std::io::Error::other(format!( - "login_with_chatgpt subprocess failed: {stderr}" - ))) - } -} - -fn write_login_script_to_disk() -> std::io::Result { - // Write the embedded Python script to a file to avoid very long - // command-line arguments (Windows error 206). - let mut tmp = NamedTempFile::new()?; - tmp.write_all(SOURCE_FOR_PYTHON_SERVER.as_bytes())?; - tmp.flush()?; - - let (_file, path) = tmp.keep()?; - Ok(path) -} - pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> { let auth_dot_json = AuthDotJson { openai_api_key: Some(api_key.to_string()), @@ -538,7 +462,7 @@ mod tests { } #[tokio::test] - async fn pro_account_with_no_api_key_uses_chatgpt_auth() { + async fn roundtrip_auth_dot_json() { let codex_home = tempdir().unwrap(); write_auth_file( AuthFileParams { @@ -549,6 +473,26 @@ mod tests { ) .expect("failed to write auth file"); + let file = get_auth_file(codex_home.path()); + let auth_dot_json = try_read_auth_json(&file).unwrap(); + write_auth_json(&file, &auth_dot_json).unwrap(); + + let same_auth_dot_json = try_read_auth_json(&file).unwrap(); + assert_eq!(auth_dot_json, same_auth_dot_json); + } + + #[tokio::test] + async fn pro_account_with_no_api_key_uses_chatgpt_auth() { + let codex_home = tempdir().unwrap(); + let fake_jwt = write_auth_file( + AuthFileParams { + openai_api_key: None, + chatgpt_plan_type: "pro".to_string(), + }, + codex_home.path(), + ) + .expect("failed to write auth file"); + let CodexAuth { api_key, mode, @@ -567,6 +511,7 @@ mod tests { id_token: IdTokenInfo { email: Some("user@example.com".to_string()), chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)), + raw_jwt: fake_jwt, }, access_token: "test-access-token".to_string(), refresh_token: "test-refresh-token".to_string(), @@ -588,7 +533,7 @@ mod tests { #[tokio::test] async fn pro_account_with_api_key_still_uses_chatgpt_auth() { let codex_home = tempdir().unwrap(); - write_auth_file( + let fake_jwt = write_auth_file( AuthFileParams { openai_api_key: Some("sk-test-key".to_string()), chatgpt_plan_type: "pro".to_string(), @@ -615,6 +560,7 @@ mod tests { id_token: IdTokenInfo { email: Some("user@example.com".to_string()), chatgpt_plan_type: Some(PlanType::Known(KnownPlan::Pro)), + raw_jwt: fake_jwt, }, access_token: "test-access-token".to_string(), refresh_token: "test-refresh-token".to_string(), @@ -662,7 +608,7 @@ mod tests { chatgpt_plan_type: String, } - fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result<()> { + fn write_auth_file(params: AuthFileParams, codex_home: &Path) -> std::io::Result { let auth_file = get_auth_file(codex_home); // Create a minimal valid JWT for the id_token field. #[derive(Serialize)] @@ -700,7 +646,9 @@ mod tests { "last_refresh": LAST_REFRESH, }); let auth_json = serde_json::to_string_pretty(&auth_json_data)?; - std::fs::write(auth_file, auth_json) + std::fs::write(auth_file, auth_json)?; + + Ok(fake_jwt) } #[test] diff --git a/codex-rs/login/src/login_with_chatgpt.py b/codex-rs/login/src/login_with_chatgpt.py deleted file mode 100644 index 252c4e06..00000000 --- a/codex-rs/login/src/login_with_chatgpt.py +++ /dev/null @@ -1,933 +0,0 @@ -"""Script that spawns a local webserver for retrieving an OpenAI API key. - -- Listens on 127.0.0.1:1455 -- Opens http://localhost:1455/auth/callback in the browser -- If the user successfully navigates the auth flow, - $CODEX_HOME/auth.json will be written with the API key. -- User will be redirected to http://localhost:1455/success upon success. - -The script should exit with a non-zero code if the user fails to navigate the -auth flow. - -To test this script locally without overwriting your existing auth.json file: - -``` -rm -rf /tmp/codex_home && mkdir /tmp/codex_home -CODEX_HOME=/tmp/codex_home python3 codex-rs/login/src/login_with_chatgpt.py -``` -""" - -from __future__ import annotations - -import argparse -import base64 -import datetime -import errno -import hashlib -import http.server -import json -import os -import secrets -import sys -import threading -import time -import urllib.parse -import urllib.request -import webbrowser -from dataclasses import dataclass -from typing import Any, Dict # for type hints - -# Required port for OAuth client. -REQUIRED_PORT = 1455 -URL_BASE = f"http://localhost:{REQUIRED_PORT}" -DEFAULT_ISSUER = "https://auth.openai.com" - -EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE = 13 - -CA_CONTEXT = None -CODEX_LOGIN_TRACE = os.environ.get("CODEX_LOGIN_TRACE", "false") in ["true", "1"] - -try: - - def trace(msg: str) -> None: - if CODEX_LOGIN_TRACE: - print(msg) - - def attempt_request(method: str) -> bool: - try: - with urllib.request.urlopen( - urllib.request.Request( - f"{DEFAULT_ISSUER}/.well-known/openid-configuration", - method="GET", - ), - context=CA_CONTEXT, - ) as resp: - if resp.status != 200: - trace(f"Request using {method} failed: {resp.status}") - return False - - trace(f"Request using {method} succeeded") - return True - except Exception as e: - trace(f"Request using {method} failed: {e}") - return False - - status = attempt_request("default settings") - if not status: - try: - import truststore - - truststore.inject_into_ssl() - status = attempt_request("truststore") - except Exception as e: - trace(f"Failed to use truststore: {e}") - - if not status: - try: - import ssl - import certifi as _certifi - - CA_CONTEXT = ssl.create_default_context(cafile=_certifi.where()) - status = attempt_request("certify") - except Exception as e: - trace(f"Failed to use certify: {e}") - - -except Exception: - pass - - -@dataclass -class TokenData: - id_token: str - access_token: str - refresh_token: str - account_id: str - - -@dataclass -class AuthBundle: - """Aggregates authentication data produced after successful OAuth flow.""" - - api_key: str | None - token_data: TokenData - last_refresh: str - - -def main() -> None: - parser = argparse.ArgumentParser(description="Retrieve API key via local HTTP flow") - parser.add_argument( - "--no-browser", - action="store_true", - help="Do not automatically open the browser", - ) - parser.add_argument("--verbose", action="store_true", help="Enable request logging") - args = parser.parse_args() - - codex_home = os.environ.get("CODEX_HOME") - if not codex_home: - eprint("ERROR: CODEX_HOME environment variable is not set") - sys.exit(1) - - client_id = os.getenv("CODEX_CLIENT_ID") - if not client_id: - eprint("ERROR: CODEX_CLIENT_ID environment variable is not set") - sys.exit(1) - - # Spawn server. - try: - httpd = _ApiKeyHTTPServer( - ("127.0.0.1", REQUIRED_PORT), - _ApiKeyHTTPHandler, - codex_home=codex_home, - client_id=client_id, - verbose=args.verbose, - ) - except OSError as e: - eprint(f"ERROR: {e}") - if e.errno == errno.EADDRINUSE: - # Caller might want to handle this case specially. - sys.exit(EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE) - else: - sys.exit(1) - - auth_url = httpd.auth_url() - - with httpd: - eprint(f"Starting local login server on {URL_BASE}") - if not args.no_browser: - try: - webbrowser.open(auth_url, new=1, autoraise=True) - except Exception as e: - eprint(f"Failed to open browser: {e}") - - eprint( - f". If your browser did not open, navigate to this URL to authenticate: \n\n{auth_url}" - ) - - # Run the server in the main thread until `shutdown()` is called by the - # request handler. - try: - httpd.serve_forever() - except KeyboardInterrupt: - eprint("\nKeyboard interrupt received, exiting.") - - # Server has been shut down by the request handler. Exit with the code - # it set (0 on success, non-zero on failure). - sys.exit(httpd.exit_code) - - -class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): - """A minimal request handler that captures an *api key* from query/post.""" - - # We store the result in the server instance itself. - server: "_ApiKeyHTTPServer" # type: ignore[override] - helpful annotation - - def do_GET(self) -> None: # noqa: N802 – required by BaseHTTPRequestHandler - path = urllib.parse.urlparse(self.path).path - - if path == "/success": - # Serve confirmation page then gracefully shut down the server so - # the main thread can exit with the previously captured exit code. - self._send_html(LOGIN_SUCCESS_HTML) - - # Ensure the data is flushed to the client before we stop. - try: - self.wfile.flush() - except Exception as e: - eprint(f"Failed to flush response: {e}") - - self.request_shutdown() - elif path == "/auth/callback": - query = urllib.parse.urlparse(self.path).query - params = urllib.parse.parse_qs(query) - - # Validate state ------------------------------------------------- - if params.get("state", [None])[0] != self.server.state: - self.send_error(400, "State parameter mismatch") - return - - # Standard OAuth flow ----------------------------------------- - code = params.get("code", [None])[0] - if not code: - self.send_error(400, "Missing authorization code") - return - - try: - auth_bundle, success_url = self._exchange_code(code) - except Exception as exc: # noqa: BLE001 – propagate to client - self.send_error(500, f"Token exchange failed: {exc}") - return - - # Persist API key along with additional token metadata. - if _write_auth_file( - auth=auth_bundle, - codex_home=self.server.codex_home, - ): - self.server.exit_code = 0 - self._send_redirect(success_url) - else: - self.send_error(500, "Unable to persist auth file") - else: - self.send_error(404, "Endpoint not supported") - - def do_POST(self) -> None: # noqa: N802 – required by BaseHTTPRequestHandler - self.send_error(404, "Endpoint not supported") - - def send_error(self, code, message=None, explain=None) -> None: - """Send an error response and stop the server. - - We avoid calling `sys.exit()` directly from the request-handling thread - so that the response has a chance to be written to the socket. Instead - we shut the server down; the main thread will then exit with the - appropriate status code. - """ - super().send_error(code, message, explain) - try: - self.wfile.flush() - except Exception as e: - eprint(f"Failed to flush response: {e}") - - self.request_shutdown() - - def _send_redirect(self, url: str) -> None: - self.send_response(302) - self.send_header("Location", url) - self.end_headers() - - def _send_html(self, body: str) -> None: - encoded = body.encode() - self.send_response(200) - self.send_header("Content-Type", "text/html; charset=utf-8") - self.send_header("Content-Length", str(len(encoded))) - self.end_headers() - self.wfile.write(encoded) - - # Silence logging for cleanliness unless --verbose flag is used. - def log_message(self, fmt: str, *args): # type: ignore[override] - if getattr(self.server, "verbose", False): # type: ignore[attr-defined] - super().log_message(fmt, *args) - - def _obtain_api_key( - self, - token_claims: Dict[str, Any], - access_claims: Dict[str, Any], - token_data: TokenData, - ) -> tuple[str | None, str | None]: - """Obtain an API key from the auth service. - - Returns (api_key, success_url) if successful, None otherwise. - """ - - org_id = token_claims.get("organization_id") - project_id = token_claims.get("project_id") - - if not org_id or not project_id: - return (None, None) - - random_id = secrets.token_hex(6) - - # 2. Token exchange to obtain API key - today = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d") - exchange_data = urllib.parse.urlencode( - { - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "client_id": self.server.client_id, - "requested_token": "openai-api-key", - "subject_token": token_data.id_token, - "subject_token_type": "urn:ietf:params:oauth:token-type:id_token", - "name": f"Codex CLI [auto-generated] ({today}) [{random_id}]", - } - ).encode() - - exchanged_access_token: str - with urllib.request.urlopen( - urllib.request.Request( - self.server.token_endpoint, - data=exchange_data, - method="POST", - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ), - context=CA_CONTEXT, - ) as resp: - exchange_payload = json.loads(resp.read().decode()) - exchanged_access_token = exchange_payload["access_token"] - - # Determine whether the organization still requires additional - # setup (e.g., adding a payment method) based on the ID-token - # claim provided by the auth service. - completed_onboarding = token_claims.get("completed_platform_onboarding") == True - chatgpt_plan_type = access_claims.get("chatgpt_plan_type") - is_org_owner = token_claims.get("is_org_owner") == True - needs_setup = not completed_onboarding and is_org_owner - - # Build the success URL on the same host/port as the callback and - # include the required query parameters for the front-end page. - success_url_query = { - "id_token": token_data.id_token, - "needs_setup": "true" if needs_setup else "false", - "org_id": org_id, - "project_id": project_id, - "plan_type": chatgpt_plan_type, - "platform_url": ( - "https://platform.openai.com" - if self.server.issuer == "https://auth.openai.com" - else "https://platform.api.openai.org" - ), - } - success_url = f"{URL_BASE}/success?{urllib.parse.urlencode(success_url_query)}" - - # Attempt to redeem complimentary API credits for eligible ChatGPT - # Plus / Pro subscribers. Any errors are logged but do not interrupt - # the login flow. - - try: - maybe_redeem_credits( - issuer=self.server.issuer, - client_id=self.server.client_id, - id_token=token_data.id_token, - refresh_token=token_data.refresh_token, - codex_home=self.server.codex_home, - ) - except Exception as exc: # pragma: no cover – best-effort only - eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}") - - return (exchanged_access_token, success_url) - - def _exchange_code(self, code: str) -> tuple[AuthBundle, str]: - """Perform token + token-exchange to obtain an OpenAI API key. - - Returns (AuthBundle, success_url). - """ - - # 1. Authorization-code -> (id_token, access_token, refresh_token) - data = urllib.parse.urlencode( - { - "grant_type": "authorization_code", - "code": code, - "redirect_uri": self.server.redirect_uri, - "client_id": self.server.client_id, - "code_verifier": self.server.pkce.code_verifier, - } - ).encode() - - token_data: TokenData - - with urllib.request.urlopen( - urllib.request.Request( - self.server.token_endpoint, - data=data, - method="POST", - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ), - context=CA_CONTEXT, - ) as resp: - payload = json.loads(resp.read().decode()) - - # Extract chatgpt_account_id from id_token - id_token_parts = payload["id_token"].split(".") - if len(id_token_parts) != 3: - raise ValueError("Invalid ID token") - id_token_claims = _decode_jwt_segment(id_token_parts[1]) - auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) - chatgpt_account_id = auth_claims.get("chatgpt_account_id", "") - - token_data = TokenData( - id_token=payload["id_token"], - access_token=payload["access_token"], - refresh_token=payload["refresh_token"], - account_id=chatgpt_account_id, - ) - - access_token_parts = token_data.access_token.split(".") - if len(access_token_parts) != 3: - raise ValueError("Invalid access token") - - access_token_claims = _decode_jwt_segment(access_token_parts[1]) - - token_claims = id_token_claims.get("https://api.openai.com/auth", {}) - access_claims = access_token_claims.get("https://api.openai.com/auth", {}) - - exchanged_access_token, success_url = self._obtain_api_key( - token_claims, access_claims, token_data - ) - - # Persist refresh_token/id_token for future use (redeem credits etc.) - last_refresh_str = ( - datetime.datetime.now(datetime.timezone.utc) - .isoformat() - .replace("+00:00", "Z") - ) - - auth_bundle = AuthBundle( - api_key=exchanged_access_token, - token_data=token_data, - last_refresh=last_refresh_str, - ) - - return (auth_bundle, success_url or f"{URL_BASE}/success") - - def request_shutdown(self) -> None: - # shutdown() must be invoked from another thread to avoid - # deadlocking the serve_forever() loop, which is running in this - # same thread. A short-lived helper thread does the trick. - threading.Thread(target=self.server.shutdown, daemon=True).start() - - -def _write_auth_file(*, auth: AuthBundle, codex_home: str) -> bool: - """Persist *api_key* to $CODEX_HOME/auth.json. - - Returns True on success, False otherwise. Any error is printed to - *stderr* so that the Rust layer can surface the problem. - """ - if not os.path.isdir(codex_home): - try: - os.makedirs(codex_home, exist_ok=True) - except Exception as exc: # pragma: no cover – unlikely - eprint(f"ERROR: unable to create CODEX_HOME directory: {exc}") - return False - - auth_path = os.path.join(codex_home, "auth.json") - auth_json_contents = { - "OPENAI_API_KEY": auth.api_key, - "tokens": { - "id_token": auth.token_data.id_token, - "access_token": auth.token_data.access_token, - "refresh_token": auth.token_data.refresh_token, - "account_id": auth.token_data.account_id, - }, - "last_refresh": auth.last_refresh, - } - try: - with open(auth_path, "w", encoding="utf-8") as fp: - if hasattr(os, "fchmod"): # POSIX-safe - os.fchmod(fp.fileno(), 0o600) - json.dump(auth_json_contents, fp, indent=2) - except Exception as exc: # pragma: no cover – permissions/filesystem - eprint(f"ERROR: unable to write auth file: {exc}") - return False - - return True - - -@dataclass -class PkceCodes: - code_verifier: str - code_challenge: str - - -class _ApiKeyHTTPServer(http.server.HTTPServer): - """HTTPServer with shutdown helper & self-contained OAuth configuration.""" - - def __init__( - self, - server_address: tuple[str, int], - request_handler_class: type[http.server.BaseHTTPRequestHandler], - *, - codex_home: str, - client_id: str, - verbose: bool = False, - ) -> None: - super().__init__(server_address, request_handler_class, bind_and_activate=True) - - self.exit_code = 1 - self.codex_home = codex_home - self.verbose: bool = verbose - - self.issuer: str = DEFAULT_ISSUER - self.token_endpoint: str = f"{self.issuer}/oauth/token" - self.client_id: str = client_id - port = server_address[1] - self.redirect_uri: str = f"http://localhost:{port}/auth/callback" - self.pkce: PkceCodes = _generate_pkce() - self.state: str = secrets.token_hex(32) - - def auth_url(self) -> str: - """Return fully-formed OpenID authorization URL.""" - params = { - "response_type": "code", - "client_id": self.client_id, - "redirect_uri": self.redirect_uri, - "scope": "openid profile email offline_access", - "code_challenge": self.pkce.code_challenge, - "code_challenge_method": "S256", - "id_token_add_organizations": "true", - "codex_cli_simplified_flow": "true", - "state": self.state, - } - return f"{self.issuer}/oauth/authorize?" + urllib.parse.urlencode(params) - - -def maybe_redeem_credits( - *, - issuer: str, - client_id: str, - id_token: str | None, - refresh_token: str, - codex_home: str, -) -> None: - """Attempt to redeem complimentary API credits for ChatGPT subscribers. - - The operation is best-effort: any error results in a warning being printed - and the function returning early without raising. - """ - id_claims: Dict[str, Any] | None = parse_id_token_claims(id_token or "") - - # Refresh expired ID token, if possible - token_expired = True - if id_claims and isinstance(id_claims.get("exp"), int): - token_expired = _current_timestamp_ms() >= int(id_claims["exp"]) * 1000 - - if token_expired: - eprint("Refreshing credentials...") - new_refresh_token: str | None = None - new_id_token: str | None = None - - try: - payload = json.dumps( - { - "client_id": client_id, - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "scope": "openid profile email", - } - ).encode() - - req = urllib.request.Request( - url="https://auth.openai.com/oauth/token", - data=payload, - method="POST", - headers={"Content-Type": "application/json"}, - ) - - with urllib.request.urlopen(req, context=CA_CONTEXT) as resp: - refresh_data = json.loads(resp.read().decode()) - new_id_token = refresh_data.get("id_token") - new_id_claims = parse_id_token_claims(new_id_token or "") - new_refresh_token = refresh_data.get("refresh_token") - except Exception as err: - eprint("Unable to refresh ID token via token-exchange:", err) - return - - if not new_id_token or not new_refresh_token: - return - - # Update auth.json with new tokens. - try: - auth_dir = codex_home - auth_path = os.path.join(auth_dir, "auth.json") - with open(auth_path, "r", encoding="utf-8") as fp: - existing = json.load(fp) - - tokens = existing.setdefault("tokens", {}) - tokens["id_token"] = new_id_token - # Note this does not touch the access_token? - tokens["refresh_token"] = new_refresh_token - tokens["last_refresh"] = ( - datetime.datetime.now(datetime.timezone.utc) - .isoformat() - .replace("+00:00", "Z") - ) - - with open(auth_path, "w", encoding="utf-8") as fp: - if hasattr(os, "fchmod"): - os.fchmod(fp.fileno(), 0o600) - json.dump(existing, fp, indent=2) - except Exception as err: - eprint("Unable to update refresh token in auth file:", err) - - if not new_id_claims: - # Still couldn't parse claims. - return - - id_token = new_id_token - id_claims = new_id_claims - - # Done refreshing credentials: now try to redeem credits. - if not id_token: - eprint("No ID token available, cannot redeem credits.") - return - - auth_claims = id_claims.get("https://api.openai.com/auth", {}) - - # Subscription eligibility check (Plus or Pro, >7 days active) - sub_start_str = auth_claims.get("chatgpt_subscription_active_start") - if isinstance(sub_start_str, str): - try: - sub_start_ts = datetime.datetime.fromisoformat(sub_start_str.rstrip("Z")) - if datetime.datetime.now( - datetime.timezone.utc - ) - sub_start_ts < datetime.timedelta(days=7): - eprint( - "Sorry, your subscription must be active for more than 7 days to redeem credits." - ) - return - except ValueError: - # Malformed; ignore - pass - - completed_onboarding = bool(auth_claims.get("completed_platform_onboarding")) - is_org_owner = bool(auth_claims.get("is_org_owner")) - needs_setup = not completed_onboarding and is_org_owner - plan_type = auth_claims.get("chatgpt_plan_type") - - if needs_setup or plan_type not in {"plus", "pro"}: - eprint("Only users with Plus or Pro subscriptions can redeem free API credits.") - return - - api_host = ( - "https://api.openai.com" - if issuer == "https://auth.openai.com" - else "https://api.openai.org" - ) - - try: - redeem_payload = json.dumps({"id_token": id_token}).encode() - req = urllib.request.Request( - url=f"{api_host}/v1/billing/redeem_credits", - data=redeem_payload, - method="POST", - headers={"Content-Type": "application/json"}, - ) - - with urllib.request.urlopen(req, context=CA_CONTEXT) as resp: - redeem_data = json.loads(resp.read().decode()) - - granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0) - if granted and granted > 0: - eprint( - f"""Thanks for being a ChatGPT {"Plus" if plan_type == "plus" else "Pro"} subscriber! -If you haven't already redeemed, you should receive {"$5" if plan_type == "plus" else "$50"} in API credits. - -Credits: https://platform.openai.com/settings/organization/billing/credit-grants -More info: https://help.openai.com/en/articles/11381614""", - ) - else: - eprint( - f"""It looks like no credits were granted: - -{json.dumps(redeem_data, indent=2)} - -Credits: https://platform.openai.com/settings/organization/billing/credit-grants -More info: https://help.openai.com/en/articles/11381614""" - ) - except Exception as err: - eprint("Credit redemption request failed:", err) - - -def _generate_pkce() -> PkceCodes: - """Generate PKCE *code_verifier* and *code_challenge* (S256).""" - code_verifier = secrets.token_hex(64) - digest = hashlib.sha256(code_verifier.encode()).digest() - code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() - return PkceCodes(code_verifier, code_challenge) - - -def eprint(*args, **kwargs) -> None: - print(*args, file=sys.stderr, **kwargs) - - -# Parse ID-token claims (if provided) -# -# interface IDTokenClaims { -# "exp": number; // specifically, an int -# "https://api.openai.com/auth": { -# organization_id: string; -# project_id: string; -# completed_platform_onboarding: boolean; -# is_org_owner: boolean; -# chatgpt_subscription_active_start: string; -# chatgpt_subscription_active_until: string; -# chatgpt_plan_type: string; -# }; -# } -def parse_id_token_claims(id_token: str) -> Dict[str, Any] | None: - if id_token: - parts = id_token.split(".") - if len(parts) == 3: - return _decode_jwt_segment(parts[1]) - return None - - -def _decode_jwt_segment(segment: str) -> Dict[str, Any]: - """Return the decoded JSON payload from a JWT segment. - - Adds required padding for urlsafe_b64decode. - """ - padded = segment + "=" * (-len(segment) % 4) - try: - data = base64.urlsafe_b64decode(padded.encode()) - return json.loads(data.decode()) - except Exception: - return {} - - -def _current_timestamp_ms() -> int: - return int(time.time() * 1000) - - -LOGIN_SUCCESS_HTML = """ - - - - Sign into Codex CLI - - - - -
-
-
- -
Signed in to Codex CLI
-
- - -
-
- - -""" - -# Unconditionally call `main()` instead of gating it behind -# `if __name__ == "__main__"` because this script is either: -# -# - invoked as a string passed to `python3 -c` -# - run via `python3 login_with_chatgpt.py` for testing as part of local -# development -main() diff --git a/codex-rs/login/src/pkce.rs b/codex-rs/login/src/pkce.rs new file mode 100644 index 00000000..3c413b11 --- /dev/null +++ b/codex-rs/login/src/pkce.rs @@ -0,0 +1,27 @@ +use base64::Engine; +use rand::RngCore; +use sha2::Digest; +use sha2::Sha256; + +#[derive(Debug, Clone)] +pub struct PkceCodes { + pub code_verifier: String, + pub code_challenge: String, +} + +pub fn generate_pkce() -> PkceCodes { + let mut bytes = [0u8; 64]; + rand::thread_rng().fill_bytes(&mut bytes); + + // Verifier: URL-safe base64 without padding (43..128 chars) + let code_verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + + // Challenge (S256): BASE64URL-ENCODE(SHA256(verifier)) without padding + let digest = Sha256::digest(code_verifier.as_bytes()); + let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest); + + PkceCodes { + code_verifier, + code_challenge, + } +} diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs new file mode 100644 index 00000000..550ee703 --- /dev/null +++ b/codex-rs/login/src/server.rs @@ -0,0 +1,443 @@ +use std::io::{self}; +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; + +use base64::Engine; +use chrono::Utc; +use rand::RngCore; +use tiny_http::Response; +use tiny_http::Server; + +use crate::AuthDotJson; +use crate::get_auth_file; +use crate::pkce::PkceCodes; +use crate::pkce::generate_pkce; + +const DEFAULT_ISSUER: &str = "https://auth.openai.com"; +const DEFAULT_PORT: u16 = 1455; + +#[derive(Debug, Clone)] +pub struct ServerOptions<'a> { + pub codex_home: &'a Path, + pub client_id: &'a str, + pub issuer: &'a str, + pub port: u16, + pub open_browser: bool, + pub force_state: Option, +} + +impl<'a> ServerOptions<'a> { + pub fn new(codex_home: &'a Path, client_id: &'a str) -> Self { + Self { + codex_home, + client_id, + issuer: DEFAULT_ISSUER, + port: DEFAULT_PORT, + open_browser: true, + force_state: None, + } + } +} + +#[allow(dead_code)] +pub fn run_server_blocking(opts: ServerOptions) -> io::Result<()> { + run_server_blocking_with_notify(opts, None, None) +} + +pub struct LoginServerInfo { + pub auth_url: String, + pub actual_port: u16, +} + +pub fn run_server_blocking_with_notify( + opts: ServerOptions, + notify_started: Option>, + shutdown_flag: Option>, +) -> io::Result<()> { + let pkce = generate_pkce(); + let state = opts.force_state.clone().unwrap_or_else(generate_state); + + let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?; + let actual_port = match server.server_addr().to_ip() { + Some(addr) => addr.port(), + None => { + return Err(io::Error::new( + io::ErrorKind::AddrInUse, + "Unable to determine the server port", + )); + } + }; + + let redirect_uri = format!("http://localhost:{actual_port}/auth/callback"); + let auth_url = build_authorize_url(opts.issuer, opts.client_id, &redirect_uri, &pkce, &state); + + if let Some(tx) = ¬ify_started { + let _ = tx.send(LoginServerInfo { + auth_url: auth_url.clone(), + actual_port, + }); + } + + if opts.open_browser { + let _ = webbrowser::open(&auth_url); + } + + let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false))); + while !shutdown_flag.load(Ordering::SeqCst) { + let req = match server.recv() { + Ok(r) => r, + Err(e) => return Err(io::Error::other(e)), + }; + + let url_raw = req.url().to_string(); + let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { + Ok(u) => u, + Err(e) => { + eprintln!("URL parse error: {e}"); + let _ = req.respond(Response::from_string("Bad Request").with_status_code(400)); + continue; + } + }; + let path = parsed_url.path().to_string(); + + match path.as_str() { + "/auth/callback" => { + let params: std::collections::HashMap = + parsed_url.query_pairs().into_owned().collect(); + if params.get("state").map(String::as_str) != Some(state.as_str()) { + let _ = + req.respond(Response::from_string("State mismatch").with_status_code(400)); + continue; + } + let code = match params.get("code") { + Some(c) if !c.is_empty() => c.clone(), + _ => { + let _ = req.respond( + Response::from_string("Missing authorization code") + .with_status_code(400), + ); + continue; + } + }; + + match exchange_code_for_tokens( + opts.issuer, + opts.client_id, + &redirect_uri, + &pkce, + &code, + ) { + Ok(tokens) => { + // Obtain API key via token-exchange and persist + let api_key = + obtain_api_key(opts.issuer, opts.client_id, &tokens.id_token).ok(); + if let Err(err) = persist_tokens( + opts.codex_home, + api_key.clone(), + tokens.id_token.clone(), + Some(tokens.access_token.clone()), + Some(tokens.refresh_token.clone()), + ) { + eprintln!("Persist error: {err}"); + let _ = req.respond( + Response::from_string(format!( + "Unable to persist auth file: {err}" + )) + .with_status_code(500), + ); + continue; + } + + let success_url = compose_success_url( + actual_port, + opts.issuer, + &tokens.id_token, + &tokens.access_token, + ); + match tiny_http::Header::from_bytes( + &b"Location"[..], + success_url.as_bytes(), + ) { + Ok(h) => { + let response = tiny_http::Response::empty(302).with_header(h); + let _ = req.respond(response); + } + Err(_) => { + let _ = req.respond( + Response::from_string("Internal Server Error") + .with_status_code(500), + ); + } + } + } + Err(err) => { + eprintln!("Token exchange error: {err}"); + let _ = req.respond( + Response::from_string(format!("Token exchange failed: {err}")) + .with_status_code(500), + ); + } + } + } + "/success" => { + let body = include_str!("assets/success.html"); + let mut resp = Response::from_data(body.as_bytes()); + if let Ok(h) = tiny_http::Header::from_bytes( + &b"Content-Type"[..], + &b"text/html; charset=utf-8"[..], + ) { + resp.add_header(h); + } + let _ = req.respond(resp); + shutdown_flag.store(true, Ordering::SeqCst); + } + _ => { + let _ = req.respond(Response::from_string("Not Found").with_status_code(404)); + } + } + } + + Ok(()) +} + +fn build_authorize_url( + issuer: &str, + client_id: &str, + redirect_uri: &str, + pkce: &PkceCodes, + state: &str, +) -> String { + let query = vec![ + ("response_type", "code"), + ("client_id", client_id), + ("redirect_uri", redirect_uri), + ("scope", "openid profile email offline_access"), + ("code_challenge", &pkce.code_challenge), + ("code_challenge_method", "S256"), + ("id_token_add_organizations", "true"), + ("codex_cli_simplified_flow", "true"), + ("state", state), + ]; + let qs = query + .into_iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v))) + .collect::>() + .join("&"); + format!("{issuer}/oauth/authorize?{qs}") +} + +fn generate_state() -> String { + let mut bytes = [0u8; 32]; + rand::thread_rng().fill_bytes(&mut bytes); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) +} + +struct ExchangedTokens { + id_token: String, + access_token: String, + refresh_token: String, +} + +fn exchange_code_for_tokens( + issuer: &str, + client_id: &str, + redirect_uri: &str, + pkce: &PkceCodes, + code: &str, +) -> io::Result { + #[derive(serde::Deserialize)] + struct TokenResponse { + id_token: String, + access_token: String, + refresh_token: String, + } + + let client = reqwest::blocking::Client::new(); + let resp = client + .post(format!("{issuer}/oauth/token")) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(format!( + "grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}", + urlencoding::encode(code), + urlencoding::encode(redirect_uri), + urlencoding::encode(client_id), + urlencoding::encode(&pkce.code_verifier) + )) + .send() + .map_err(io::Error::other)?; + + if !resp.status().is_success() { + return Err(io::Error::other(format!( + "token endpoint returned status {}", + resp.status() + ))); + } + + let tokens: TokenResponse = resp.json().map_err(io::Error::other)?; + Ok(ExchangedTokens { + id_token: tokens.id_token, + access_token: tokens.access_token, + refresh_token: tokens.refresh_token, + }) +} + +fn persist_tokens( + codex_home: &Path, + api_key: Option, + id_token: String, + access_token: Option, + refresh_token: Option, +) -> io::Result<()> { + let auth_file = get_auth_file(codex_home); + if let Some(parent) = auth_file.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent).map_err(io::Error::other)?; + } + } + + let mut auth = read_or_default(&auth_file); + if let Some(key) = api_key { + auth.openai_api_key = Some(key); + } + let tokens = auth + .tokens + .get_or_insert_with(crate::token_data::TokenData::default); + tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?; + // Persist chatgpt_account_id if present in claims + if let Some(acc) = jwt_auth_claims(&id_token) + .get("chatgpt_account_id") + .and_then(|v| v.as_str()) + { + tokens.account_id = Some(acc.to_string()); + } + if let Some(at) = access_token { + tokens.access_token = at; + } + if let Some(rt) = refresh_token { + tokens.refresh_token = rt; + } + auth.last_refresh = Some(Utc::now()); + super::write_auth_json(&auth_file, &auth) +} + +fn read_or_default(path: &Path) -> AuthDotJson { + match super::try_read_auth_json(path) { + Ok(auth) => auth, + Err(_) => AuthDotJson { + openai_api_key: None, + tokens: None, + last_refresh: None, + }, + } +} + +fn compose_success_url(port: u16, issuer: &str, id_token: &str, access_token: &str) -> String { + let token_claims = jwt_auth_claims(id_token); + let access_claims = jwt_auth_claims(access_token); + + let org_id = token_claims + .get("organization_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let project_id = token_claims + .get("project_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let completed_onboarding = token_claims + .get("completed_platform_onboarding") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let is_org_owner = token_claims + .get("is_org_owner") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let needs_setup = (!completed_onboarding) && is_org_owner; + let plan_type = access_claims + .get("chatgpt_plan_type") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let platform_url = if issuer == DEFAULT_ISSUER { + "https://platform.openai.com" + } else { + "https://platform.api.openai.org" + }; + + let mut params = vec![ + ("id_token", id_token.to_string()), + ("needs_setup", needs_setup.to_string()), + ("org_id", org_id.to_string()), + ("project_id", project_id.to_string()), + ("plan_type", plan_type.to_string()), + ("platform_url", platform_url.to_string()), + ]; + let qs = params + .drain(..) + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v))) + .collect::>() + .join("&"); + format!("http://localhost:{port}/success?{qs}") +} + +fn jwt_auth_claims(jwt: &str) -> serde_json::Map { + let mut parts = jwt.split('.'); + let (_h, payload_b64, _s) = match (parts.next(), parts.next(), parts.next()) { + (Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s), + _ => { + eprintln!("Invalid JWT format while extracting claims"); + return serde_json::Map::new(); + } + }; + match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(mut v) => { + if let Some(obj) = v + .get_mut("https://api.openai.com/auth") + .and_then(|x| x.as_object_mut()) + { + return obj.clone(); + } + eprintln!("JWT payload missing expected 'https://api.openai.com/auth' object"); + } + Err(e) => { + eprintln!("Failed to parse JWT JSON payload: {e}"); + } + }, + Err(e) => { + eprintln!("Failed to base64url-decode JWT payload: {e}"); + } + } + serde_json::Map::new() +} + +fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result { + // Token exchange for an API key access token + #[derive(serde::Deserialize)] + struct ExchangeResp { + access_token: String, + } + let client = reqwest::blocking::Client::new(); + let resp = client + .post(format!("{issuer}/oauth/token")) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(format!( + "grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}", + urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"), + urlencoding::encode(client_id), + urlencoding::encode("openai-api-key"), + urlencoding::encode(id_token), + urlencoding::encode("urn:ietf:params:oauth:token-type:id_token") + )) + .send() + .map_err(io::Error::other)?; + if !resp.status().is_success() { + return Err(io::Error::other(format!( + "api key exchange failed with status {}", + resp.status() + ))); + } + let body: ExchangeResp = resp.json().map_err(io::Error::other)?; + Ok(body.access_token) +} diff --git a/codex-rs/login/src/token_data.rs b/codex-rs/login/src/token_data.rs index fb4d8395..1cb537fa 100644 --- a/codex-rs/login/src/token_data.rs +++ b/codex-rs/login/src/token_data.rs @@ -6,7 +6,10 @@ use thiserror::Error; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)] pub struct TokenData { /// Flat info parsed from the JWT in auth.json. - #[serde(deserialize_with = "deserialize_id_token")] + #[serde( + deserialize_with = "deserialize_id_token", + serialize_with = "serialize_id_token" + )] pub id_token: IdTokenInfo, /// This is a JWT. @@ -29,13 +32,14 @@ impl TokenData { } /// Flat subset of useful claims in id_token from auth.json. -#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct IdTokenInfo { pub email: Option, /// The ChatGPT subscription plan type /// (e.g., "free", "plus", "pro", "business", "enterprise", "edu"). /// (Note: ae has not verified that those are the exact values.) pub(crate) chatgpt_plan_type: Option, + pub raw_jwt: String, } impl IdTokenInfo { @@ -126,6 +130,7 @@ pub(crate) fn parse_id_token(id_token: &str) -> Result(id_token: &IdTokenInfo, serializer: S) -> Result +where + S: serde::Serializer, +{ + serializer.serialize_str(&id_token.raw_jwt) +} + #[cfg(test)] mod tests { use super::*; @@ -145,7 +157,6 @@ mod tests { #[test] #[expect(clippy::expect_used, clippy::unwrap_used)] fn id_token_info_parses_email_and_plan() { - // Build a fake JWT with a URL-safe base64 payload containing email and plan. #[derive(Serialize)] struct Header { alg: &'static str, diff --git a/codex-rs/login/tests/login_server_e2e.rs b/codex-rs/login/tests/login_server_e2e.rs new file mode 100644 index 00000000..3fe0e320 --- /dev/null +++ b/codex-rs/login/tests/login_server_e2e.rs @@ -0,0 +1,192 @@ +#![allow(clippy::unwrap_used)] +use std::net::SocketAddr; +use std::net::TcpListener; +use std::thread; + +use base64::Engine; +use codex_login::LoginServerInfo; +use codex_login::ServerOptions; +use codex_login::run_server_blocking_with_notify; +use tempfile::tempdir; + +// See spawn.rs for details +pub const CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR: &str = "CODEX_SANDBOX_NETWORK_DISABLED"; + +fn start_mock_issuer() -> (SocketAddr, thread::JoinHandle<()>) { + // Bind to a random available port + let listener = TcpListener::bind(("127.0.0.1", 0)).unwrap(); + let addr = listener.local_addr().unwrap(); + let server = tiny_http::Server::from_listener(listener, None).unwrap(); + + let handle = thread::spawn(move || { + while let Ok(mut req) = server.recv() { + let url = req.url().to_string(); + if url.starts_with("/oauth/token") { + // Read body + let mut body = String::new(); + let _ = req.as_reader().read_to_string(&mut body); + // Build minimal JWT with plan=pro + #[derive(serde::Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "pro", + "chatgpt_account_id": "acc-123" + } + }); + let b64 = |b: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b); + let header_bytes = serde_json::to_vec(&header).unwrap(); + let payload_bytes = serde_json::to_vec(&payload).unwrap(); + let id_token = format!( + "{}.{}.{}", + b64(&header_bytes), + b64(&payload_bytes), + b64(b"sig") + ); + + let tokens = serde_json::json!({ + "id_token": id_token, + "access_token": "access-123", + "refresh_token": "refresh-123", + }); + let data = serde_json::to_vec(&tokens).unwrap(); + let mut resp = tiny_http::Response::from_data(data); + resp.add_header( + tiny_http::Header::from_bytes(&b"Content-Type"[..], &b"application/json"[..]) + .unwrap_or_else(|_| panic!("header bytes")), + ); + let _ = req.respond(resp); + } else { + let _ = req + .respond(tiny_http::Response::from_string("not found").with_status_code(404)); + } + } + }); + + (addr, handle) +} + +#[test] +fn end_to_end_login_flow_persists_auth_json() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let (issuer_addr, issuer_handle) = start_mock_issuer(); + let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port()); + + let tmp = tempdir().unwrap(); + let codex_home = tmp.path().to_path_buf(); + + let state = "test_state_123".to_string(); + + // Run server in background + let server_home = codex_home.clone(); + + let (tx, rx) = std::sync::mpsc::channel::(); + let server_thread = thread::spawn(move || { + let opts = ServerOptions { + codex_home: &server_home, + client_id: codex_login::CLIENT_ID, + issuer: &issuer, + port: 0, + open_browser: false, + force_state: Some(state), + }; + run_server_blocking_with_notify(opts, Some(tx), None).unwrap(); + }); + + let server_info = rx.recv().unwrap(); + let login_port = server_info.actual_port; + + // Simulate browser callback, and follow redirect to /success + let client = reqwest::blocking::Client::builder() + .redirect(reqwest::redirect::Policy::limited(5)) + .build() + .unwrap(); + let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=test_state_123"); + let resp = client.get(&url).send().unwrap(); + assert!(resp.status().is_success()); + + // Wait for server shutdown + server_thread + .join() + .unwrap_or_else(|_| panic!("server thread panicked")); + + // Validate auth.json + let auth_path = codex_home.join("auth.json"); + let data = std::fs::read_to_string(&auth_path).unwrap(); + let json: serde_json::Value = serde_json::from_str(&data).unwrap(); + assert!( + !json["OPENAI_API_KEY"].is_null(), + "OPENAI_API_KEY should be set" + ); + assert_eq!(json["tokens"]["access_token"], "access-123"); + assert_eq!(json["tokens"]["refresh_token"], "refresh-123"); + assert_eq!(json["tokens"]["account_id"], "acc-123"); + + // Stop mock issuer + drop(issuer_handle); +} + +#[test] +fn creates_missing_codex_home_dir() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let (issuer_addr, _issuer_handle) = start_mock_issuer(); + let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port()); + + let tmp = tempdir().unwrap(); + let codex_home = tmp.path().join("missing-subdir"); // does not exist + + let state = "state2".to_string(); + + // Run server in background + let server_home = codex_home.clone(); + let (tx, rx) = std::sync::mpsc::channel::(); + let server_thread = thread::spawn(move || { + let opts = ServerOptions { + codex_home: &server_home, + client_id: codex_login::CLIENT_ID, + issuer: &issuer, + port: 0, + open_browser: false, + force_state: Some(state), + }; + run_server_blocking_with_notify(opts, Some(tx), None).unwrap() + }); + + let server_info = rx.recv().unwrap(); + let login_port = server_info.actual_port; + + let client = reqwest::blocking::Client::new(); + let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2"); + let resp = client.get(&url).send().unwrap(); + assert!(resp.status().is_success()); + + server_thread + .join() + .unwrap_or_else(|_| panic!("server thread panicked")); + + let auth_path = codex_home.join("auth.json"); + assert!( + auth_path.exists(), + "auth.json should be created even if parent dir was missing" + ); +} diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index 2bef9fb9..256289b2 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -522,6 +522,10 @@ impl App<'_> { } fn draw_next_frame(&mut self, terminal: &mut tui::Tui) -> Result<()> { + if matches!(self.app_state, AppState::Onboarding { .. }) { + terminal.clear()?; + } + let screen_size = terminal.size()?; let last_known_screen_size = terminal.last_known_screen_size; if screen_size != last_known_screen_size { diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index ed573a41..a8864e4d 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -44,11 +44,7 @@ pub(crate) struct ContinueInBrowserState { impl Drop for ContinueInBrowserState { fn drop(&mut self) { if let Some(child) = &self.login_child { - if let Ok(mut locked) = child.child.lock() { - // Best-effort terminate and reap the child to avoid zombies. - let _ = locked.kill(); - let _ = locked.wait(); - } + child.cancel(); } } } @@ -321,32 +317,16 @@ impl AuthModeWidget { } fn spawn_completion_poller(&self, child: codex_login::SpawnedLogin) { - let child_arc = child.child.clone(); - let stderr_buf = child.stderr.clone(); let event_tx = self.event_tx.clone(); std::thread::spawn(move || { loop { - let done = { - if let Ok(mut locked) = child_arc.lock() { - match locked.try_wait() { - Ok(Some(status)) => Some(status.success()), - Ok(None) => None, - Err(_) => Some(false), - } - } else { - Some(false) - } - }; - if let Some(success) = done { + if let Some(success) = child.get_auth_result() { if success { event_tx.send(AppEvent::OnboardingAuthComplete(Ok(()))); } else { - let err = stderr_buf - .lock() - .ok() - .and_then(|b| String::from_utf8(b.clone()).ok()) - .unwrap_or_else(|| "login_with_chatgpt subprocess failed".to_string()); - event_tx.send(AppEvent::OnboardingAuthComplete(Err(err))); + event_tx.send(AppEvent::OnboardingAuthComplete(Err( + "login failed".to_string() + ))); } break; } diff --git a/codex-rs/tui/src/onboarding/onboarding_screen.rs b/codex-rs/tui/src/onboarding/onboarding_screen.rs index a104f777..a481c8c7 100644 --- a/codex-rs/tui/src/onboarding/onboarding_screen.rs +++ b/codex-rs/tui/src/onboarding/onboarding_screen.rs @@ -2,6 +2,8 @@ use codex_core::util::is_inside_git_repo; use crossterm::event::KeyEvent; use ratatui::buffer::Buffer; use ratatui::layout::Rect; +use ratatui::prelude::Widget; +use ratatui::widgets::Clear; use ratatui::widgets::WidgetRef; use codex_login::AuthMode; @@ -113,6 +115,14 @@ impl OnboardingScreen { Ok(()) => { state.sign_in_state = SignInState::ChatGptSuccessMessage; self.event_tx.send(AppEvent::RequestRedraw); + let tx1 = self.event_tx.clone(); + let tx2 = self.event_tx.clone(); + std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_millis(150)); + tx1.send(AppEvent::RequestRedraw); + std::thread::sleep(std::time::Duration::from_millis(200)); + tx2.send(AppEvent::RequestRedraw); + }); } Err(e) => { state.sign_in_state = SignInState::PickMode; @@ -171,6 +181,7 @@ impl KeyboardHandler for OnboardingScreen { impl WidgetRef for &OnboardingScreen { fn render_ref(&self, area: Rect, buf: &mut Buffer) { + Clear.render(area, buf); // Render steps top-to-bottom, measuring each step's height dynamically. let mut y = area.y; let bottom = area.y.saturating_add(area.height); @@ -218,6 +229,7 @@ impl WidgetRef for &OnboardingScreen { width, height: h, }; + Clear.render(target, buf); step.render_ref(target, buf); y = y.saturating_add(h); }