fix: refactor login/src/server.rs so process_request() is a separate function (#2388)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use std::io::Cursor;
|
||||
use std::io::{self};
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
@@ -15,6 +16,8 @@ use crate::pkce::generate_pkce;
|
||||
use base64::Engine;
|
||||
use chrono::Utc;
|
||||
use rand::RngCore;
|
||||
use tiny_http::Header;
|
||||
use tiny_http::Request;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
|
||||
@@ -149,116 +152,23 @@ pub fn run_login_server(
|
||||
}
|
||||
};
|
||||
|
||||
let url_raw = req.url().to_string();
|
||||
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
eprintln!("URL parse error: {e}");
|
||||
let _ = req.respond(Response::from_string("Bad Request").with_status_code(400));
|
||||
continue;
|
||||
let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state);
|
||||
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||
match response {
|
||||
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||
let _ = req.respond(r);
|
||||
}
|
||||
};
|
||||
let path = parsed_url.path().to_string();
|
||||
|
||||
match path.as_str() {
|
||||
"/auth/callback" => {
|
||||
let params: std::collections::HashMap<String, String> =
|
||||
parsed_url.query_pairs().into_owned().collect();
|
||||
if params.get("state").map(String::as_str) != Some(state.as_str()) {
|
||||
let _ = req
|
||||
.respond(Response::from_string("State mismatch").with_status_code(400));
|
||||
continue;
|
||||
}
|
||||
let code = match params.get("code") {
|
||||
Some(c) if !c.is_empty() => c.clone(),
|
||||
_ => {
|
||||
let _ = req.respond(
|
||||
Response::from_string("Missing authorization code")
|
||||
.with_status_code(400),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match exchange_code_for_tokens(
|
||||
&opts.issuer,
|
||||
&opts.client_id,
|
||||
&redirect_uri,
|
||||
&pkce,
|
||||
&code,
|
||||
) {
|
||||
Ok(tokens) => {
|
||||
// Obtain API key via token-exchange and persist
|
||||
let api_key =
|
||||
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
|
||||
.ok();
|
||||
if let Err(err) = persist_tokens(
|
||||
&opts.codex_home,
|
||||
api_key.clone(),
|
||||
tokens.id_token.clone(),
|
||||
Some(tokens.access_token.clone()),
|
||||
Some(tokens.refresh_token.clone()),
|
||||
) {
|
||||
eprintln!("Persist error: {err}");
|
||||
let _ = req.respond(
|
||||
Response::from_string(format!(
|
||||
"Unable to persist auth file: {err}"
|
||||
))
|
||||
.with_status_code(500),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let success_url = compose_success_url(
|
||||
actual_port,
|
||||
&opts.issuer,
|
||||
&tokens.id_token,
|
||||
&tokens.access_token,
|
||||
);
|
||||
match tiny_http::Header::from_bytes(
|
||||
&b"Location"[..],
|
||||
success_url.as_bytes(),
|
||||
) {
|
||||
Ok(h) => {
|
||||
let response = tiny_http::Response::empty(302).with_header(h);
|
||||
let _ = req.respond(response);
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = req.respond(
|
||||
Response::from_string("Internal Server Error")
|
||||
.with_status_code(500),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Token exchange error: {err}");
|
||||
let _ = req.respond(
|
||||
Response::from_string(format!("Token exchange failed: {err}"))
|
||||
.with_status_code(500),
|
||||
);
|
||||
}
|
||||
}
|
||||
HandledRequest::RedirectWithHeader(header) => {
|
||||
let redirect = Response::empty(302).with_header(header);
|
||||
let _ = req.respond(redirect);
|
||||
}
|
||||
"/success" => {
|
||||
let body = include_str!("assets/success.html");
|
||||
let mut resp = Response::from_data(body.as_bytes());
|
||||
if let Ok(h) = tiny_http::Header::from_bytes(
|
||||
&b"Content-Type"[..],
|
||||
&b"text/html; charset=utf-8"[..],
|
||||
) {
|
||||
resp.add_header(h);
|
||||
}
|
||||
let _ = req.respond(resp);
|
||||
shutdown_flag.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// Login has succeeded, so disarm the timeout watcher.
|
||||
let _ = done_tx.send(());
|
||||
return Ok(());
|
||||
}
|
||||
_ => {
|
||||
let _ = req.respond(Response::from_string("Not Found").with_status_code(404));
|
||||
}
|
||||
if is_login_complete {
|
||||
shutdown_flag.store(true, Ordering::SeqCst);
|
||||
// Login has succeeded, so disarm the timeout watcher.
|
||||
let _ = done_tx.send(());
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,6 +191,107 @@ pub fn run_login_server(
|
||||
})
|
||||
}
|
||||
|
||||
enum HandledRequest {
|
||||
Response(Response<Cursor<Vec<u8>>>),
|
||||
RedirectWithHeader(Header),
|
||||
ResponseAndExit(Response<Cursor<Vec<u8>>>),
|
||||
}
|
||||
|
||||
fn process_request(
|
||||
req: &Request,
|
||||
opts: &ServerOptions,
|
||||
redirect_uri: &str,
|
||||
pkce: &PkceCodes,
|
||||
actual_port: u16,
|
||||
state: &str,
|
||||
) -> HandledRequest {
|
||||
let url_raw = req.url().to_string();
|
||||
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
eprintln!("URL parse error: {e}");
|
||||
return HandledRequest::Response(
|
||||
Response::from_string("Bad Request").with_status_code(400),
|
||||
);
|
||||
}
|
||||
};
|
||||
let path = parsed_url.path().to_string();
|
||||
|
||||
match path.as_str() {
|
||||
"/auth/callback" => {
|
||||
let params: std::collections::HashMap<String, String> =
|
||||
parsed_url.query_pairs().into_owned().collect();
|
||||
if params.get("state").map(String::as_str) != Some(state) {
|
||||
return HandledRequest::Response(
|
||||
Response::from_string("State mismatch").with_status_code(400),
|
||||
);
|
||||
}
|
||||
let code = match params.get("code") {
|
||||
Some(c) if !c.is_empty() => c.clone(),
|
||||
_ => {
|
||||
return HandledRequest::Response(
|
||||
Response::from_string("Missing authorization code").with_status_code(400),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
|
||||
{
|
||||
Ok(tokens) => {
|
||||
// Obtain API key via token-exchange and persist
|
||||
let api_key =
|
||||
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok();
|
||||
if let Err(err) = persist_tokens(
|
||||
&opts.codex_home,
|
||||
api_key.clone(),
|
||||
tokens.id_token.clone(),
|
||||
Some(tokens.access_token.clone()),
|
||||
Some(tokens.refresh_token.clone()),
|
||||
) {
|
||||
eprintln!("Persist error: {err}");
|
||||
return HandledRequest::Response(
|
||||
Response::from_string(format!("Unable to persist auth file: {err}"))
|
||||
.with_status_code(500),
|
||||
);
|
||||
}
|
||||
|
||||
let success_url = compose_success_url(
|
||||
actual_port,
|
||||
&opts.issuer,
|
||||
&tokens.id_token,
|
||||
&tokens.access_token,
|
||||
);
|
||||
match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
|
||||
Ok(header) => HandledRequest::RedirectWithHeader(header),
|
||||
Err(_) => HandledRequest::Response(
|
||||
Response::from_string("Internal Server Error").with_status_code(500),
|
||||
),
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Token exchange error: {err}");
|
||||
HandledRequest::Response(
|
||||
Response::from_string(format!("Token exchange failed: {err}"))
|
||||
.with_status_code(500),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
"/success" => {
|
||||
let body = include_str!("assets/success.html");
|
||||
let mut resp = Response::from_data(body.as_bytes());
|
||||
if let Ok(h) = tiny_http::Header::from_bytes(
|
||||
&b"Content-Type"[..],
|
||||
&b"text/html; charset=utf-8"[..],
|
||||
) {
|
||||
resp.add_header(h);
|
||||
}
|
||||
HandledRequest::ResponseAndExit(resp)
|
||||
}
|
||||
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawns a detached thread that waits for either a completion signal on `done_rx`
|
||||
/// or the specified `timeout` to elapse. If the timeout elapses first it marks
|
||||
/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
|
||||
|
||||
Reference in New Issue
Block a user