fix: refactor login/src/server.rs so process_request() is a separate function (#2388)

This commit is contained in:
Michael Bolin
2025-08-17 12:32:56 -07:00
committed by GitHub
parent 350b00d54b
commit 71cae06e66

View File

@@ -1,3 +1,4 @@
use std::io::Cursor;
use std::io::{self};
use std::path::Path;
use std::path::PathBuf;
@@ -15,6 +16,8 @@ use crate::pkce::generate_pkce;
use base64::Engine;
use chrono::Utc;
use rand::RngCore;
use tiny_http::Header;
use tiny_http::Request;
use tiny_http::Response;
use tiny_http::Server;
@@ -149,116 +152,23 @@ pub fn run_login_server(
}
};
let url_raw = req.url().to_string();
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
Ok(u) => u,
Err(e) => {
eprintln!("URL parse error: {e}");
let _ = req.respond(Response::from_string("Bad Request").with_status_code(400));
continue;
let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state);
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
match response {
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
let _ = req.respond(r);
}
};
let path = parsed_url.path().to_string();
match path.as_str() {
"/auth/callback" => {
let params: std::collections::HashMap<String, String> =
parsed_url.query_pairs().into_owned().collect();
if params.get("state").map(String::as_str) != Some(state.as_str()) {
let _ = req
.respond(Response::from_string("State mismatch").with_status_code(400));
continue;
}
let code = match params.get("code") {
Some(c) if !c.is_empty() => c.clone(),
_ => {
let _ = req.respond(
Response::from_string("Missing authorization code")
.with_status_code(400),
);
continue;
}
};
match exchange_code_for_tokens(
&opts.issuer,
&opts.client_id,
&redirect_uri,
&pkce,
&code,
) {
Ok(tokens) => {
// Obtain API key via token-exchange and persist
let api_key =
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
.ok();
if let Err(err) = persist_tokens(
&opts.codex_home,
api_key.clone(),
tokens.id_token.clone(),
Some(tokens.access_token.clone()),
Some(tokens.refresh_token.clone()),
) {
eprintln!("Persist error: {err}");
let _ = req.respond(
Response::from_string(format!(
"Unable to persist auth file: {err}"
))
.with_status_code(500),
);
continue;
}
let success_url = compose_success_url(
actual_port,
&opts.issuer,
&tokens.id_token,
&tokens.access_token,
);
match tiny_http::Header::from_bytes(
&b"Location"[..],
success_url.as_bytes(),
) {
Ok(h) => {
let response = tiny_http::Response::empty(302).with_header(h);
let _ = req.respond(response);
}
Err(_) => {
let _ = req.respond(
Response::from_string("Internal Server Error")
.with_status_code(500),
);
}
}
}
Err(err) => {
eprintln!("Token exchange error: {err}");
let _ = req.respond(
Response::from_string(format!("Token exchange failed: {err}"))
.with_status_code(500),
);
}
}
HandledRequest::RedirectWithHeader(header) => {
let redirect = Response::empty(302).with_header(header);
let _ = req.respond(redirect);
}
"/success" => {
let body = include_str!("assets/success.html");
let mut resp = Response::from_data(body.as_bytes());
if let Ok(h) = tiny_http::Header::from_bytes(
&b"Content-Type"[..],
&b"text/html; charset=utf-8"[..],
) {
resp.add_header(h);
}
let _ = req.respond(resp);
shutdown_flag.store(true, Ordering::SeqCst);
}
// Login has succeeded, so disarm the timeout watcher.
let _ = done_tx.send(());
return Ok(());
}
_ => {
let _ = req.respond(Response::from_string("Not Found").with_status_code(404));
}
if is_login_complete {
shutdown_flag.store(true, Ordering::SeqCst);
// Login has succeeded, so disarm the timeout watcher.
let _ = done_tx.send(());
return Ok(());
}
}
@@ -281,6 +191,107 @@ pub fn run_login_server(
})
}
enum HandledRequest {
Response(Response<Cursor<Vec<u8>>>),
RedirectWithHeader(Header),
ResponseAndExit(Response<Cursor<Vec<u8>>>),
}
fn process_request(
req: &Request,
opts: &ServerOptions,
redirect_uri: &str,
pkce: &PkceCodes,
actual_port: u16,
state: &str,
) -> HandledRequest {
let url_raw = req.url().to_string();
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
Ok(u) => u,
Err(e) => {
eprintln!("URL parse error: {e}");
return HandledRequest::Response(
Response::from_string("Bad Request").with_status_code(400),
);
}
};
let path = parsed_url.path().to_string();
match path.as_str() {
"/auth/callback" => {
let params: std::collections::HashMap<String, String> =
parsed_url.query_pairs().into_owned().collect();
if params.get("state").map(String::as_str) != Some(state) {
return HandledRequest::Response(
Response::from_string("State mismatch").with_status_code(400),
);
}
let code = match params.get("code") {
Some(c) if !c.is_empty() => c.clone(),
_ => {
return HandledRequest::Response(
Response::from_string("Missing authorization code").with_status_code(400),
);
}
};
match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
{
Ok(tokens) => {
// Obtain API key via token-exchange and persist
let api_key =
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok();
if let Err(err) = persist_tokens(
&opts.codex_home,
api_key.clone(),
tokens.id_token.clone(),
Some(tokens.access_token.clone()),
Some(tokens.refresh_token.clone()),
) {
eprintln!("Persist error: {err}");
return HandledRequest::Response(
Response::from_string(format!("Unable to persist auth file: {err}"))
.with_status_code(500),
);
}
let success_url = compose_success_url(
actual_port,
&opts.issuer,
&tokens.id_token,
&tokens.access_token,
);
match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
Ok(header) => HandledRequest::RedirectWithHeader(header),
Err(_) => HandledRequest::Response(
Response::from_string("Internal Server Error").with_status_code(500),
),
}
}
Err(err) => {
eprintln!("Token exchange error: {err}");
HandledRequest::Response(
Response::from_string(format!("Token exchange failed: {err}"))
.with_status_code(500),
)
}
}
}
"/success" => {
let body = include_str!("assets/success.html");
let mut resp = Response::from_data(body.as_bytes());
if let Ok(h) = tiny_http::Header::from_bytes(
&b"Content-Type"[..],
&b"text/html; charset=utf-8"[..],
) {
resp.add_header(h);
}
HandledRequest::ResponseAndExit(resp)
}
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
}
}
/// Spawns a detached thread that waits for either a completion signal on `done_rx`
/// or the specified `timeout` to elapse. If the timeout elapses first it marks
/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so