diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 9abce0c3..120050c2 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -673,7 +673,9 @@ dependencies = [ "async-channel", "base64 0.22.1", "bytes", + "chrono", "codex-apply-patch", + "codex-login", "codex-mcp-client", "core_test_support", "dirs", diff --git a/codex-rs/chatgpt/src/chatgpt_client.rs b/codex-rs/chatgpt/src/chatgpt_client.rs index 4c4cb4c4..907783bb 100644 --- a/codex-rs/chatgpt/src/chatgpt_client.rs +++ b/codex-rs/chatgpt/src/chatgpt_client.rs @@ -21,10 +21,14 @@ pub(crate) async fn chatgpt_get_request( let token = get_chatgpt_token_data().ok_or_else(|| anyhow::anyhow!("ChatGPT token not available"))?; + let account_id = token.account_id.ok_or_else(|| { + anyhow::anyhow!("ChatGPT account ID not available, please re-run `codex login`") + }); + let response = client .get(&url) .bearer_auth(&token.access_token) - .header("chatgpt-account-id", &token.account_id) + .header("chatgpt-account-id", account_id?) .header("Content-Type", "application/json") .header("User-Agent", "codex-cli") .send() diff --git a/codex-rs/chatgpt/src/chatgpt_token.rs b/codex-rs/chatgpt/src/chatgpt_token.rs index adf9a6ba..55ebc22a 100644 --- a/codex-rs/chatgpt/src/chatgpt_token.rs +++ b/codex-rs/chatgpt/src/chatgpt_token.rs @@ -18,7 +18,10 @@ pub fn set_chatgpt_token_data(value: TokenData) { /// Initialize the ChatGPT token from auth.json file pub async fn init_chatgpt_token_from_auth(codex_home: &Path) -> std::io::Result<()> { - let auth_json = codex_login::try_read_auth_json(codex_home).await?; - set_chatgpt_token_data(auth_json.tokens.clone()); + let auth = codex_login::load_auth(codex_home)?; + if let Some(auth) = auth { + let token_data = auth.get_token_data().await?; + set_chatgpt_token_data(token_data); + } Ok(()) } diff --git a/codex-rs/cli/src/proto.rs b/codex-rs/cli/src/proto.rs index 64b292d5..291e1680 100644 --- a/codex-rs/cli/src/proto.rs +++ b/codex-rs/cli/src/proto.rs @@ -9,6 +9,7 @@ use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::protocol::Submission; use codex_core::util::notify_on_sigint; +use codex_login::load_auth; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tracing::error; @@ -35,8 +36,9 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> { .map_err(anyhow::Error::msg)?; let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?; + let auth = load_auth(&config.codex_home)?; let ctrl_c = notify_on_sigint(); - let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await?; + let CodexSpawnOk { codex, .. } = Codex::spawn(config, auth, ctrl_c.clone()).await?; let codex = Arc::new(codex); // Task that reads JSON lines from stdin and forwards to Submission Queue diff --git a/codex-rs/config.md b/codex-rs/config.md index c45d8118..1a407a23 100644 --- a/codex-rs/config.md +++ b/codex-rs/config.md @@ -110,12 +110,15 @@ stream_idle_timeout_ms = 300000 # 5m idle timeout ``` #### request_max_retries + How many times Codex will retry a failed HTTP request to the model provider. Defaults to `4`. #### stream_max_retries + Number of times Codex will attempt to reconnect when a streaming response is interrupted. Defaults to `10`. #### stream_idle_timeout_ms + How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes). ## model_provider diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 2e0489c9..5ebb5ef6 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -17,6 +17,8 @@ base64 = "0.22" bytes = "1.10.1" codex-apply-patch = { path = "../apply-patch" } codex-mcp-client = { path = "../mcp-client" } +chrono = { version = "0.4", features = ["serde"] } +codex-login = { path = "../login" } dirs = "6" env-flags = "0.1.1" eventsource-stream = "0.2.3" diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index aa31b67e..72104da2 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -3,6 +3,8 @@ use std::path::Path; use std::time::Duration; use bytes::Bytes; +use codex_login::AuthMode; +use codex_login::CodexAuth; use eventsource_stream::Eventsource; use futures::prelude::*; use reqwest::StatusCode; @@ -28,6 +30,7 @@ use crate::config::Config; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::CodexErr; +use crate::error::EnvVarError; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_provider_info::ModelProviderInfo; @@ -41,6 +44,7 @@ use std::sync::Arc; #[derive(Clone)] pub struct ModelClient { config: Arc, + auth: Option, client: reqwest::Client, provider: ModelProviderInfo, session_id: Uuid, @@ -51,6 +55,7 @@ pub struct ModelClient { impl ModelClient { pub fn new( config: Arc, + auth: Option, provider: ModelProviderInfo, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, @@ -58,6 +63,7 @@ impl ModelClient { ) -> Self { Self { config, + auth, client: reqwest::Client::new(), provider, session_id, @@ -115,6 +121,25 @@ impl ModelClient { return stream_from_fixture(path, self.provider.clone()).await; } + let auth = self.auth.as_ref().ok_or_else(|| { + CodexErr::EnvVar(EnvVarError { + var: "OPENAI_API_KEY".to_string(), + instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".to_string()), + }) + })?; + + let store = prompt.store && auth.mode != AuthMode::ChatGPT; + + let base_url = match self.provider.base_url.clone() { + Some(url) => url, + None => match auth.mode { + AuthMode::ChatGPT => "https://chatgpt.com/backend-api/codex".to_string(), + AuthMode::ApiKey => "https://api.openai.com/v1".to_string(), + }, + }; + + let token = auth.get_token().await?; + let full_instructions = prompt.get_full_instructions(&self.config.model); let tools_json = create_tools_json_for_responses_api( prompt, @@ -125,7 +150,7 @@ impl ModelClient { // Request encrypted COT if we are not storing responses, // otherwise reasoning items will be referenced by ID - let include = if !prompt.store && reasoning.is_some() { + let include: Vec = if !store && reasoning.is_some() { vec!["reasoning.encrypted_content".to_string()] } else { vec![] @@ -139,8 +164,7 @@ impl ModelClient { tool_choice: "auto", parallel_tool_calls: false, reasoning, - store: prompt.store, - // TODO: make this configurable + store, stream: true, include, }; @@ -153,17 +177,21 @@ impl ModelClient { let mut attempt = 0; let max_retries = self.provider.request_max_retries(); + loop { attempt += 1; let req_builder = self - .provider - .create_request_builder(&self.client)? + .client + .post(format!("{base_url}/responses")) .header("OpenAI-Beta", "responses=experimental") .header("session_id", self.session_id.to_string()) + .bearer_auth(&token) .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload); + let req_builder = self.provider.apply_http_headers(req_builder); + let res = req_builder.send().await; if let Ok(resp) = &res { trace!( @@ -572,7 +600,7 @@ mod tests { let provider = ModelProviderInfo { name: "test".to_string(), - base_url: "https://test.com".to_string(), + base_url: Some("https://test.com".to_string()), env_key: Some("TEST_API_KEY".to_string()), env_key_instructions: None, wire_api: WireApi::Responses, @@ -582,6 +610,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), + requires_auth: false, }; let events = collect_events( @@ -631,7 +660,7 @@ mod tests { let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); let provider = ModelProviderInfo { name: "test".to_string(), - base_url: "https://test.com".to_string(), + base_url: Some("https://test.com".to_string()), env_key: Some("TEST_API_KEY".to_string()), env_key_instructions: None, wire_api: WireApi::Responses, @@ -641,6 +670,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), + requires_auth: false, }; let events = collect_events(&[sse1.as_bytes()], provider).await; @@ -733,7 +763,7 @@ mod tests { let provider = ModelProviderInfo { name: "test".to_string(), - base_url: "https://test.com".to_string(), + base_url: Some("https://test.com".to_string()), env_key: Some("TEST_API_KEY".to_string()), env_key_instructions: None, wire_api: WireApi::Responses, @@ -743,6 +773,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), + requires_auth: false, }; let out = run_sse(evs, provider).await; diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 6efc878f..92ca7bf8 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -15,6 +15,7 @@ use async_channel::Sender; use codex_apply_patch::ApplyPatchAction; use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; +use codex_login::CodexAuth; use futures::prelude::*; use mcp_types::CallToolResult; use serde::Serialize; @@ -103,7 +104,11 @@ pub struct CodexSpawnOk { impl Codex { /// Spawn a new [`Codex`] and initialize the session. - pub async fn spawn(config: Config, ctrl_c: Arc) -> CodexResult { + pub async fn spawn( + config: Config, + auth: Option, + ctrl_c: Arc, + ) -> CodexResult { // experimental resume path (undocumented) let resume_path = config.experimental_resume.clone(); info!("resume_path: {resume_path:?}"); @@ -132,7 +137,7 @@ impl Codex { // Generate a unique ID for the lifetime of this Codex session. let session_id = Uuid::new_v4(); tokio::spawn(submission_loop( - session_id, config, rx_sub, tx_event, ctrl_c, + session_id, config, auth, rx_sub, tx_event, ctrl_c, )); let codex = Codex { next_id: AtomicU64::new(0), @@ -525,6 +530,7 @@ impl AgentTask { async fn submission_loop( mut session_id: Uuid, config: Arc, + auth: Option, rx_sub: Receiver, tx_event: Sender, ctrl_c: Arc, @@ -636,6 +642,7 @@ async fn submission_loop( let client = ModelClient::new( config.clone(), + auth.clone(), provider.clone(), model_reasoning_effort, model_reasoning_summary, diff --git a/codex-rs/core/src/codex_wrapper.rs b/codex-rs/core/src/codex_wrapper.rs index b8057929..1e26a9eb 100644 --- a/codex-rs/core/src/codex_wrapper.rs +++ b/codex-rs/core/src/codex_wrapper.rs @@ -6,6 +6,7 @@ use crate::config::Config; use crate::protocol::Event; use crate::protocol::EventMsg; use crate::util::notify_on_sigint; +use codex_login::load_auth; use tokio::sync::Notify; use uuid::Uuid; @@ -25,11 +26,12 @@ pub struct CodexConversation { /// that callers can surface the information to the UI. pub async fn init_codex(config: Config) -> anyhow::Result { let ctrl_c = notify_on_sigint(); + let auth = load_auth(&config.codex_home)?; let CodexSpawnOk { codex, init_id, session_id, - } = Codex::spawn(config, ctrl_c.clone()).await?; + } = Codex::spawn(config, auth, ctrl_c.clone()).await?; // The first event must be `SessionInitialized`. Validate and forward it to // the caller so that they can display it in the conversation history. diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 53ca8d5b..a65ec096 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -526,6 +526,7 @@ impl Config { .chatgpt_base_url .or(cfg.chatgpt_base_url) .unwrap_or("https://chatgpt.com/backend-api/".to_string()), + experimental_resume, include_plan_tool: include_plan_tool.unwrap_or(false), }; @@ -794,7 +795,7 @@ disable_response_storage = true let openai_chat_completions_provider = ModelProviderInfo { name: "OpenAI using Chat Completions".to_string(), - base_url: "https://api.openai.com/v1".to_string(), + base_url: Some("https://api.openai.com/v1".to_string()), env_key: Some("OPENAI_API_KEY".to_string()), wire_api: crate::WireApi::Chat, env_key_instructions: None, @@ -804,6 +805,7 @@ disable_response_storage = true request_max_retries: Some(4), stream_max_retries: Some(10), stream_idle_timeout_ms: Some(300_000), + requires_auth: false, }; let model_provider_map = { let mut model_provider_map = built_in_model_providers(); diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index b2dbded5..ffe64d7c 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -30,8 +30,8 @@ mod message_history; mod model_provider_info; pub use model_provider_info::ModelProviderInfo; pub use model_provider_info::WireApi; +pub use model_provider_info::built_in_model_providers; mod models; -pub mod openai_api_key; mod openai_model_info; mod openai_tools; pub mod plan_tool; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 72ef58c6..4640f53a 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -12,7 +12,6 @@ use std::env::VarError; use std::time::Duration; use crate::error::EnvVarError; -use crate::openai_api_key::get_openai_api_key; /// Value for the `OpenAI-Originator` header that is sent with requests to /// OpenAI. @@ -30,7 +29,7 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4; #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum WireApi { - /// The experimental "Responses" API exposed by OpenAI at `/v1/responses`. + /// The Responses API exposed by OpenAI at `/v1/responses`. Responses, /// Regular Chat Completions compatible with `/v1/chat/completions`. @@ -44,7 +43,7 @@ pub struct ModelProviderInfo { /// Friendly display name. pub name: String, /// Base URL for the provider's OpenAI-compatible API. - pub base_url: String, + pub base_url: Option, /// Environment variable that stores the user's API key for this provider. pub env_key: Option, @@ -78,6 +77,10 @@ pub struct ModelProviderInfo { /// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating /// the connection as lost. pub stream_idle_timeout_ms: Option, + + /// Whether this provider requires some form of standard authentication (API key, ChatGPT token). + #[serde(default)] + pub requires_auth: bool, } impl ModelProviderInfo { @@ -93,11 +96,11 @@ impl ModelProviderInfo { &'a self, client: &'a reqwest::Client, ) -> crate::error::Result { - let api_key = self.api_key()?; - let url = self.get_full_url(); let mut builder = client.post(url); + + let api_key = self.api_key()?; if let Some(key) = api_key { builder = builder.bearer_auth(key); } @@ -117,9 +120,15 @@ impl ModelProviderInfo { .join("&"); format!("?{full_params}") }); - let base_url = &self.base_url; + let base_url = self + .base_url + .clone() + .unwrap_or("https://api.openai.com/v1".to_string()); + match self.wire_api { - WireApi::Responses => format!("{base_url}/responses{query_string}"), + WireApi::Responses => { + format!("{base_url}/responses{query_string}") + } WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), } } @@ -127,7 +136,10 @@ impl ModelProviderInfo { /// Apply provider-specific HTTP headers (both static and environment-based) /// onto an existing `reqwest::RequestBuilder` and return the updated /// builder. - fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + pub fn apply_http_headers( + &self, + mut builder: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { if let Some(extra) = &self.http_headers { for (k, v) in extra { builder = builder.header(k, v); @@ -152,11 +164,7 @@ impl ModelProviderInfo { fn api_key(&self) -> crate::error::Result> { match &self.env_key { Some(env_key) => { - let env_value = if env_key == crate::openai_api_key::OPENAI_API_KEY_ENV_VAR { - get_openai_api_key().map_or_else(|| Err(VarError::NotPresent), Ok) - } else { - std::env::var(env_key) - }; + let env_value = std::env::var(env_key); env_value .and_then(|v| { if v.trim().is_empty() { @@ -204,47 +212,51 @@ pub fn built_in_model_providers() -> HashMap { // providers are bundled with Codex CLI, so we only include the OpenAI // provider by default. Users are encouraged to add to `model_providers` // in config.toml to add their own providers. - [ - ( - "openai", - P { - name: "OpenAI".into(), - // Allow users to override the default OpenAI endpoint by - // exporting `OPENAI_BASE_URL`. This is useful when pointing - // Codex at a proxy, mock server, or Azure-style deployment - // without requiring a full TOML override for the built-in - // OpenAI provider. - base_url: std::env::var("OPENAI_BASE_URL") - .ok() - .filter(|v| !v.trim().is_empty()) - .unwrap_or_else(|| "https://api.openai.com/v1".to_string()), - env_key: Some("OPENAI_API_KEY".into()), - env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()), - wire_api: WireApi::Responses, - query_params: None, - http_headers: Some( - [ - ("originator".to_string(), OPENAI_ORIGINATOR_HEADER.to_string()), - ("version".to_string(), env!("CARGO_PKG_VERSION").to_string()), - ] - .into_iter() - .collect(), - ), - env_http_headers: Some( - [ - ("OpenAI-Organization".to_string(), "OPENAI_ORGANIZATION".to_string()), - ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), - ] - .into_iter() - .collect(), - ), - // Use global defaults for retry/timeout unless overridden in config.toml. - request_max_retries: None, - stream_max_retries: None, - stream_idle_timeout_ms: None, - }, - ), - ] + [( + "openai", + P { + name: "OpenAI".into(), + // Allow users to override the default OpenAI endpoint by + // exporting `OPENAI_BASE_URL`. This is useful when pointing + // Codex at a proxy, mock server, or Azure-style deployment + // without requiring a full TOML override for the built-in + // OpenAI provider. + base_url: std::env::var("OPENAI_BASE_URL") + .ok() + .filter(|v| !v.trim().is_empty()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: Some( + [ + ( + "originator".to_string(), + OPENAI_ORIGINATOR_HEADER.to_string(), + ), + ("version".to_string(), env!("CARGO_PKG_VERSION").to_string()), + ] + .into_iter() + .collect(), + ), + env_http_headers: Some( + [ + ( + "OpenAI-Organization".to_string(), + "OPENAI_ORGANIZATION".to_string(), + ), + ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), + ] + .into_iter() + .collect(), + ), + // Use global defaults for retry/timeout unless overridden in config.toml. + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: true, + }, + )] .into_iter() .map(|(k, v)| (k.to_string(), v)) .collect() @@ -264,7 +276,7 @@ base_url = "http://localhost:11434/v1" "#; let expected_provider = ModelProviderInfo { name: "Ollama".into(), - base_url: "http://localhost:11434/v1".into(), + base_url: Some("http://localhost:11434/v1".into()), env_key: None, env_key_instructions: None, wire_api: WireApi::Chat, @@ -274,6 +286,7 @@ base_url = "http://localhost:11434/v1" request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -290,7 +303,7 @@ query_params = { api-version = "2025-04-01-preview" } "#; let expected_provider = ModelProviderInfo { name: "Azure".into(), - base_url: "https://xxxxx.openai.azure.com/openai".into(), + base_url: Some("https://xxxxx.openai.azure.com/openai".into()), env_key: Some("AZURE_OPENAI_API_KEY".into()), env_key_instructions: None, wire_api: WireApi::Chat, @@ -302,6 +315,7 @@ query_params = { api-version = "2025-04-01-preview" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -319,7 +333,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } "#; let expected_provider = ModelProviderInfo { name: "Example".into(), - base_url: "https://example.com".into(), + base_url: Some("https://example.com".into()), env_key: Some("API_KEY".into()), env_key_instructions: None, wire_api: WireApi::Chat, @@ -333,6 +347,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, + requires_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); diff --git a/codex-rs/core/src/openai_api_key.rs b/codex-rs/core/src/openai_api_key.rs deleted file mode 100644 index 728914c0..00000000 --- a/codex-rs/core/src/openai_api_key.rs +++ /dev/null @@ -1,24 +0,0 @@ -use std::env; -use std::sync::LazyLock; -use std::sync::RwLock; - -pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; - -static OPENAI_API_KEY: LazyLock>> = LazyLock::new(|| { - let val = env::var(OPENAI_API_KEY_ENV_VAR) - .ok() - .and_then(|s| if s.is_empty() { None } else { Some(s) }); - RwLock::new(val) -}); - -pub fn get_openai_api_key() -> Option { - #![allow(clippy::unwrap_used)] - OPENAI_API_KEY.read().unwrap().clone() -} - -pub fn set_openai_api_key(value: String) { - #![allow(clippy::unwrap_used)] - if !value.is_empty() { - *OPENAI_API_KEY.write().unwrap() = Some(value); - } -} diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index 9de2d560..fbe63fb3 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -1,11 +1,19 @@ +use std::path::PathBuf; + +use chrono::Utc; use codex_core::Codex; use codex_core::CodexSpawnOk; use codex_core::ModelProviderInfo; +use codex_core::built_in_model_providers; use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::SessionConfiguredEvent; +use codex_login::AuthDotJson; +use codex_login::AuthMode; +use codex_login::CodexAuth; +use codex_login::TokenData; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; use core_test_support::wait_for_event; @@ -48,32 +56,23 @@ async fn includes_session_id_and_model_headers_in_request() { .await; let model_provider = ModelProviderInfo { - name: "openai".into(), - base_url: format!("{}/v1", server.uri()), - // Environment variable that should exist in the test environment. - // ModelClient will return an error if the environment variable for the - // provider is not set. - env_key: Some("PATH".into()), - env_key_instructions: None, - wire_api: codex_core::WireApi::Responses, - query_params: None, - http_headers: Some( - [("originator".to_string(), "codex_cli_rs".to_string())] - .into_iter() - .collect(), - ), - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: None, + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() }; // Init session let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); - let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); codex .submit(Op::UserInput { @@ -95,15 +94,20 @@ async fn includes_session_id_and_model_headers_in_request() { // get request from the server let request = &server.received_requests().await.unwrap()[0]; - let request_body = request.headers.get("session_id").unwrap(); - let originator = request.headers.get("originator").unwrap(); + let request_session_id = request.headers.get("session_id").unwrap(); + let request_originator = request.headers.get("originator").unwrap(); + let request_authorization = request.headers.get("authorization").unwrap(); assert!(current_session_id.is_some()); assert_eq!( - request_body.to_str().unwrap(), + request_session_id.to_str().unwrap(), current_session_id.as_ref().unwrap() ); - assert_eq!(originator.to_str().unwrap(), "codex_cli_rs"); + assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs"); + assert_eq!( + request_authorization.to_str().unwrap(), + "Bearer Test API Key" + ); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -126,22 +130,9 @@ async fn includes_base_instructions_override_in_request() { .await; let model_provider = ModelProviderInfo { - name: "openai".into(), - base_url: format!("{}/v1", server.uri()), - // Environment variable that should exist in the test environment. - // ModelClient will return an error if the environment variable for the - // provider is not set. - env_key: Some("PATH".into()), - env_key_instructions: None, - wire_api: codex_core::WireApi::Responses, - query_params: None, - http_headers: None, - env_http_headers: None, - request_max_retries: Some(0), - stream_max_retries: Some(0), - stream_idle_timeout_ms: None, + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() }; - let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); @@ -149,7 +140,13 @@ async fn includes_base_instructions_override_in_request() { config.model_provider = model_provider; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); - let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); codex .submit(Op::UserInput { @@ -172,3 +169,108 @@ async fn includes_base_instructions_override_in_request() { .contains("test instructions") ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn chatgpt_auth_sends_correct_request() { + #![allow(clippy::unwrap_used)] + + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/api/codex/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/api/codex", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = model_provider; + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(auth_from_token("Access Token".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = + wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))).await + else { + unreachable!() + }; + + let current_session_id = Some(session_id.to_string()); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + // get request from the server + let request = &server.received_requests().await.unwrap()[0]; + let request_session_id = request.headers.get("session_id").unwrap(); + let request_originator = request.headers.get("originator").unwrap(); + let request_authorization = request.headers.get("authorization").unwrap(); + let request_body = request.body_json::().unwrap(); + + assert!(current_session_id.is_some()); + assert_eq!( + request_session_id.to_str().unwrap(), + current_session_id.as_ref().unwrap() + ); + assert_eq!(request_originator.to_str().unwrap(), "codex_cli_rs"); + assert_eq!( + request_authorization.to_str().unwrap(), + "Bearer Access Token" + ); + assert!(!request_body["store"].as_bool().unwrap()); + assert!(request_body["stream"].as_bool().unwrap()); + assert_eq!( + request_body["include"][0].as_str().unwrap(), + "reasoning.encrypted_content" + ); +} + +fn auth_from_token(id_token: String) -> CodexAuth { + CodexAuth::new( + None, + AuthMode::ChatGPT, + PathBuf::new(), + Some(AuthDotJson { + tokens: TokenData { + id_token, + access_token: "Access Token".to_string(), + refresh_token: "test".to_string(), + account_id: None, + }, + last_refresh: Utc::now(), + openai_api_key: None, + }), + ) +} diff --git a/codex-rs/core/tests/live_agent.rs b/codex-rs/core/tests/live_agent.rs index 98953430..95408e20 100644 --- a/codex-rs/core/tests/live_agent.rs +++ b/codex-rs/core/tests/live_agent.rs @@ -50,7 +50,7 @@ async fn spawn_codex() -> Result { config.model_provider.request_max_retries = Some(2); config.model_provider.stream_max_retries = Some(2); let CodexSpawnOk { codex: agent, .. } = - Codex::spawn(config, std::sync::Arc::new(Notify::new())).await?; + Codex::spawn(config, None, std::sync::Arc::new(Notify::new())).await?; Ok(agent) } diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 8e5d83a0..d2fc0355 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -10,6 +10,7 @@ use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; +use codex_login::CodexAuth; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture; use core_test_support::load_sse_fixture_with_id; @@ -75,7 +76,7 @@ async fn retries_on_early_close() { let model_provider = ModelProviderInfo { name: "openai".into(), - base_url: format!("{}/v1", server.uri()), + base_url: Some(format!("{}/v1", server.uri())), // Environment variable that should exist in the test environment. // ModelClient will return an error if the environment variable for the // provider is not set. @@ -89,13 +90,20 @@ async fn retries_on_early_close() { request_max_retries: Some(0), stream_max_retries: Some(1), stream_idle_timeout_ms: Some(2000), + requires_auth: false, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let CodexSpawnOk { codex, .. } = Codex::spawn(config, ctrl_c).await.unwrap(); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(CodexAuth::from_api_key("Test API Key".to_string())), + ctrl_c, + ) + .await + .unwrap(); codex .submit(Op::UserInput { diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index ab92ecf6..47dbbca9 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -1,20 +1,152 @@ use chrono::DateTime; + use chrono::Utc; use serde::Deserialize; use serde::Serialize; +use std::env; use std::fs::OpenOptions; use std::io::Read; use std::io::Write; #[cfg(unix)] use std::os::unix::fs::OpenOptionsExt; use std::path::Path; +use std::path::PathBuf; use std::process::Stdio; +use std::sync::Arc; +use std::sync::Mutex; use std::time::Duration; use tokio::process::Command; const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py"); const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; + +#[derive(Clone, Debug, PartialEq)] +pub enum AuthMode { + ApiKey, + ChatGPT, +} + +#[derive(Debug, Clone)] +pub struct CodexAuth { + pub api_key: Option, + pub mode: AuthMode, + auth_dot_json: Arc>>, + auth_file: PathBuf, +} + +impl PartialEq for CodexAuth { + fn eq(&self, other: &Self) -> bool { + self.mode == other.mode + } +} + +impl CodexAuth { + pub fn new( + api_key: Option, + mode: AuthMode, + auth_file: PathBuf, + auth_dot_json: Option, + ) -> Self { + let auth_dot_json = Arc::new(Mutex::new(auth_dot_json)); + Self { + api_key, + mode, + auth_file, + auth_dot_json, + } + } + + pub fn from_api_key(api_key: String) -> Self { + Self { + api_key: Some(api_key), + mode: AuthMode::ApiKey, + auth_file: PathBuf::new(), + auth_dot_json: Arc::new(Mutex::new(None)), + } + } + + pub async fn get_token_data(&self) -> Result { + #[expect(clippy::unwrap_used)] + let auth_dot_json = self.auth_dot_json.lock().unwrap().clone(); + + match auth_dot_json { + Some(auth_dot_json) => { + if auth_dot_json.last_refresh < Utc::now() - chrono::Duration::days(28) { + let refresh_response = tokio::time::timeout( + Duration::from_secs(60), + try_refresh_token(auth_dot_json.tokens.refresh_token.clone()), + ) + .await + .map_err(|_| { + std::io::Error::other("timed out while refreshing OpenAI API key") + })? + .map_err(std::io::Error::other)?; + + let updated_auth_dot_json = update_tokens( + &self.auth_file, + refresh_response.id_token, + refresh_response.access_token, + refresh_response.refresh_token, + ) + .await?; + + #[expect(clippy::unwrap_used)] + let mut auth_dot_json = self.auth_dot_json.lock().unwrap(); + *auth_dot_json = Some(updated_auth_dot_json); + } + Ok(auth_dot_json.tokens.clone()) + } + None => Err(std::io::Error::other("Token data is not available.")), + } + } + + pub async fn get_token(&self) -> Result { + match self.mode { + AuthMode::ApiKey => Ok(self.api_key.clone().unwrap_or_default()), + AuthMode::ChatGPT => { + let id_token = self.get_token_data().await?.access_token; + + Ok(id_token) + } + } + } +} + +// Loads the available auth information from the auth.json or OPENAI_API_KEY environment variable. +pub fn load_auth(codex_home: &Path) -> std::io::Result> { + let auth_file = codex_home.join("auth.json"); + + let auth_dot_json = try_read_auth_json(&auth_file).ok(); + + let auth_json_api_key = auth_dot_json + .as_ref() + .and_then(|a| a.openai_api_key.clone()) + .filter(|s| !s.is_empty()); + + let openai_api_key = env::var(OPENAI_API_KEY_ENV_VAR) + .ok() + .filter(|s| !s.is_empty()) + .or(auth_json_api_key); + + if openai_api_key.is_none() && auth_dot_json.is_none() { + return Ok(None); + } + + let mode = if openai_api_key.is_some() { + AuthMode::ApiKey + } else { + AuthMode::ChatGPT + }; + + Ok(Some(CodexAuth { + api_key: openai_api_key, + mode, + auth_file, + auth_dot_json: Arc::new(Mutex::new(auth_dot_json)), + })) +} /// Run `python3 -c {{SOURCE_FOR_PYTHON_SERVER}}` with the CODEX_HOME /// environment variable set to the provided `codex_home` path. If the @@ -25,14 +157,12 @@ const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; /// If `capture_output` is true, the subprocess's output will be captured and /// recorded in memory. Otherwise, the subprocess's output will be sent to the /// current process's stdout/stderr. -pub async fn login_with_chatgpt( - codex_home: &Path, - capture_output: bool, -) -> std::io::Result { +pub async fn login_with_chatgpt(codex_home: &Path, capture_output: bool) -> std::io::Result<()> { let child = Command::new("python3") .arg("-c") .arg(SOURCE_FOR_PYTHON_SERVER) .env("CODEX_HOME", codex_home) + .env("CODEX_CLIENT_ID", CLIENT_ID) .stdin(Stdio::null()) .stdout(if capture_output { Stdio::piped() @@ -48,7 +178,7 @@ pub async fn login_with_chatgpt( let output = child.wait_with_output().await?; if output.status.success() { - try_read_openai_api_key(codex_home).await + Ok(()) } else { let stderr = String::from_utf8_lossy(&output.stderr); Err(std::io::Error::other(format!( @@ -57,65 +187,54 @@ pub async fn login_with_chatgpt( } } -/// Attempt to read the `OPENAI_API_KEY` from the `auth.json` file in the given -/// `CODEX_HOME` directory, refreshing it, if necessary. -pub async fn try_read_openai_api_key(codex_home: &Path) -> std::io::Result { - let auth_dot_json = try_read_auth_json(codex_home).await?; - Ok(auth_dot_json.openai_api_key) -} - /// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory. /// Returns the full AuthDotJson structure after refreshing if necessary. -pub async fn try_read_auth_json(codex_home: &Path) -> std::io::Result { - let auth_path = codex_home.join("auth.json"); - let mut file = std::fs::File::open(&auth_path)?; +pub fn try_read_auth_json(auth_file: &Path) -> std::io::Result { + let mut file = std::fs::File::open(auth_file)?; let mut contents = String::new(); file.read_to_string(&mut contents)?; let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?; - if is_expired(&auth_dot_json) { - let refresh_response = - tokio::time::timeout(Duration::from_secs(60), try_refresh_token(&auth_dot_json)) - .await - .map_err(|_| std::io::Error::other("timed out while refreshing OpenAI API key"))? - .map_err(std::io::Error::other)?; - let mut auth_dot_json = auth_dot_json; - auth_dot_json.tokens.id_token = refresh_response.id_token; - if let Some(refresh_token) = refresh_response.refresh_token { - auth_dot_json.tokens.refresh_token = refresh_token; - } - auth_dot_json.last_refresh = Utc::now(); + Ok(auth_dot_json) +} - let mut options = OpenOptions::new(); - options.truncate(true).write(true).create(true); - #[cfg(unix)] - { - options.mode(0o600); - } - - let json_data = serde_json::to_string(&auth_dot_json)?; - { - let mut file = options.open(&auth_path)?; - file.write_all(json_data.as_bytes())?; - file.flush()?; - } - - Ok(auth_dot_json) - } else { - Ok(auth_dot_json) +async fn update_tokens( + auth_file: &Path, + id_token: String, + access_token: Option, + refresh_token: Option, +) -> std::io::Result { + let mut options = OpenOptions::new(); + options.truncate(true).write(true).create(true); + #[cfg(unix)] + { + options.mode(0o600); } + let mut auth_dot_json = try_read_auth_json(auth_file)?; + + auth_dot_json.tokens.id_token = id_token.to_string(); + if let Some(access_token) = access_token { + auth_dot_json.tokens.access_token = access_token.to_string(); + } + if let Some(refresh_token) = refresh_token { + auth_dot_json.tokens.refresh_token = refresh_token.to_string(); + } + auth_dot_json.last_refresh = Utc::now(); + + let json_data = serde_json::to_string_pretty(&auth_dot_json)?; + { + let mut file = options.open(auth_file)?; + file.write_all(json_data.as_bytes())?; + file.flush()?; + } + Ok(auth_dot_json) } -fn is_expired(auth_dot_json: &AuthDotJson) -> bool { - let last_refresh = auth_dot_json.last_refresh; - last_refresh < Utc::now() - chrono::Duration::days(28) -} - -async fn try_refresh_token(auth_dot_json: &AuthDotJson) -> std::io::Result { +async fn try_refresh_token(refresh_token: String) -> std::io::Result { let refresh_request = RefreshRequest { client_id: CLIENT_ID, grant_type: "refresh_token", - refresh_token: auth_dot_json.tokens.refresh_token.clone(), + refresh_token, scope: "openid profile email", }; @@ -150,24 +269,25 @@ struct RefreshRequest { scope: &'static str, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] struct RefreshResponse { id_token: String, + access_token: Option, refresh_token: Option, } /// Expected structure for $CODEX_HOME/auth.json. -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] pub struct AuthDotJson { #[serde(rename = "OPENAI_API_KEY")] - pub openai_api_key: String, + pub openai_api_key: Option, pub tokens: TokenData, pub last_refresh: DateTime, } -#[derive(Deserialize, Serialize, Clone)] +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] pub struct TokenData { /// This is a JWT. pub id_token: String, @@ -177,5 +297,5 @@ pub struct TokenData { pub refresh_token: String, - pub account_id: String, + pub account_id: Option, } diff --git a/codex-rs/login/src/login_with_chatgpt.py b/codex-rs/login/src/login_with_chatgpt.py index ccb051c0..2dbf5be5 100644 --- a/codex-rs/login/src/login_with_chatgpt.py +++ b/codex-rs/login/src/login_with_chatgpt.py @@ -41,7 +41,6 @@ from typing import Any, Dict # for type hints REQUIRED_PORT = 1455 URL_BASE = f"http://localhost:{REQUIRED_PORT}" DEFAULT_ISSUER = "https://auth.openai.com" -DEFAULT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" EXIT_CODE_WHEN_ADDRESS_ALREADY_IN_USE = 13 @@ -58,7 +57,7 @@ class TokenData: class AuthBundle: """Aggregates authentication data produced after successful OAuth flow.""" - api_key: str + api_key: str | None token_data: TokenData last_refresh: str @@ -78,12 +77,18 @@ def main() -> None: eprint("ERROR: CODEX_HOME environment variable is not set") sys.exit(1) + client_id = os.getenv("CODEX_CLIENT_ID") + if not client_id: + eprint("ERROR: CODEX_CLIENT_ID environment variable is not set") + sys.exit(1) + # Spawn server. try: httpd = _ApiKeyHTTPServer( ("127.0.0.1", REQUIRED_PORT), _ApiKeyHTTPHandler, codex_home=codex_home, + client_id=client_id, verbose=args.verbose, ) except OSError as e: @@ -157,7 +162,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): return try: - auth_bundle, success_url = self._exchange_code_for_api_key(code) + auth_bundle, success_url = self._exchange_code(code) except Exception as exc: # noqa: BLE001 – propagate to client self.send_error(500, f"Token exchange failed: {exc}") return @@ -211,68 +216,22 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): if getattr(self.server, "verbose", False): # type: ignore[attr-defined] super().log_message(fmt, *args) - def _exchange_code_for_api_key(self, code: str) -> tuple[AuthBundle, str]: - """Perform token + token-exchange to obtain an OpenAI API key. + def _obtain_api_key( + self, + token_claims: Dict[str, Any], + access_claims: Dict[str, Any], + token_data: TokenData, + ) -> tuple[str | None, str | None]: + """Obtain an API key from the auth service. - Returns (AuthBundle, success_url). + Returns (api_key, success_url) if successful, None otherwise. """ - token_endpoint = f"{self.server.issuer}/oauth/token" - - # 1. Authorization-code -> (id_token, access_token, refresh_token) - data = urllib.parse.urlencode( - { - "grant_type": "authorization_code", - "code": code, - "redirect_uri": self.server.redirect_uri, - "client_id": self.server.client_id, - "code_verifier": self.server.pkce.code_verifier, - } - ).encode() - - token_data: TokenData - - with urllib.request.urlopen( - urllib.request.Request( - token_endpoint, - data=data, - method="POST", - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ) - ) as resp: - payload = json.loads(resp.read().decode()) - - # Extract chatgpt_account_id from id_token - id_token_parts = payload["id_token"].split(".") - if len(id_token_parts) != 3: - raise ValueError("Invalid ID token") - id_token_claims = _decode_jwt_segment(id_token_parts[1]) - auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) - chatgpt_account_id = auth_claims.get("chatgpt_account_id", "") - - token_data = TokenData( - id_token=payload["id_token"], - access_token=payload["access_token"], - refresh_token=payload["refresh_token"], - account_id=chatgpt_account_id, - ) - - access_token_parts = token_data.access_token.split(".") - if len(access_token_parts) != 3: - raise ValueError("Invalid access token") - - access_token_claims = _decode_jwt_segment(access_token_parts[1]) - - token_claims = id_token_claims.get("https://api.openai.com/auth", {}) - access_claims = access_token_claims.get("https://api.openai.com/auth", {}) - org_id = token_claims.get("organization_id") - if not org_id: - raise ValueError("Missing organization in id_token claims") - project_id = token_claims.get("project_id") - if not project_id: - raise ValueError("Missing project in id_token claims") + + if not org_id or not project_id: + return (None, None) random_id = secrets.token_hex(6) @@ -292,7 +251,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): exchanged_access_token: str with urllib.request.urlopen( urllib.request.Request( - token_endpoint, + self.server.token_endpoint, data=exchange_data, method="POST", headers={"Content-Type": "application/x-www-form-urlencoded"}, @@ -340,6 +299,65 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): except Exception as exc: # pragma: no cover – best-effort only eprint(f"Unable to redeem ChatGPT subscriber API credits: {exc}") + return (exchanged_access_token, success_url) + + def _exchange_code(self, code: str) -> tuple[AuthBundle, str]: + """Perform token + token-exchange to obtain an OpenAI API key. + + Returns (AuthBundle, success_url). + """ + + # 1. Authorization-code -> (id_token, access_token, refresh_token) + data = urllib.parse.urlencode( + { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.server.redirect_uri, + "client_id": self.server.client_id, + "code_verifier": self.server.pkce.code_verifier, + } + ).encode() + + token_data: TokenData + + with urllib.request.urlopen( + urllib.request.Request( + self.server.token_endpoint, + data=data, + method="POST", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + ) as resp: + payload = json.loads(resp.read().decode()) + + # Extract chatgpt_account_id from id_token + id_token_parts = payload["id_token"].split(".") + if len(id_token_parts) != 3: + raise ValueError("Invalid ID token") + id_token_claims = _decode_jwt_segment(id_token_parts[1]) + auth_claims = id_token_claims.get("https://api.openai.com/auth", {}) + chatgpt_account_id = auth_claims.get("chatgpt_account_id", "") + + token_data = TokenData( + id_token=payload["id_token"], + access_token=payload["access_token"], + refresh_token=payload["refresh_token"], + account_id=chatgpt_account_id, + ) + + access_token_parts = token_data.access_token.split(".") + if len(access_token_parts) != 3: + raise ValueError("Invalid access token") + + access_token_claims = _decode_jwt_segment(access_token_parts[1]) + + token_claims = id_token_claims.get("https://api.openai.com/auth", {}) + access_claims = access_token_claims.get("https://api.openai.com/auth", {}) + + exchanged_access_token, success_url = self._obtain_api_key( + token_claims, access_claims, token_data + ) + # Persist refresh_token/id_token for future use (redeem credits etc.) last_refresh_str = ( datetime.datetime.now(datetime.timezone.utc) @@ -353,7 +371,7 @@ class _ApiKeyHTTPHandler(http.server.BaseHTTPRequestHandler): last_refresh=last_refresh_str, ) - return (auth_bundle, success_url) + return (auth_bundle, success_url or f"{URL_BASE}/success") def request_shutdown(self) -> None: # shutdown() must be invoked from another thread to avoid @@ -413,6 +431,7 @@ class _ApiKeyHTTPServer(http.server.HTTPServer): request_handler_class: type[http.server.BaseHTTPRequestHandler], *, codex_home: str, + client_id: str, verbose: bool = False, ) -> None: super().__init__(server_address, request_handler_class, bind_and_activate=True) @@ -422,7 +441,8 @@ class _ApiKeyHTTPServer(http.server.HTTPServer): self.verbose: bool = verbose self.issuer: str = DEFAULT_ISSUER - self.client_id: str = DEFAULT_CLIENT_ID + self.token_endpoint: str = f"{self.issuer}/oauth/token" + self.client_id: str = client_id port = server_address[1] self.redirect_uri: str = f"http://localhost:{port}/auth/callback" self.pkce: PkceCodes = _generate_pkce() @@ -581,8 +601,8 @@ def maybe_redeem_credits( granted = redeem_data.get("granted_chatgpt_subscriber_api_credits", 0) if granted and granted > 0: eprint( - f"""Thanks for being a ChatGPT {'Plus' if plan_type=='plus' else 'Pro'} subscriber! -If you haven't already redeemed, you should receive {'$5' if plan_type=='plus' else '$50'} in API credits. + f"""Thanks for being a ChatGPT {"Plus" if plan_type == "plus" else "Pro"} subscriber! +If you haven't already redeemed, you should receive {"$5" if plan_type == "plus" else "$50"} in API credits. Credits: https://platform.openai.com/settings/organization/billing/credit-grants More info: https://help.openai.com/en/articles/11381614""", diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 7bc041a5..424b5ac2 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -6,16 +6,14 @@ use app::App; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config_types::SandboxMode; -use codex_core::openai_api_key::OPENAI_API_KEY_ENV_VAR; -use codex_core::openai_api_key::get_openai_api_key; -use codex_core::openai_api_key::set_openai_api_key; use codex_core::protocol::AskForApproval; use codex_core::util::is_inside_git_repo; -use codex_login::try_read_openai_api_key; +use codex_login::load_auth; use log_layer::TuiLogLayer; use std::fs::OpenOptions; use std::io::Write; use std::path::PathBuf; +use tracing::error; use tracing_appender::non_blocking; use tracing_subscriber::EnvFilter; use tracing_subscriber::prelude::*; @@ -140,7 +138,7 @@ pub async fn run_main( .with(tui_layer) .try_init(); - let show_login_screen = should_show_login_screen(&config).await; + let show_login_screen = should_show_login_screen(&config); if show_login_screen { std::io::stdout() .write_all(b"No API key detected.\nLogin with your ChatGPT account? [Yn] ")?; @@ -153,8 +151,8 @@ pub async fn run_main( } // Spawn a task to run the login command. // Block until the login command is finished. - let new_key = codex_login::login_with_chatgpt(&config.codex_home, false).await?; - set_openai_api_key(new_key); + codex_login::login_with_chatgpt(&config.codex_home, false).await?; + std::io::stdout().write_all(b"Login successful.\n")?; } @@ -217,28 +215,21 @@ fn restore() { } } -async fn should_show_login_screen(config: &Config) -> bool { - if is_in_need_of_openai_api_key(config) { +#[allow(clippy::unwrap_used)] +fn should_show_login_screen(config: &Config) -> bool { + if config.model_provider.requires_auth { // Reading the OpenAI API key is an async operation because it may need // to refresh the token. Block on it. let codex_home = config.codex_home.clone(); - if let Ok(openai_api_key) = try_read_openai_api_key(&codex_home).await { - set_openai_api_key(openai_api_key); - false - } else { - true + match load_auth(&codex_home) { + Ok(Some(_)) => false, + Ok(None) => true, + Err(err) => { + error!("Failed to read auth.json: {err}"); + true + } } } else { false } } - -fn is_in_need_of_openai_api_key(config: &Config) -> bool { - let is_using_openai_key = config - .model_provider - .env_key - .as_ref() - .map(|s| s == OPENAI_API_KEY_ENV_VAR) - .unwrap_or(false); - is_using_openai_key && get_openai_api_key().is_none() -}