Prefer env var auth over default codex auth (#1861)
## Summary - Prioritize provider-specific API keys over default Codex auth when building requests - Add test to ensure provider env var auth overrides default auth ## Testing - `just fmt` - `just fix` *(fails: `let` expressions in this position are unstable)* - `cargo test --all-features` *(fails: `let` expressions in this position are unstable)* ------ https://chatgpt.com/codex/tasks/task_i_68926a104f7483208f2c8fd36763e0e3
This commit is contained in:
@@ -623,7 +623,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let events = collect_events(
|
let events = collect_events(
|
||||||
@@ -683,7 +683,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
let events = collect_events(&[sse1.as_bytes()], provider).await;
|
||||||
@@ -786,7 +786,7 @@ mod tests {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(0),
|
stream_max_retries: Some(0),
|
||||||
stream_idle_timeout_ms: Some(1000),
|
stream_idle_timeout_ms: Some(1000),
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let out = run_sse(evs, provider).await;
|
let out = run_sse(evs, provider).await;
|
||||||
|
|||||||
@@ -842,7 +842,7 @@ disable_response_storage = true
|
|||||||
request_max_retries: Some(4),
|
request_max_retries: Some(4),
|
||||||
stream_max_retries: Some(10),
|
stream_max_retries: Some(10),
|
||||||
stream_idle_timeout_ms: Some(300_000),
|
stream_idle_timeout_ms: Some(300_000),
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
let model_provider_map = {
|
let model_provider_map = {
|
||||||
let mut model_provider_map = built_in_model_providers();
|
let mut model_provider_map = built_in_model_providers();
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ use codex_login::AuthMode;
|
|||||||
use codex_login::CodexAuth;
|
use codex_login::CodexAuth;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::borrow::Cow;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env::VarError;
|
use std::env::VarError;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@@ -79,7 +78,7 @@ pub struct ModelProviderInfo {
|
|||||||
|
|
||||||
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
/// Whether this provider requires some form of standard authentication (API key, ChatGPT token).
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub requires_auth: bool,
|
pub requires_openai_auth: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelProviderInfo {
|
impl ModelProviderInfo {
|
||||||
@@ -87,26 +86,32 @@ impl ModelProviderInfo {
|
|||||||
/// reqwest Client applying:
|
/// reqwest Client applying:
|
||||||
/// • provider-specific headers (static + env based)
|
/// • provider-specific headers (static + env based)
|
||||||
/// • Bearer auth header when an API key is available.
|
/// • Bearer auth header when an API key is available.
|
||||||
|
/// • Auth token for OAuth.
|
||||||
///
|
///
|
||||||
/// When `require_api_key` is true and the provider declares an `env_key`
|
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
|
||||||
/// but the variable is missing/empty, returns an [`Err`] identical to the
|
|
||||||
/// one produced by [`ModelProviderInfo::api_key`].
|
/// one produced by [`ModelProviderInfo::api_key`].
|
||||||
pub async fn create_request_builder<'a>(
|
pub async fn create_request_builder<'a>(
|
||||||
&'a self,
|
&'a self,
|
||||||
client: &'a reqwest::Client,
|
client: &'a reqwest::Client,
|
||||||
auth: &Option<CodexAuth>,
|
auth: &Option<CodexAuth>,
|
||||||
) -> crate::error::Result<reqwest::RequestBuilder> {
|
) -> crate::error::Result<reqwest::RequestBuilder> {
|
||||||
let auth: Cow<'_, Option<CodexAuth>> = if auth.is_some() {
|
let effective_auth = match self.api_key() {
|
||||||
Cow::Borrowed(auth)
|
Ok(Some(key)) => Some(CodexAuth::from_api_key(key)),
|
||||||
} else {
|
Ok(None) => auth.clone(),
|
||||||
Cow::Owned(self.get_fallback_auth()?)
|
Err(err) => {
|
||||||
|
if auth.is_some() {
|
||||||
|
auth.clone()
|
||||||
|
} else {
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = self.get_full_url(&auth);
|
let url = self.get_full_url(&effective_auth);
|
||||||
|
|
||||||
let mut builder = client.post(url);
|
let mut builder = client.post(url);
|
||||||
|
|
||||||
if let Some(auth) = auth.as_ref() {
|
if let Some(auth) = effective_auth.as_ref() {
|
||||||
builder = builder.bearer_auth(auth.get_token().await?);
|
builder = builder.bearer_auth(auth.get_token().await?);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,14 +221,6 @@ impl ModelProviderInfo {
|
|||||||
.map(Duration::from_millis)
|
.map(Duration::from_millis)
|
||||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_fallback_auth(&self) -> crate::error::Result<Option<CodexAuth>> {
|
|
||||||
let api_key = self.api_key()?;
|
|
||||||
if let Some(api_key) = api_key {
|
|
||||||
return Ok(Some(CodexAuth::from_api_key(api_key)));
|
|
||||||
}
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||||
@@ -275,7 +272,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: true,
|
requires_openai_auth: true,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||||
@@ -319,7 +316,7 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -347,7 +344,7 @@ base_url = "http://localhost:11434/v1"
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
@@ -376,7 +373,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
@@ -408,7 +405,7 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
|
|||||||
@@ -458,7 +458,7 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
|
|||||||
request_max_retries: None,
|
request_max_retries: None,
|
||||||
stream_max_retries: None,
|
stream_max_retries: None,
|
||||||
stream_idle_timeout_ms: None,
|
stream_idle_timeout_ms: None,
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Init session
|
// Init session
|
||||||
@@ -481,6 +481,86 @@ async fn azure_overrides_assign_properties_used_for_responses_url() {
|
|||||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn env_var_overrides_loaded_auth() {
|
||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
|
||||||
|
let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };
|
||||||
|
|
||||||
|
// Mock server
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
|
||||||
|
// First request – must NOT include `previous_response_id`.
|
||||||
|
let first = ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||||
|
|
||||||
|
// Expect POST to /openai/responses with api-version query param
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/openai/responses"))
|
||||||
|
.and(query_param("api-version", "2025-04-01-preview"))
|
||||||
|
.and(header_regex("Custom-Header", "Value"))
|
||||||
|
.and(header_regex(
|
||||||
|
"Authorization",
|
||||||
|
format!(
|
||||||
|
"Bearer {}",
|
||||||
|
std::env::var(existing_env_var_with_random_value).unwrap()
|
||||||
|
)
|
||||||
|
.as_str(),
|
||||||
|
))
|
||||||
|
.respond_with(first)
|
||||||
|
.expect(1)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let provider = ModelProviderInfo {
|
||||||
|
name: "custom".to_string(),
|
||||||
|
base_url: Some(format!("{}/openai", server.uri())),
|
||||||
|
// Reuse the existing environment variable to avoid using unsafe code
|
||||||
|
env_key: Some(existing_env_var_with_random_value.to_string()),
|
||||||
|
query_params: Some(std::collections::HashMap::from([(
|
||||||
|
"api-version".to_string(),
|
||||||
|
"2025-04-01-preview".to_string(),
|
||||||
|
)])),
|
||||||
|
env_key_instructions: None,
|
||||||
|
wire_api: WireApi::Responses,
|
||||||
|
http_headers: Some(std::collections::HashMap::from([(
|
||||||
|
"Custom-Header".to_string(),
|
||||||
|
"Value".to_string(),
|
||||||
|
)])),
|
||||||
|
env_http_headers: None,
|
||||||
|
request_max_retries: None,
|
||||||
|
stream_max_retries: None,
|
||||||
|
stream_idle_timeout_ms: None,
|
||||||
|
requires_openai_auth: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Init session
|
||||||
|
let codex_home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
config.model_provider = provider;
|
||||||
|
|
||||||
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
|
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||||
|
config,
|
||||||
|
Some(auth_from_token("Default Access Token".to_string())),
|
||||||
|
ctrl_c.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "hello".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
}
|
||||||
|
|
||||||
fn auth_from_token(id_token: String) -> CodexAuth {
|
fn auth_from_token(id_token: String) -> CodexAuth {
|
||||||
CodexAuth::new(
|
CodexAuth::new(
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ async fn retries_on_early_close() {
|
|||||||
request_max_retries: Some(0),
|
request_max_retries: Some(0),
|
||||||
stream_max_retries: Some(1),
|
stream_max_retries: Some(1),
|
||||||
stream_idle_timeout_ms: Some(2000),
|
stream_idle_timeout_ms: Some(2000),
|
||||||
requires_auth: false,
|
requires_openai_auth: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ fn restore() {
|
|||||||
|
|
||||||
#[allow(clippy::unwrap_used)]
|
#[allow(clippy::unwrap_used)]
|
||||||
fn should_show_login_screen(config: &Config) -> bool {
|
fn should_show_login_screen(config: &Config) -> bool {
|
||||||
if config.model_provider.requires_auth {
|
if config.model_provider.requires_openai_auth {
|
||||||
// Reading the OpenAI API key is an async operation because it may need
|
// Reading the OpenAI API key is an async operation because it may need
|
||||||
// to refresh the token. Block on it.
|
// to refresh the token. Block on it.
|
||||||
let codex_home = config.codex_home.clone();
|
let codex_home = config.codex_home.clone();
|
||||||
|
|||||||
Reference in New Issue
Block a user