Cleanup rust login server a bit more (#2331)

Remove some extra abstractions.

---------

Co-authored-by: easong-openai <easong@openai.com>
This commit is contained in:
pakrym-oai
2025-08-14 19:42:14 -07:00
committed by GitHub
parent d0b907d399
commit 76df07350a
5 changed files with 208 additions and 283 deletions

View File

@@ -16,10 +16,9 @@ use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
pub use crate::server::LoginServerInfo;
pub use crate::server::LoginServer;
pub use crate::server::ServerOptions;
pub use crate::server::run_server_blocking;
pub use crate::server::run_server_blocking_with_notify;
pub use crate::server::run_login_server;
pub use crate::token_data::TokenData;
use crate::token_data::parse_id_token;
@@ -252,65 +251,6 @@ pub fn logout(codex_home: &Path) -> std::io::Result<bool> {
}
}
/// Represents a running login server. The server can be stopped by calling `cancel()` on SpawnedLogin.
#[derive(Debug, Clone)]
pub struct SpawnedLogin {
url: Arc<Mutex<Option<String>>>,
done: Arc<Mutex<Option<bool>>>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
}
impl SpawnedLogin {
pub fn get_login_url(&self) -> Option<String> {
self.url.lock().ok().and_then(|u| u.clone())
}
pub fn get_auth_result(&self) -> Option<bool> {
self.done.lock().ok().and_then(|d| *d)
}
pub fn cancel(&self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
pub fn spawn_login_with_chatgpt(codex_home: &Path) -> std::io::Result<SpawnedLogin> {
let (tx, rx) = std::sync::mpsc::channel::<LoginServerInfo>();
let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let done = Arc::new(Mutex::new(None::<bool>));
let url = Arc::new(Mutex::new(None::<String>));
let codex_home_buf = codex_home.to_path_buf();
let client_id = CLIENT_ID.to_string();
let shutdown_clone = shutdown.clone();
let done_clone = done.clone();
std::thread::spawn(move || {
let opts = ServerOptions::new(&codex_home_buf, &client_id);
let res = run_server_blocking_with_notify(opts, Some(tx), Some(shutdown_clone));
let success = res.is_ok();
if let Ok(mut lock) = done_clone.lock() {
*lock = Some(success);
}
});
let url_clone = url.clone();
std::thread::spawn(move || {
if let Ok(u) = rx.recv() {
if let Ok(mut lock) = url_clone.lock() {
*lock = Some(u.auth_url);
}
}
});
Ok(SpawnedLogin {
url,
done,
shutdown,
})
}
pub fn login_with_api_key(codex_home: &Path, api_key: &str) -> std::io::Result<()> {
let auth_dot_json = AuthDotJson {
openai_api_key: Some(api_key.to_string()),

View File

@@ -1,39 +1,40 @@
use std::io::{self};
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::thread;
use crate::AuthDotJson;
use crate::get_auth_file;
use crate::pkce::PkceCodes;
use crate::pkce::generate_pkce;
use base64::Engine;
use chrono::Utc;
use rand::RngCore;
use tiny_http::Response;
use tiny_http::Server;
use crate::AuthDotJson;
use crate::get_auth_file;
use crate::pkce::PkceCodes;
use crate::pkce::generate_pkce;
const DEFAULT_ISSUER: &str = "https://auth.openai.com";
const DEFAULT_PORT: u16 = 1455;
#[derive(Debug, Clone)]
pub struct ServerOptions<'a> {
pub codex_home: &'a Path,
pub client_id: &'a str,
pub issuer: &'a str,
pub struct ServerOptions {
pub codex_home: PathBuf,
pub client_id: String,
pub issuer: String,
pub port: u16,
pub open_browser: bool,
pub force_state: Option<String>,
}
impl<'a> ServerOptions<'a> {
pub fn new(codex_home: &'a Path, client_id: &'a str) -> Self {
impl ServerOptions {
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
Self {
codex_home,
client_id,
issuer: DEFAULT_ISSUER,
client_id: client_id.to_string(),
issuer: DEFAULT_ISSUER.to_string(),
port: DEFAULT_PORT,
open_browser: true,
force_state: None,
@@ -41,21 +42,31 @@ impl<'a> ServerOptions<'a> {
}
}
#[allow(dead_code)]
pub fn run_server_blocking(opts: ServerOptions) -> io::Result<()> {
run_server_blocking_with_notify(opts, None, None)
}
pub struct LoginServerInfo {
#[derive(Debug)]
pub struct LoginServer {
pub auth_url: String,
pub actual_port: u16,
pub server_handle: thread::JoinHandle<io::Result<()>>,
pub shutdown_flag: Arc<AtomicBool>,
}
pub fn run_server_blocking_with_notify(
impl LoginServer {
pub fn block_until_done(self) -> io::Result<()> {
#[expect(clippy::expect_used)]
self.server_handle
.join()
.expect("can't join on the server thread")
}
pub fn cancel(&self) {
self.shutdown_flag.store(true, Ordering::SeqCst);
}
}
pub fn run_login_server(
opts: ServerOptions,
notify_started: Option<std::sync::mpsc::Sender<LoginServerInfo>>,
shutdown_flag: Option<Arc<AtomicBool>>,
) -> io::Result<()> {
) -> io::Result<LoginServer> {
let pkce = generate_pkce();
let state = opts.force_state.clone().unwrap_or_else(generate_state);
@@ -71,135 +82,138 @@ pub fn run_server_blocking_with_notify(
};
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
let auth_url = build_authorize_url(opts.issuer, opts.client_id, &redirect_uri, &pkce, &state);
if let Some(tx) = &notify_started {
let _ = tx.send(LoginServerInfo {
auth_url: auth_url.clone(),
actual_port,
});
}
let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
if opts.open_browser {
let _ = webbrowser::open(&auth_url);
}
let shutdown_flag = shutdown_flag.unwrap_or_else(|| Arc::new(AtomicBool::new(false)));
while !shutdown_flag.load(Ordering::SeqCst) {
let req = match server.recv() {
Ok(r) => r,
Err(e) => return Err(io::Error::other(e)),
};
let shutdown_flag_clone = shutdown_flag.clone();
let server_handle = thread::spawn(move || {
while !shutdown_flag.load(Ordering::SeqCst) {
let req = match server.recv() {
Ok(r) => r,
Err(e) => return Err(io::Error::other(e)),
};
let url_raw = req.url().to_string();
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
Ok(u) => u,
Err(e) => {
eprintln!("URL parse error: {e}");
let _ = req.respond(Response::from_string("Bad Request").with_status_code(400));
continue;
}
};
let path = parsed_url.path().to_string();
match path.as_str() {
"/auth/callback" => {
let params: std::collections::HashMap<String, String> =
parsed_url.query_pairs().into_owned().collect();
if params.get("state").map(String::as_str) != Some(state.as_str()) {
let _ =
req.respond(Response::from_string("State mismatch").with_status_code(400));
let url_raw = req.url().to_string();
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
Ok(u) => u,
Err(e) => {
eprintln!("URL parse error: {e}");
let _ = req.respond(Response::from_string("Bad Request").with_status_code(400));
continue;
}
let code = match params.get("code") {
Some(c) if !c.is_empty() => c.clone(),
_ => {
let _ = req.respond(
Response::from_string("Missing authorization code")
.with_status_code(400),
);
};
let path = parsed_url.path().to_string();
match path.as_str() {
"/auth/callback" => {
let params: std::collections::HashMap<String, String> =
parsed_url.query_pairs().into_owned().collect();
if params.get("state").map(String::as_str) != Some(state.as_str()) {
let _ = req
.respond(Response::from_string("State mismatch").with_status_code(400));
continue;
}
};
match exchange_code_for_tokens(
opts.issuer,
opts.client_id,
&redirect_uri,
&pkce,
&code,
) {
Ok(tokens) => {
// Obtain API key via token-exchange and persist
let api_key =
obtain_api_key(opts.issuer, opts.client_id, &tokens.id_token).ok();
if let Err(err) = persist_tokens(
opts.codex_home,
api_key.clone(),
tokens.id_token.clone(),
Some(tokens.access_token.clone()),
Some(tokens.refresh_token.clone()),
) {
eprintln!("Persist error: {err}");
let code = match params.get("code") {
Some(c) if !c.is_empty() => c.clone(),
_ => {
let _ = req.respond(
Response::from_string(format!(
"Unable to persist auth file: {err}"
))
.with_status_code(500),
Response::from_string("Missing authorization code")
.with_status_code(400),
);
continue;
}
};
let success_url = compose_success_url(
actual_port,
opts.issuer,
&tokens.id_token,
&tokens.access_token,
);
match tiny_http::Header::from_bytes(
&b"Location"[..],
success_url.as_bytes(),
) {
Ok(h) => {
let response = tiny_http::Response::empty(302).with_header(h);
let _ = req.respond(response);
}
Err(_) => {
match exchange_code_for_tokens(
&opts.issuer,
&opts.client_id,
&redirect_uri,
&pkce,
&code,
) {
Ok(tokens) => {
// Obtain API key via token-exchange and persist
let api_key =
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
.ok();
if let Err(err) = persist_tokens(
&opts.codex_home,
api_key.clone(),
tokens.id_token.clone(),
Some(tokens.access_token.clone()),
Some(tokens.refresh_token.clone()),
) {
eprintln!("Persist error: {err}");
let _ = req.respond(
Response::from_string("Internal Server Error")
.with_status_code(500),
Response::from_string(format!(
"Unable to persist auth file: {err}"
))
.with_status_code(500),
);
continue;
}
let success_url = compose_success_url(
actual_port,
&opts.issuer,
&tokens.id_token,
&tokens.access_token,
);
match tiny_http::Header::from_bytes(
&b"Location"[..],
success_url.as_bytes(),
) {
Ok(h) => {
let response = tiny_http::Response::empty(302).with_header(h);
let _ = req.respond(response);
}
Err(_) => {
let _ = req.respond(
Response::from_string("Internal Server Error")
.with_status_code(500),
);
}
}
}
}
Err(err) => {
eprintln!("Token exchange error: {err}");
let _ = req.respond(
Response::from_string(format!("Token exchange failed: {err}"))
.with_status_code(500),
);
Err(err) => {
eprintln!("Token exchange error: {err}");
let _ = req.respond(
Response::from_string(format!("Token exchange failed: {err}"))
.with_status_code(500),
);
}
}
}
}
"/success" => {
let body = include_str!("assets/success.html");
let mut resp = Response::from_data(body.as_bytes());
if let Ok(h) = tiny_http::Header::from_bytes(
&b"Content-Type"[..],
&b"text/html; charset=utf-8"[..],
) {
resp.add_header(h);
"/success" => {
let body = include_str!("assets/success.html");
let mut resp = Response::from_data(body.as_bytes());
if let Ok(h) = tiny_http::Header::from_bytes(
&b"Content-Type"[..],
&b"text/html; charset=utf-8"[..],
) {
resp.add_header(h);
}
let _ = req.respond(resp);
shutdown_flag.store(true, Ordering::SeqCst);
return Ok(());
}
_ => {
let _ = req.respond(Response::from_string("Not Found").with_status_code(404));
}
let _ = req.respond(resp);
shutdown_flag.store(true, Ordering::SeqCst);
}
_ => {
let _ = req.respond(Response::from_string("Not Found").with_status_code(404));
}
}
}
Err(io::Error::other("Login flow was not completed"))
});
Ok(())
Ok(LoginServer {
auth_url: auth_url.clone(),
actual_port,
server_handle,
shutdown_flag: shutdown_flag_clone,
})
}
fn build_authorize_url(