diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 9748cde7..ed05fb5d 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -623,7 +623,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), - requires_auth: false, + requires_openai_auth: false, }; let events = collect_events( @@ -683,7 +683,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), - requires_auth: false, + requires_openai_auth: false, }; let events = collect_events(&[sse1.as_bytes()], provider).await; @@ -786,7 +786,7 @@ mod tests { request_max_retries: Some(0), stream_max_retries: Some(0), stream_idle_timeout_ms: Some(1000), - requires_auth: false, + requires_openai_auth: false, }; let out = run_sse(evs, provider).await; diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index f48cc934..63a2e594 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -842,7 +842,7 @@ disable_response_storage = true request_max_retries: Some(4), stream_max_retries: Some(10), stream_idle_timeout_ms: Some(300_000), - requires_auth: false, + requires_openai_auth: false, }; let model_provider_map = { let mut model_provider_map = built_in_model_providers(); diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index db369df3..a9802111 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -9,7 +9,6 @@ use codex_login::AuthMode; use codex_login::CodexAuth; use serde::Deserialize; use serde::Serialize; -use std::borrow::Cow; use std::collections::HashMap; use std::env::VarError; use std::time::Duration; @@ -79,7 +78,7 @@ pub struct ModelProviderInfo { /// Whether this provider requires some form of standard authentication (API key, ChatGPT token). #[serde(default)] - pub requires_auth: bool, + pub requires_openai_auth: bool, } impl ModelProviderInfo { @@ -87,26 +86,32 @@ impl ModelProviderInfo { /// reqwest Client applying: /// • provider-specific headers (static + env based) /// • Bearer auth header when an API key is available. + /// • Auth token for OAuth. /// - /// When `require_api_key` is true and the provider declares an `env_key` - /// but the variable is missing/empty, returns an [`Err`] identical to the + /// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the /// one produced by [`ModelProviderInfo::api_key`]. pub async fn create_request_builder<'a>( &'a self, client: &'a reqwest::Client, auth: &Option, ) -> crate::error::Result { - let auth: Cow<'_, Option> = if auth.is_some() { - Cow::Borrowed(auth) - } else { - Cow::Owned(self.get_fallback_auth()?) + let effective_auth = match self.api_key() { + Ok(Some(key)) => Some(CodexAuth::from_api_key(key)), + Ok(None) => auth.clone(), + Err(err) => { + if auth.is_some() { + auth.clone() + } else { + return Err(err); + } + } }; - let url = self.get_full_url(&auth); + let url = self.get_full_url(&effective_auth); let mut builder = client.post(url); - if let Some(auth) = auth.as_ref() { + if let Some(auth) = effective_auth.as_ref() { builder = builder.bearer_auth(auth.get_token().await?); } @@ -216,14 +221,6 @@ impl ModelProviderInfo { .map(Duration::from_millis) .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS)) } - - fn get_fallback_auth(&self) -> crate::error::Result> { - let api_key = self.api_key()?; - if let Some(api_key) = api_key { - return Ok(Some(CodexAuth::from_api_key(api_key))); - } - Ok(None) - } } const DEFAULT_OLLAMA_PORT: u32 = 11434; @@ -275,7 +272,7 @@ pub fn built_in_model_providers() -> HashMap { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: true, + requires_openai_auth: true, }, ), (BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()), @@ -319,7 +316,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: false, + requires_openai_auth: false, } } @@ -347,7 +344,7 @@ base_url = "http://localhost:11434/v1" request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: false, + requires_openai_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -376,7 +373,7 @@ query_params = { api-version = "2025-04-01-preview" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: false, + requires_openai_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -408,7 +405,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: false, + requires_openai_auth: false, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index 00f91a87..60eb9224 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -458,7 +458,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() { request_max_retries: None, stream_max_retries: None, stream_idle_timeout_ms: None, - requires_auth: false, + requires_openai_auth: false, }; // Init session @@ -481,6 +481,86 @@ async fn azure_overrides_assign_properties_used_for_responses_url() { wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn env_var_overrides_loaded_auth() { + #![allow(clippy::unwrap_used)] + + let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" }; + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + // Expect POST to /openai/responses with api-version query param + Mock::given(method("POST")) + .and(path("/openai/responses")) + .and(query_param("api-version", "2025-04-01-preview")) + .and(header_regex("Custom-Header", "Value")) + .and(header_regex( + "Authorization", + format!( + "Bearer {}", + std::env::var(existing_env_var_with_random_value).unwrap() + ) + .as_str(), + )) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "custom".to_string(), + base_url: Some(format!("{}/openai", server.uri())), + // Reuse the existing environment variable to avoid using unsafe code + env_key: Some(existing_env_var_with_random_value.to_string()), + query_params: Some(std::collections::HashMap::from([( + "api-version".to_string(), + "2025-04-01-preview".to_string(), + )])), + env_key_instructions: None, + wire_api: WireApi::Responses, + http_headers: Some(std::collections::HashMap::from([( + "Custom-Header".to_string(), + "Value".to_string(), + )])), + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = provider; + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn( + config, + Some(auth_from_token("Default Access Token".to_string())), + ctrl_c.clone(), + ) + .await + .unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; +} + fn auth_from_token(id_token: String) -> CodexAuth { CodexAuth::new( None, diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 3e30d937..8a4216b1 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -90,7 +90,7 @@ async fn retries_on_early_close() { request_max_retries: Some(0), stream_max_retries: Some(1), stream_idle_timeout_ms: Some(2000), - requires_auth: false, + requires_openai_auth: false, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 50535e59..0228a568 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -287,7 +287,7 @@ fn restore() { #[allow(clippy::unwrap_used)] fn should_show_login_screen(config: &Config) -> bool { - if config.model_provider.requires_auth { + if config.model_provider.requires_openai_auth { // Reading the OpenAI API key is an async operation because it may need // to refresh the token. Block on it. let codex_home = config.codex_home.clone();