From 71cae06e6643b4adf644d9769208c3c5fcd1f2be Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Sun, 17 Aug 2025 12:32:56 -0700 Subject: [PATCH] fix: refactor login/src/server.rs so process_request() is a separate function (#2388) --- codex-rs/login/src/server.rs | 225 ++++++++++++++++++----------------- 1 file changed, 118 insertions(+), 107 deletions(-) diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index 566b562d..ef85df69 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -1,3 +1,4 @@ +use std::io::Cursor; use std::io::{self}; use std::path::Path; use std::path::PathBuf; @@ -15,6 +16,8 @@ use crate::pkce::generate_pkce; use base64::Engine; use chrono::Utc; use rand::RngCore; +use tiny_http::Header; +use tiny_http::Request; use tiny_http::Response; use tiny_http::Server; @@ -149,116 +152,23 @@ pub fn run_login_server( } }; - let url_raw = req.url().to_string(); - let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { - Ok(u) => u, - Err(e) => { - eprintln!("URL parse error: {e}"); - let _ = req.respond(Response::from_string("Bad Request").with_status_code(400)); - continue; + let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state); + let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); + match response { + HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { + let _ = req.respond(r); } - }; - let path = parsed_url.path().to_string(); - - match path.as_str() { - "/auth/callback" => { - let params: std::collections::HashMap = - parsed_url.query_pairs().into_owned().collect(); - if params.get("state").map(String::as_str) != Some(state.as_str()) { - let _ = req - .respond(Response::from_string("State mismatch").with_status_code(400)); - continue; - } - let code = match params.get("code") { - Some(c) if !c.is_empty() => c.clone(), - _ => { - let _ = req.respond( - Response::from_string("Missing authorization code") - .with_status_code(400), - ); - continue; - } - }; - - match exchange_code_for_tokens( - &opts.issuer, - &opts.client_id, - &redirect_uri, - &pkce, - &code, - ) { - Ok(tokens) => { - // Obtain API key via token-exchange and persist - let api_key = - obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token) - .ok(); - if let Err(err) = persist_tokens( - &opts.codex_home, - api_key.clone(), - tokens.id_token.clone(), - Some(tokens.access_token.clone()), - Some(tokens.refresh_token.clone()), - ) { - eprintln!("Persist error: {err}"); - let _ = req.respond( - Response::from_string(format!( - "Unable to persist auth file: {err}" - )) - .with_status_code(500), - ); - continue; - } - - let success_url = compose_success_url( - actual_port, - &opts.issuer, - &tokens.id_token, - &tokens.access_token, - ); - match tiny_http::Header::from_bytes( - &b"Location"[..], - success_url.as_bytes(), - ) { - Ok(h) => { - let response = tiny_http::Response::empty(302).with_header(h); - let _ = req.respond(response); - } - Err(_) => { - let _ = req.respond( - Response::from_string("Internal Server Error") - .with_status_code(500), - ); - } - } - } - Err(err) => { - eprintln!("Token exchange error: {err}"); - let _ = req.respond( - Response::from_string(format!("Token exchange failed: {err}")) - .with_status_code(500), - ); - } - } + HandledRequest::RedirectWithHeader(header) => { + let redirect = Response::empty(302).with_header(header); + let _ = req.respond(redirect); } - "/success" => { - let body = include_str!("assets/success.html"); - let mut resp = Response::from_data(body.as_bytes()); - if let Ok(h) = tiny_http::Header::from_bytes( - &b"Content-Type"[..], - &b"text/html; charset=utf-8"[..], - ) { - resp.add_header(h); - } - let _ = req.respond(resp); - shutdown_flag.store(true, Ordering::SeqCst); + } - // Login has succeeded, so disarm the timeout watcher. - let _ = done_tx.send(()); - return Ok(()); - } - _ => { - let _ = req.respond(Response::from_string("Not Found").with_status_code(404)); - } + if is_login_complete { + shutdown_flag.store(true, Ordering::SeqCst); + // Login has succeeded, so disarm the timeout watcher. + let _ = done_tx.send(()); + return Ok(()); } } @@ -281,6 +191,107 @@ pub fn run_login_server( }) } +enum HandledRequest { + Response(Response>>), + RedirectWithHeader(Header), + ResponseAndExit(Response>>), +} + +fn process_request( + req: &Request, + opts: &ServerOptions, + redirect_uri: &str, + pkce: &PkceCodes, + actual_port: u16, + state: &str, +) -> HandledRequest { + let url_raw = req.url().to_string(); + let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) { + Ok(u) => u, + Err(e) => { + eprintln!("URL parse error: {e}"); + return HandledRequest::Response( + Response::from_string("Bad Request").with_status_code(400), + ); + } + }; + let path = parsed_url.path().to_string(); + + match path.as_str() { + "/auth/callback" => { + let params: std::collections::HashMap = + parsed_url.query_pairs().into_owned().collect(); + if params.get("state").map(String::as_str) != Some(state) { + return HandledRequest::Response( + Response::from_string("State mismatch").with_status_code(400), + ); + } + let code = match params.get("code") { + Some(c) if !c.is_empty() => c.clone(), + _ => { + return HandledRequest::Response( + Response::from_string("Missing authorization code").with_status_code(400), + ); + } + }; + + match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code) + { + Ok(tokens) => { + // Obtain API key via token-exchange and persist + let api_key = + obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok(); + if let Err(err) = persist_tokens( + &opts.codex_home, + api_key.clone(), + tokens.id_token.clone(), + Some(tokens.access_token.clone()), + Some(tokens.refresh_token.clone()), + ) { + eprintln!("Persist error: {err}"); + return HandledRequest::Response( + Response::from_string(format!("Unable to persist auth file: {err}")) + .with_status_code(500), + ); + } + + let success_url = compose_success_url( + actual_port, + &opts.issuer, + &tokens.id_token, + &tokens.access_token, + ); + match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) { + Ok(header) => HandledRequest::RedirectWithHeader(header), + Err(_) => HandledRequest::Response( + Response::from_string("Internal Server Error").with_status_code(500), + ), + } + } + Err(err) => { + eprintln!("Token exchange error: {err}"); + HandledRequest::Response( + Response::from_string(format!("Token exchange failed: {err}")) + .with_status_code(500), + ) + } + } + } + "/success" => { + let body = include_str!("assets/success.html"); + let mut resp = Response::from_data(body.as_bytes()); + if let Ok(h) = tiny_http::Header::from_bytes( + &b"Content-Type"[..], + &b"text/html; charset=utf-8"[..], + ) { + resp.add_header(h); + } + HandledRequest::ResponseAndExit(resp) + } + _ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)), + } +} + /// Spawns a detached thread that waits for either a completion signal on `done_rx` /// or the specified `timeout` to elapse. If the timeout elapses first it marks /// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so