diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 5ede774b..b1dee853 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -120,7 +120,7 @@ pub(crate) async fn stream_chat_completions( debug!( "POST to {}: {}", - provider.get_full_url(), + provider.get_full_url(&None), serde_json::to_string_pretty(&payload).unwrap_or_default() ); @@ -129,7 +129,7 @@ pub(crate) async fn stream_chat_completions( loop { attempt += 1; - let req_builder = provider.create_request_builder(client)?; + let req_builder = provider.create_request_builder(client, &None).await?; let res = req_builder .header(reqwest::header::ACCEPT, "text/event-stream") diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index b9ea6b13..1a8ae94f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -30,7 +30,6 @@ use crate::config::Config; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::CodexErr; -use crate::error::EnvVarError; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_provider_info::ModelProviderInfo; @@ -122,24 +121,11 @@ impl ModelClient { return stream_from_fixture(path, self.provider.clone()).await; } - let auth = self.auth.as_ref().ok_or_else(|| { - CodexErr::EnvVar(EnvVarError { - var: "OPENAI_API_KEY".to_string(), - instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".to_string()), - }) - })?; + let auth = self.auth.clone(); - let store = prompt.store && auth.mode != AuthMode::ChatGPT; + let auth_mode = auth.as_ref().map(|a| a.mode); - let base_url = match self.provider.base_url.clone() { - Some(url) => url, - None => match auth.mode { - AuthMode::ChatGPT => "https://chatgpt.com/backend-api/codex".to_string(), - AuthMode::ApiKey => "https://api.openai.com/v1".to_string(), - }, - }; - - let token = auth.get_token().await?; + let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT); let full_instructions = prompt.get_full_instructions(&self.config.model); let tools_json = create_tools_json_for_responses_api( @@ -180,35 +166,36 @@ impl ModelClient { include, }; - trace!( - "POST to {}: {}", - self.provider.get_full_url(), - serde_json::to_string(&payload)? - ); - let mut attempt = 0; let max_retries = self.provider.request_max_retries(); + trace!( + "POST to {}: {}", + self.provider.get_full_url(&auth), + serde_json::to_string(&payload)? + ); + loop { attempt += 1; let mut req_builder = self - .client - .post(format!("{base_url}/responses")) + .provider + .create_request_builder(&self.client, &auth) + .await?; + + req_builder = req_builder .header("OpenAI-Beta", "responses=experimental") .header("session_id", self.session_id.to_string()) - .bearer_auth(&token) .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload); - if auth.mode == AuthMode::ChatGPT { - if let Some(account_id) = auth.get_account_id().await { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } + if let Some(auth) = auth.as_ref() + && auth.mode == AuthMode::ChatGPT + && let Some(account_id) = auth.get_account_id().await + { + req_builder = req_builder.header("chatgpt-account-id", account_id); } - req_builder = self.provider.apply_http_headers(req_builder); - let originator = self .config .internal_originator diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 29366377..49478660 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -5,8 +5,11 @@ //! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` //! key. These override or extend the defaults at runtime. +use codex_login::AuthMode; +use codex_login::CodexAuth; use serde::Deserialize; use serde::Serialize; +use std::borrow::Cow; use std::collections::HashMap; use std::env::VarError; use std::time::Duration; @@ -88,25 +91,30 @@ impl ModelProviderInfo { /// When `require_api_key` is true and the provider declares an `env_key` /// but the variable is missing/empty, returns an [`Err`] identical to the /// one produced by [`ModelProviderInfo::api_key`]. - pub fn create_request_builder<'a>( + pub async fn create_request_builder<'a>( &'a self, client: &'a reqwest::Client, + auth: &Option, ) -> crate::error::Result { - let url = self.get_full_url(); + let auth: Cow<'_, Option> = if auth.is_some() { + Cow::Borrowed(auth) + } else { + Cow::Owned(self.get_fallback_auth()?) + }; + + let url = self.get_full_url(&auth); let mut builder = client.post(url); - let api_key = self.api_key()?; - if let Some(key) = api_key { - builder = builder.bearer_auth(key); + if let Some(auth) = auth.as_ref() { + builder = builder.bearer_auth(auth.get_token().await?); } Ok(self.apply_http_headers(builder)) } - pub(crate) fn get_full_url(&self) -> String { - let query_string = self - .query_params + fn get_query_string(&self) -> String { + self.query_params .as_ref() .map_or_else(String::new, |params| { let full_params = params @@ -115,16 +123,29 @@ impl ModelProviderInfo { .collect::>() .join("&"); format!("?{full_params}") - }); + }) + } + + pub(crate) fn get_full_url(&self, auth: &Option) -> String { + let default_base_url = if matches!( + auth, + Some(CodexAuth { + mode: AuthMode::ChatGPT, + .. + }) + ) { + "https://chatgpt.com/backend-api/codex" + } else { + "https://api.openai.com/v1" + }; + let query_string = self.get_query_string(); let base_url = self .base_url .clone() - .unwrap_or("https://api.openai.com/v1".to_string()); + .unwrap_or(default_base_url.to_string()); match self.wire_api { - WireApi::Responses => { - format!("{base_url}/responses{query_string}") - } + WireApi::Responses => format!("{base_url}/responses{query_string}"), WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), } } @@ -132,10 +153,7 @@ impl ModelProviderInfo { /// Apply provider-specific HTTP headers (both static and environment-based) /// onto an existing `reqwest::RequestBuilder` and return the updated /// builder. - pub fn apply_http_headers( - &self, - mut builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { + fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { if let Some(extra) = &self.http_headers { for (k, v) in extra { builder = builder.header(k, v); @@ -157,7 +175,7 @@ impl ModelProviderInfo { /// If `env_key` is Some, returns the API key for this provider if present /// (and non-empty) in the environment. If `env_key` is required but /// cannot be found, returns an error. - fn api_key(&self) -> crate::error::Result> { + pub fn api_key(&self) -> crate::error::Result> { match &self.env_key { Some(env_key) => { let env_value = std::env::var(env_key); @@ -198,6 +216,14 @@ impl ModelProviderInfo { .map(Duration::from_millis) .unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS)) } + + fn get_fallback_auth(&self) -> crate::error::Result> { + let api_key = self.api_key()?; + if let Some(api_key) = api_key { + return Ok(Some(CodexAuth::from_api_key(api_key))); + } + Ok(None) + } } /// Built-in default provider list. diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index a22a9438..06a110ea 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -4,6 +4,7 @@ use chrono::Utc; use codex_core::Codex; use codex_core::CodexSpawnOk; use codex_core::ModelProviderInfo; +use codex_core::WireApi; use codex_core::built_in_model_providers; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; @@ -21,8 +22,10 @@ use tempfile::TempDir; use wiremock::Mock; use wiremock::MockServer; use wiremock::ResponseTemplate; +use wiremock::matchers::header_regex; use wiremock::matchers::method; use wiremock::matchers::path; +use wiremock::matchers::query_param; /// Build minimal SSE stream with completed marker using the JSON fixture. fn sse_completed(id: &str) -> String { @@ -376,6 +379,81 @@ async fn includes_user_instructions_message_in_request() { .starts_with("be nice") ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn azure_overrides_assign_properties_used_for_responses_url() { + #![allow(clippy::unwrap_used)] + + let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" }; + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + // Expect POST to /openai/responses with api-version query param + Mock::given(method("POST")) + .and(path("/openai/responses")) + .and(query_param("api-version", "2025-04-01-preview")) + .and(header_regex("Custom-Header", "Value")) + .and(header_regex( + "Authorization", + format!( + "Bearer {}", + std::env::var(existing_env_var_with_random_value).unwrap() + ) + .as_str(), + )) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "custom".to_string(), + base_url: Some(format!("{}/openai", server.uri())), + // Reuse the existing environment variable to avoid using unsafe code + env_key: Some(existing_env_var_with_random_value.to_string()), + query_params: Some(std::collections::HashMap::from([( + "api-version".to_string(), + "2025-04-01-preview".to_string(), + )])), + env_key_instructions: None, + wire_api: WireApi::Responses, + http_headers: Some(std::collections::HashMap::from([( + "Custom-Header".to_string(), + "Value".to_string(), + )])), + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_auth: false, + }; + + // Init session + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider = provider; + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let CodexSpawnOk { codex, .. } = Codex::spawn(config, None, ctrl_c.clone()).await.unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; +} + fn auth_from_token(id_token: String) -> CodexAuth { CodexAuth::new( None, diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index 3d55c202..35f67e71 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -22,7 +22,7 @@ const SOURCE_FOR_PYTHON_SERVER: &str = include_str!("./login_with_chatgpt.py"); const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Copy)] pub enum AuthMode { ApiKey, ChatGPT,