fix: refactor login/src/server.rs so process_request() is a separate function (#2388)
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
use std::io::Cursor;
|
||||||
use std::io::{self};
|
use std::io::{self};
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -15,6 +16,8 @@ use crate::pkce::generate_pkce;
|
|||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
|
use tiny_http::Header;
|
||||||
|
use tiny_http::Request;
|
||||||
use tiny_http::Response;
|
use tiny_http::Response;
|
||||||
use tiny_http::Server;
|
use tiny_http::Server;
|
||||||
|
|
||||||
@@ -149,116 +152,23 @@ pub fn run_login_server(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let url_raw = req.url().to_string();
|
let response = process_request(&req, &opts, &redirect_uri, &pkce, actual_port, &state);
|
||||||
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
||||||
Ok(u) => u,
|
match response {
|
||||||
Err(e) => {
|
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
||||||
eprintln!("URL parse error: {e}");
|
let _ = req.respond(r);
|
||||||
let _ = req.respond(Response::from_string("Bad Request").with_status_code(400));
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
};
|
HandledRequest::RedirectWithHeader(header) => {
|
||||||
let path = parsed_url.path().to_string();
|
let redirect = Response::empty(302).with_header(header);
|
||||||
|
let _ = req.respond(redirect);
|
||||||
match path.as_str() {
|
|
||||||
"/auth/callback" => {
|
|
||||||
let params: std::collections::HashMap<String, String> =
|
|
||||||
parsed_url.query_pairs().into_owned().collect();
|
|
||||||
if params.get("state").map(String::as_str) != Some(state.as_str()) {
|
|
||||||
let _ = req
|
|
||||||
.respond(Response::from_string("State mismatch").with_status_code(400));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let code = match params.get("code") {
|
|
||||||
Some(c) if !c.is_empty() => c.clone(),
|
|
||||||
_ => {
|
|
||||||
let _ = req.respond(
|
|
||||||
Response::from_string("Missing authorization code")
|
|
||||||
.with_status_code(400),
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match exchange_code_for_tokens(
|
|
||||||
&opts.issuer,
|
|
||||||
&opts.client_id,
|
|
||||||
&redirect_uri,
|
|
||||||
&pkce,
|
|
||||||
&code,
|
|
||||||
) {
|
|
||||||
Ok(tokens) => {
|
|
||||||
// Obtain API key via token-exchange and persist
|
|
||||||
let api_key =
|
|
||||||
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
|
|
||||||
.ok();
|
|
||||||
if let Err(err) = persist_tokens(
|
|
||||||
&opts.codex_home,
|
|
||||||
api_key.clone(),
|
|
||||||
tokens.id_token.clone(),
|
|
||||||
Some(tokens.access_token.clone()),
|
|
||||||
Some(tokens.refresh_token.clone()),
|
|
||||||
) {
|
|
||||||
eprintln!("Persist error: {err}");
|
|
||||||
let _ = req.respond(
|
|
||||||
Response::from_string(format!(
|
|
||||||
"Unable to persist auth file: {err}"
|
|
||||||
))
|
|
||||||
.with_status_code(500),
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let success_url = compose_success_url(
|
|
||||||
actual_port,
|
|
||||||
&opts.issuer,
|
|
||||||
&tokens.id_token,
|
|
||||||
&tokens.access_token,
|
|
||||||
);
|
|
||||||
match tiny_http::Header::from_bytes(
|
|
||||||
&b"Location"[..],
|
|
||||||
success_url.as_bytes(),
|
|
||||||
) {
|
|
||||||
Ok(h) => {
|
|
||||||
let response = tiny_http::Response::empty(302).with_header(h);
|
|
||||||
let _ = req.respond(response);
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
let _ = req.respond(
|
|
||||||
Response::from_string("Internal Server Error")
|
|
||||||
.with_status_code(500),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
eprintln!("Token exchange error: {err}");
|
|
||||||
let _ = req.respond(
|
|
||||||
Response::from_string(format!("Token exchange failed: {err}"))
|
|
||||||
.with_status_code(500),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
"/success" => {
|
}
|
||||||
let body = include_str!("assets/success.html");
|
|
||||||
let mut resp = Response::from_data(body.as_bytes());
|
|
||||||
if let Ok(h) = tiny_http::Header::from_bytes(
|
|
||||||
&b"Content-Type"[..],
|
|
||||||
&b"text/html; charset=utf-8"[..],
|
|
||||||
) {
|
|
||||||
resp.add_header(h);
|
|
||||||
}
|
|
||||||
let _ = req.respond(resp);
|
|
||||||
shutdown_flag.store(true, Ordering::SeqCst);
|
|
||||||
|
|
||||||
// Login has succeeded, so disarm the timeout watcher.
|
if is_login_complete {
|
||||||
let _ = done_tx.send(());
|
shutdown_flag.store(true, Ordering::SeqCst);
|
||||||
return Ok(());
|
// Login has succeeded, so disarm the timeout watcher.
|
||||||
}
|
let _ = done_tx.send(());
|
||||||
_ => {
|
return Ok(());
|
||||||
let _ = req.respond(Response::from_string("Not Found").with_status_code(404));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,6 +191,107 @@ pub fn run_login_server(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum HandledRequest {
|
||||||
|
Response(Response<Cursor<Vec<u8>>>),
|
||||||
|
RedirectWithHeader(Header),
|
||||||
|
ResponseAndExit(Response<Cursor<Vec<u8>>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_request(
|
||||||
|
req: &Request,
|
||||||
|
opts: &ServerOptions,
|
||||||
|
redirect_uri: &str,
|
||||||
|
pkce: &PkceCodes,
|
||||||
|
actual_port: u16,
|
||||||
|
state: &str,
|
||||||
|
) -> HandledRequest {
|
||||||
|
let url_raw = req.url().to_string();
|
||||||
|
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
||||||
|
Ok(u) => u,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("URL parse error: {e}");
|
||||||
|
return HandledRequest::Response(
|
||||||
|
Response::from_string("Bad Request").with_status_code(400),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let path = parsed_url.path().to_string();
|
||||||
|
|
||||||
|
match path.as_str() {
|
||||||
|
"/auth/callback" => {
|
||||||
|
let params: std::collections::HashMap<String, String> =
|
||||||
|
parsed_url.query_pairs().into_owned().collect();
|
||||||
|
if params.get("state").map(String::as_str) != Some(state) {
|
||||||
|
return HandledRequest::Response(
|
||||||
|
Response::from_string("State mismatch").with_status_code(400),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let code = match params.get("code") {
|
||||||
|
Some(c) if !c.is_empty() => c.clone(),
|
||||||
|
_ => {
|
||||||
|
return HandledRequest::Response(
|
||||||
|
Response::from_string("Missing authorization code").with_status_code(400),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
|
||||||
|
{
|
||||||
|
Ok(tokens) => {
|
||||||
|
// Obtain API key via token-exchange and persist
|
||||||
|
let api_key =
|
||||||
|
obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token).ok();
|
||||||
|
if let Err(err) = persist_tokens(
|
||||||
|
&opts.codex_home,
|
||||||
|
api_key.clone(),
|
||||||
|
tokens.id_token.clone(),
|
||||||
|
Some(tokens.access_token.clone()),
|
||||||
|
Some(tokens.refresh_token.clone()),
|
||||||
|
) {
|
||||||
|
eprintln!("Persist error: {err}");
|
||||||
|
return HandledRequest::Response(
|
||||||
|
Response::from_string(format!("Unable to persist auth file: {err}"))
|
||||||
|
.with_status_code(500),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let success_url = compose_success_url(
|
||||||
|
actual_port,
|
||||||
|
&opts.issuer,
|
||||||
|
&tokens.id_token,
|
||||||
|
&tokens.access_token,
|
||||||
|
);
|
||||||
|
match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
|
||||||
|
Ok(header) => HandledRequest::RedirectWithHeader(header),
|
||||||
|
Err(_) => HandledRequest::Response(
|
||||||
|
Response::from_string("Internal Server Error").with_status_code(500),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("Token exchange error: {err}");
|
||||||
|
HandledRequest::Response(
|
||||||
|
Response::from_string(format!("Token exchange failed: {err}"))
|
||||||
|
.with_status_code(500),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"/success" => {
|
||||||
|
let body = include_str!("assets/success.html");
|
||||||
|
let mut resp = Response::from_data(body.as_bytes());
|
||||||
|
if let Ok(h) = tiny_http::Header::from_bytes(
|
||||||
|
&b"Content-Type"[..],
|
||||||
|
&b"text/html; charset=utf-8"[..],
|
||||||
|
) {
|
||||||
|
resp.add_header(h);
|
||||||
|
}
|
||||||
|
HandledRequest::ResponseAndExit(resp)
|
||||||
|
}
|
||||||
|
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Spawns a detached thread that waits for either a completion signal on `done_rx`
|
/// Spawns a detached thread that waits for either a completion signal on `done_rx`
|
||||||
/// or the specified `timeout` to elapse. If the timeout elapses first it marks
|
/// or the specified `timeout` to elapse. If the timeout elapses first it marks
|
||||||
/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
|
/// the `shutdown_flag`, records `timeout_flag`, and unblocks the HTTP server so
|
||||||
|
|||||||
Reference in New Issue
Block a user