Folds the top-level `shutdown()` function into a method of `ShutdownHandle` and then simply stores `ShutdownHandle` on `LoginServer` since the two fields it contains were always being used together, anyway. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/2396). * #2399 * #2398 * __->__ #2396 * #2395 * #2394 * #2393 * #2389
534 lines
18 KiB
Rust
534 lines
18 KiB
Rust
use std::io::Cursor;
|
|
use std::io::{self};
|
|
use std::path::Path;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::thread;
|
|
|
|
use crate::AuthDotJson;
|
|
use crate::get_auth_file;
|
|
use crate::pkce::PkceCodes;
|
|
use crate::pkce::generate_pkce;
|
|
use base64::Engine;
|
|
use chrono::Utc;
|
|
use rand::RngCore;
|
|
use tiny_http::Header;
|
|
use tiny_http::Request;
|
|
use tiny_http::Response;
|
|
use tiny_http::Server;
|
|
|
|
const DEFAULT_ISSUER: &str = "https://auth.openai.com";
|
|
const DEFAULT_PORT: u16 = 1455;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ServerOptions {
|
|
pub codex_home: PathBuf,
|
|
pub client_id: String,
|
|
pub issuer: String,
|
|
pub port: u16,
|
|
pub open_browser: bool,
|
|
pub force_state: Option<String>,
|
|
}
|
|
|
|
impl ServerOptions {
|
|
pub fn new(codex_home: PathBuf, client_id: String) -> Self {
|
|
Self {
|
|
codex_home,
|
|
client_id: client_id.to_string(),
|
|
issuer: DEFAULT_ISSUER.to_string(),
|
|
port: DEFAULT_PORT,
|
|
open_browser: true,
|
|
force_state: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct LoginServer {
|
|
pub auth_url: String,
|
|
pub actual_port: u16,
|
|
server_handle: tokio::task::JoinHandle<io::Result<()>>,
|
|
shutdown_handle: ShutdownHandle,
|
|
}
|
|
|
|
impl LoginServer {
|
|
pub async fn block_until_done(self) -> io::Result<()> {
|
|
self.server_handle
|
|
.await
|
|
.map_err(|err| io::Error::other(format!("login server thread panicked: {err:?}")))?
|
|
}
|
|
|
|
pub fn cancel(&self) {
|
|
self.shutdown_handle.shutdown();
|
|
}
|
|
|
|
pub fn cancel_handle(&self) -> ShutdownHandle {
|
|
self.shutdown_handle.clone()
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ShutdownHandle {
|
|
shutdown_notify: Arc<tokio::sync::Notify>,
|
|
server: Arc<Server>,
|
|
}
|
|
|
|
impl std::fmt::Debug for ShutdownHandle {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("ShutdownHandle")
|
|
.field("shutdown_notify", &self.shutdown_notify)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl ShutdownHandle {
|
|
pub fn shutdown(&self) {
|
|
self.shutdown_notify.notify_waiters();
|
|
self.server.unblock();
|
|
}
|
|
}
|
|
|
|
pub fn run_login_server(
|
|
opts: ServerOptions,
|
|
shutdown_flag: Option<Arc<tokio::sync::Notify>>,
|
|
) -> io::Result<LoginServer> {
|
|
let pkce = generate_pkce();
|
|
let state = opts.force_state.clone().unwrap_or_else(generate_state);
|
|
|
|
let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?;
|
|
let actual_port = match server.server_addr().to_ip() {
|
|
Some(addr) => addr.port(),
|
|
None => {
|
|
return Err(io::Error::new(
|
|
io::ErrorKind::AddrInUse,
|
|
"Unable to determine the server port",
|
|
));
|
|
}
|
|
};
|
|
let server = Arc::new(server);
|
|
|
|
let redirect_uri = format!("http://localhost:{actual_port}/auth/callback");
|
|
let auth_url = build_authorize_url(&opts.issuer, &opts.client_id, &redirect_uri, &pkce, &state);
|
|
|
|
if opts.open_browser {
|
|
let _ = webbrowser::open(&auth_url);
|
|
}
|
|
|
|
// Map blocking reads from server.recv() to an async channel.
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<Request>(16);
|
|
let _server_handle = {
|
|
let server = server.clone();
|
|
thread::spawn(move || -> io::Result<()> {
|
|
while let Ok(request) = server.recv() {
|
|
tx.blocking_send(request).map_err(|e| {
|
|
eprintln!("Failed to send request to channel: {e}");
|
|
io::Error::other("Failed to send request to channel")
|
|
})?;
|
|
}
|
|
Ok(())
|
|
})
|
|
};
|
|
|
|
let shutdown_notify = shutdown_flag.unwrap_or_else(|| Arc::new(tokio::sync::Notify::new()));
|
|
let server_handle = {
|
|
let shutdown_notify = shutdown_notify.clone();
|
|
let server = server.clone();
|
|
tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown_notify.notified() => {
|
|
return Err(io::Error::other("Login was not completed"));
|
|
}
|
|
maybe_req = rx.recv() => {
|
|
let Some(req) = maybe_req else {
|
|
return Err(io::Error::other("Login was not completed"));
|
|
};
|
|
|
|
let url_raw = req.url().to_string();
|
|
let response =
|
|
process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await;
|
|
|
|
let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_));
|
|
match response {
|
|
HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => {
|
|
let _ = tokio::task::spawn_blocking(move || req.respond(r)).await;
|
|
}
|
|
HandledRequest::RedirectWithHeader(header) => {
|
|
let redirect = Response::empty(302).with_header(header);
|
|
let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await;
|
|
}
|
|
}
|
|
|
|
if is_login_complete {
|
|
shutdown_notify.notify_waiters();
|
|
server.unblock();
|
|
return Ok(());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
})
|
|
};
|
|
|
|
Ok(LoginServer {
|
|
auth_url,
|
|
actual_port,
|
|
server_handle,
|
|
shutdown_handle: ShutdownHandle {
|
|
shutdown_notify,
|
|
server,
|
|
},
|
|
})
|
|
}
|
|
|
|
enum HandledRequest {
|
|
Response(Response<Cursor<Vec<u8>>>),
|
|
RedirectWithHeader(Header),
|
|
ResponseAndExit(Response<Cursor<Vec<u8>>>),
|
|
}
|
|
|
|
async fn process_request(
|
|
url_raw: &str,
|
|
opts: &ServerOptions,
|
|
redirect_uri: &str,
|
|
pkce: &PkceCodes,
|
|
actual_port: u16,
|
|
state: &str,
|
|
) -> HandledRequest {
|
|
let parsed_url = match url::Url::parse(&format!("http://localhost{url_raw}")) {
|
|
Ok(u) => u,
|
|
Err(e) => {
|
|
eprintln!("URL parse error: {e}");
|
|
return HandledRequest::Response(
|
|
Response::from_string("Bad Request").with_status_code(400),
|
|
);
|
|
}
|
|
};
|
|
let path = parsed_url.path().to_string();
|
|
|
|
match path.as_str() {
|
|
"/auth/callback" => {
|
|
let params: std::collections::HashMap<String, String> =
|
|
parsed_url.query_pairs().into_owned().collect();
|
|
if params.get("state").map(String::as_str) != Some(state) {
|
|
return HandledRequest::Response(
|
|
Response::from_string("State mismatch").with_status_code(400),
|
|
);
|
|
}
|
|
let code = match params.get("code") {
|
|
Some(c) if !c.is_empty() => c.clone(),
|
|
_ => {
|
|
return HandledRequest::Response(
|
|
Response::from_string("Missing authorization code").with_status_code(400),
|
|
);
|
|
}
|
|
};
|
|
|
|
match exchange_code_for_tokens(&opts.issuer, &opts.client_id, redirect_uri, pkce, &code)
|
|
.await
|
|
{
|
|
Ok(tokens) => {
|
|
// Obtain API key via token-exchange and persist
|
|
let api_key = obtain_api_key(&opts.issuer, &opts.client_id, &tokens.id_token)
|
|
.await
|
|
.ok();
|
|
if let Err(err) = persist_tokens_async(
|
|
&opts.codex_home,
|
|
api_key.clone(),
|
|
tokens.id_token.clone(),
|
|
Some(tokens.access_token.clone()),
|
|
Some(tokens.refresh_token.clone()),
|
|
)
|
|
.await
|
|
{
|
|
eprintln!("Persist error: {err}");
|
|
return HandledRequest::Response(
|
|
Response::from_string(format!("Unable to persist auth file: {err}"))
|
|
.with_status_code(500),
|
|
);
|
|
}
|
|
|
|
let success_url = compose_success_url(
|
|
actual_port,
|
|
&opts.issuer,
|
|
&tokens.id_token,
|
|
&tokens.access_token,
|
|
);
|
|
match tiny_http::Header::from_bytes(&b"Location"[..], success_url.as_bytes()) {
|
|
Ok(header) => HandledRequest::RedirectWithHeader(header),
|
|
Err(_) => HandledRequest::Response(
|
|
Response::from_string("Internal Server Error").with_status_code(500),
|
|
),
|
|
}
|
|
}
|
|
Err(err) => {
|
|
eprintln!("Token exchange error: {err}");
|
|
HandledRequest::Response(
|
|
Response::from_string(format!("Token exchange failed: {err}"))
|
|
.with_status_code(500),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
"/success" => {
|
|
let body = include_str!("assets/success.html");
|
|
let mut resp = Response::from_data(body.as_bytes());
|
|
if let Ok(h) = tiny_http::Header::from_bytes(
|
|
&b"Content-Type"[..],
|
|
&b"text/html; charset=utf-8"[..],
|
|
) {
|
|
resp.add_header(h);
|
|
}
|
|
HandledRequest::ResponseAndExit(resp)
|
|
}
|
|
_ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)),
|
|
}
|
|
}
|
|
|
|
fn build_authorize_url(
|
|
issuer: &str,
|
|
client_id: &str,
|
|
redirect_uri: &str,
|
|
pkce: &PkceCodes,
|
|
state: &str,
|
|
) -> String {
|
|
let query = vec![
|
|
("response_type", "code"),
|
|
("client_id", client_id),
|
|
("redirect_uri", redirect_uri),
|
|
("scope", "openid profile email offline_access"),
|
|
("code_challenge", &pkce.code_challenge),
|
|
("code_challenge_method", "S256"),
|
|
("id_token_add_organizations", "true"),
|
|
("codex_cli_simplified_flow", "true"),
|
|
("state", state),
|
|
];
|
|
let qs = query
|
|
.into_iter()
|
|
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
|
|
.collect::<Vec<_>>()
|
|
.join("&");
|
|
format!("{issuer}/oauth/authorize?{qs}")
|
|
}
|
|
|
|
fn generate_state() -> String {
|
|
let mut bytes = [0u8; 32];
|
|
rand::thread_rng().fill_bytes(&mut bytes);
|
|
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
|
}
|
|
|
|
struct ExchangedTokens {
|
|
id_token: String,
|
|
access_token: String,
|
|
refresh_token: String,
|
|
}
|
|
|
|
async fn exchange_code_for_tokens(
|
|
issuer: &str,
|
|
client_id: &str,
|
|
redirect_uri: &str,
|
|
pkce: &PkceCodes,
|
|
code: &str,
|
|
) -> io::Result<ExchangedTokens> {
|
|
#[derive(serde::Deserialize)]
|
|
struct TokenResponse {
|
|
id_token: String,
|
|
access_token: String,
|
|
refresh_token: String,
|
|
}
|
|
|
|
let client = reqwest::Client::new();
|
|
let resp = client
|
|
.post(format!("{issuer}/oauth/token"))
|
|
.header("Content-Type", "application/x-www-form-urlencoded")
|
|
.body(format!(
|
|
"grant_type=authorization_code&code={}&redirect_uri={}&client_id={}&code_verifier={}",
|
|
urlencoding::encode(code),
|
|
urlencoding::encode(redirect_uri),
|
|
urlencoding::encode(client_id),
|
|
urlencoding::encode(&pkce.code_verifier)
|
|
))
|
|
.send()
|
|
.await
|
|
.map_err(io::Error::other)?;
|
|
|
|
if !resp.status().is_success() {
|
|
return Err(io::Error::other(format!(
|
|
"token endpoint returned status {}",
|
|
resp.status()
|
|
)));
|
|
}
|
|
|
|
let tokens: TokenResponse = resp.json().await.map_err(io::Error::other)?;
|
|
Ok(ExchangedTokens {
|
|
id_token: tokens.id_token,
|
|
access_token: tokens.access_token,
|
|
refresh_token: tokens.refresh_token,
|
|
})
|
|
}
|
|
|
|
async fn persist_tokens_async(
|
|
codex_home: &Path,
|
|
api_key: Option<String>,
|
|
id_token: String,
|
|
access_token: Option<String>,
|
|
refresh_token: Option<String>,
|
|
) -> io::Result<()> {
|
|
// Reuse existing synchronous logic but run it off the async runtime.
|
|
let codex_home = codex_home.to_path_buf();
|
|
tokio::task::spawn_blocking(move || {
|
|
let auth_file = get_auth_file(&codex_home);
|
|
if let Some(parent) = auth_file.parent() {
|
|
if !parent.exists() {
|
|
std::fs::create_dir_all(parent).map_err(io::Error::other)?;
|
|
}
|
|
}
|
|
|
|
let mut auth = read_or_default(&auth_file);
|
|
if let Some(key) = api_key {
|
|
auth.openai_api_key = Some(key);
|
|
}
|
|
let tokens = auth
|
|
.tokens
|
|
.get_or_insert_with(crate::token_data::TokenData::default);
|
|
tokens.id_token = crate::token_data::parse_id_token(&id_token).map_err(io::Error::other)?;
|
|
// Persist chatgpt_account_id if present in claims
|
|
if let Some(acc) = jwt_auth_claims(&id_token)
|
|
.get("chatgpt_account_id")
|
|
.and_then(|v| v.as_str())
|
|
{
|
|
tokens.account_id = Some(acc.to_string());
|
|
}
|
|
if let Some(at) = access_token {
|
|
tokens.access_token = at;
|
|
}
|
|
if let Some(rt) = refresh_token {
|
|
tokens.refresh_token = rt;
|
|
}
|
|
auth.last_refresh = Some(Utc::now());
|
|
super::write_auth_json(&auth_file, &auth)
|
|
})
|
|
.await
|
|
.map_err(|e| io::Error::other(format!("persist task failed: {e}")))?
|
|
}
|
|
|
|
fn read_or_default(path: &Path) -> AuthDotJson {
|
|
match super::try_read_auth_json(path) {
|
|
Ok(auth) => auth,
|
|
Err(_) => AuthDotJson {
|
|
openai_api_key: None,
|
|
tokens: None,
|
|
last_refresh: None,
|
|
},
|
|
}
|
|
}
|
|
|
|
fn compose_success_url(port: u16, issuer: &str, id_token: &str, access_token: &str) -> String {
|
|
let token_claims = jwt_auth_claims(id_token);
|
|
let access_claims = jwt_auth_claims(access_token);
|
|
|
|
let org_id = token_claims
|
|
.get("organization_id")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("");
|
|
let project_id = token_claims
|
|
.get("project_id")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("");
|
|
let completed_onboarding = token_claims
|
|
.get("completed_platform_onboarding")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
let is_org_owner = token_claims
|
|
.get("is_org_owner")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(false);
|
|
let needs_setup = (!completed_onboarding) && is_org_owner;
|
|
let plan_type = access_claims
|
|
.get("chatgpt_plan_type")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or("");
|
|
|
|
let platform_url = if issuer == DEFAULT_ISSUER {
|
|
"https://platform.openai.com"
|
|
} else {
|
|
"https://platform.api.openai.org"
|
|
};
|
|
|
|
let mut params = vec![
|
|
("id_token", id_token.to_string()),
|
|
("needs_setup", needs_setup.to_string()),
|
|
("org_id", org_id.to_string()),
|
|
("project_id", project_id.to_string()),
|
|
("plan_type", plan_type.to_string()),
|
|
("platform_url", platform_url.to_string()),
|
|
];
|
|
let qs = params
|
|
.drain(..)
|
|
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v)))
|
|
.collect::<Vec<_>>()
|
|
.join("&");
|
|
format!("http://localhost:{port}/success?{qs}")
|
|
}
|
|
|
|
fn jwt_auth_claims(jwt: &str) -> serde_json::Map<String, serde_json::Value> {
|
|
let mut parts = jwt.split('.');
|
|
let (_h, payload_b64, _s) = match (parts.next(), parts.next(), parts.next()) {
|
|
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
|
|
_ => {
|
|
eprintln!("Invalid JWT format while extracting claims");
|
|
return serde_json::Map::new();
|
|
}
|
|
};
|
|
match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) {
|
|
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
|
Ok(mut v) => {
|
|
if let Some(obj) = v
|
|
.get_mut("https://api.openai.com/auth")
|
|
.and_then(|x| x.as_object_mut())
|
|
{
|
|
return obj.clone();
|
|
}
|
|
eprintln!("JWT payload missing expected 'https://api.openai.com/auth' object");
|
|
}
|
|
Err(e) => {
|
|
eprintln!("Failed to parse JWT JSON payload: {e}");
|
|
}
|
|
},
|
|
Err(e) => {
|
|
eprintln!("Failed to base64url-decode JWT payload: {e}");
|
|
}
|
|
}
|
|
serde_json::Map::new()
|
|
}
|
|
|
|
async fn obtain_api_key(issuer: &str, client_id: &str, id_token: &str) -> io::Result<String> {
|
|
// Token exchange for an API key access token
|
|
#[derive(serde::Deserialize)]
|
|
struct ExchangeResp {
|
|
access_token: String,
|
|
}
|
|
let client = reqwest::Client::new();
|
|
let resp = client
|
|
.post(format!("{issuer}/oauth/token"))
|
|
.header("Content-Type", "application/x-www-form-urlencoded")
|
|
.body(format!(
|
|
"grant_type={}&client_id={}&requested_token={}&subject_token={}&subject_token_type={}",
|
|
urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
|
|
urlencoding::encode(client_id),
|
|
urlencoding::encode("openai-api-key"),
|
|
urlencoding::encode(id_token),
|
|
urlencoding::encode("urn:ietf:params:oauth:token-type:id_token")
|
|
))
|
|
.send()
|
|
.await
|
|
.map_err(io::Error::other)?;
|
|
if !resp.status().is_success() {
|
|
return Err(io::Error::other(format!(
|
|
"api key exchange failed with status {}",
|
|
resp.status()
|
|
)));
|
|
}
|
|
let body: ExchangeResp = resp.json().await.map_err(io::Error::other)?;
|
|
Ok(body.access_token)
|
|
}
|