From c221eab0b5cad59ce3dafebf7ca630f217263cc6 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Mon, 7 Jul 2025 13:09:16 -0700 Subject: [PATCH] feat: support custom HTTP headers for model providers (#1473) This adds support for two new model provider config options: - `http_headers` for hardcoded (key, value) pairs - `env_http_headers` for headers whose values should be read from environment variables This also updates the built-in `openai` provider to use this feature to set the following headers: - `originator` => `codex_cli_rs` - `version` => [CLI version] - `OpenAI-Organization` => `OPENAI_ORGANIZATION` env var - `OpenAI-Project` => `OPENAI_PROJECT` env var for consistency with the TypeScript implementation: https://github.com/openai/codex/blob/bd5a9e8ba96c7d9c58ecaf5e61ec62d14ac6378d/codex-cli/src/utils/agent/agent-loop.ts#L321-L329 While here, this also consolidates some logic that was duplicated across `client.rs` and `chat_completions.rs` by introducing `ModelProviderInfo.create_request_builder()`. Resolves https://github.com/openai/codex/discussions/1152 --- codex-rs/config.md | 16 +++ codex-rs/core/src/chat_completions.rs | 12 +-- codex-rs/core/src/client.rs | 27 ++--- codex-rs/core/src/config.rs | 2 + codex-rs/core/src/model_provider_info.rs | 113 +++++++++++++++++++- codex-rs/core/tests/previous_response_id.rs | 2 + codex-rs/core/tests/stream_no_completed.rs | 2 + 7 files changed, 147 insertions(+), 27 deletions(-) diff --git a/codex-rs/config.md b/codex-rs/config.md index f7e72581..2eaae760 100644 --- a/codex-rs/config.md +++ b/codex-rs/config.md @@ -76,6 +76,22 @@ env_key = "AZURE_OPENAI_API_KEY" # Or "OPENAI_API_KEY", whichever you use. query_params = { api-version = "2025-04-01-preview" } ``` +It is also possible to configure a provider to include extra HTTP headers with a request. These can be hardcoded values (`http_headers`) or values read from environment variables (`env_http_headers`): + +```toml +[model_providers.example] +# name, base_url, ... + +# This will add the HTTP header `X-Example-Header` with value `example-value` +# to each request to the model provider. +http_headers = { "X-Example-Header" = "example-value" } + +# This will add the HTTP header `X-Example-Features` with the value of the +# `EXAMPLE_FEATURES` environment variable to each request to the model provider +# _if_ the environment variable is set and its value is non-empty. +env_http_headers = { "X-Example-Features": "EXAMPLE_FEATURES" } +``` + ## model_provider Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index ce2ab053..816fc80f 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -114,22 +114,18 @@ pub(crate) async fn stream_chat_completions( "tools": tools_json, }); - let url = provider.get_full_url(); - debug!( - "POST to {url}: {}", + "POST to {}: {}", + provider.get_full_url(), serde_json::to_string_pretty(&payload).unwrap_or_default() ); - let api_key = provider.api_key()?; let mut attempt = 0; loop { attempt += 1; - let mut req_builder = client.post(&url); - if let Some(api_key) = &api_key { - req_builder = req_builder.bearer_auth(api_key.clone()); - } + let req_builder = provider.create_request_builder(client)?; + let res = req_builder .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 91a84bf3..9dcb7289 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -26,7 +26,6 @@ use crate::client_common::create_reasoning_param_for_request; use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::CodexErr; -use crate::error::EnvVarError; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::flags::OPENAI_REQUEST_MAX_RETRIES; @@ -123,28 +122,24 @@ impl ModelClient { stream: true, }; - let url = self.provider.get_full_url(); - trace!("POST to {url}: {}", serde_json::to_string(&payload)?); + trace!( + "POST to {}: {}", + self.provider.get_full_url(), + serde_json::to_string(&payload)? + ); let mut attempt = 0; loop { attempt += 1; - let api_key = self.provider.api_key()?.ok_or_else(|| { - CodexErr::EnvVar(EnvVarError { - var: self.provider.env_key.clone().unwrap_or_default(), - instructions: None, - }) - })?; - let res = self - .client - .post(&url) - .bearer_auth(api_key) + let req_builder = self + .provider + .create_request_builder(&self.client)? .header("OpenAI-Beta", "responses=experimental") .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload) - .send() - .await; + .json(&payload); + + let res = req_builder.send().await; match res { Ok(resp) if resp.status().is_success() => { let (tx_event, rx_event) = mpsc::channel::>(16); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 240c6eaf..18c4ec23 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -659,6 +659,8 @@ disable_response_storage = true wire_api: crate::WireApi::Chat, env_key_instructions: None, query_params: None, + http_headers: None, + env_http_headers: None, }; let model_provider_map = { let mut model_provider_map = built_in_model_providers(); diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index b8326ace..5d51b10f 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -13,6 +13,10 @@ use std::env::VarError; use crate::error::EnvVarError; use crate::openai_api_key::get_openai_api_key; +/// Value for the `OpenAI-Originator` header that is sent with requests to +/// OpenAI. +const OPENAI_ORIGINATOR_HEADER: &str = "codex_cli_rs"; + /// Wire protocol that the provider speaks. Most third-party services only /// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI /// itself (and a handful of others) additionally expose the more modern @@ -50,9 +54,43 @@ pub struct ModelProviderInfo { /// Optional query parameters to append to the base URL. pub query_params: Option>, + + /// Additional HTTP headers to include in requests to this provider where + /// the (key, value) pairs are the header name and value. + pub http_headers: Option>, + + /// Optional HTTP headers to include in requests to this provider where the + /// (key, value) pairs are the header name and _environment variable_ whose + /// value should be used. If the environment variable is not set, or the + /// value is empty, the header will not be included in the request. + pub env_http_headers: Option>, } impl ModelProviderInfo { + /// Construct a `POST` RequestBuilder for the given URL using the provided + /// reqwest Client applying: + /// • provider-specific headers (static + env based) + /// • Bearer auth header when an API key is available. + /// + /// When `require_api_key` is true and the provider declares an `env_key` + /// but the variable is missing/empty, returns an [`Err`] identical to the + /// one produced by [`ModelProviderInfo::api_key`]. + pub fn create_request_builder<'a>( + &'a self, + client: &'a reqwest::Client, + ) -> crate::error::Result { + let api_key = self.api_key()?; + + let url = self.get_full_url(); + + let mut builder = client.post(url); + if let Some(key) = api_key { + builder = builder.bearer_auth(key); + } + + Ok(self.apply_http_headers(builder)) + } + pub(crate) fn get_full_url(&self) -> String { let query_string = self .query_params @@ -71,13 +109,33 @@ impl ModelProviderInfo { WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), } } -} -impl ModelProviderInfo { + /// Apply provider-specific HTTP headers (both static and environment-based) + /// onto an existing `reqwest::RequestBuilder` and return the updated + /// builder. + fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(extra) = &self.http_headers { + for (k, v) in extra { + builder = builder.header(k, v); + } + } + + if let Some(env_headers) = &self.env_http_headers { + for (header, env_var) in env_headers { + if let Ok(val) = std::env::var(env_var) { + if !val.trim().is_empty() { + builder = builder.header(header, val); + } + } + } + } + builder + } + /// If `env_key` is Some, returns the API key for this provider if present /// (and non-empty) in the environment. If `env_key` is required but /// cannot be found, returns an error. - pub fn api_key(&self) -> crate::error::Result> { + fn api_key(&self) -> crate::error::Result> { match &self.env_key { Some(env_key) => { let env_value = if env_key == crate::openai_api_key::OPENAI_API_KEY_ENV_VAR { @@ -123,6 +181,22 @@ pub fn built_in_model_providers() -> HashMap { env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()), wire_api: WireApi::Responses, query_params: None, + http_headers: Some( + [ + ("originator".to_string(), OPENAI_ORIGINATOR_HEADER.to_string()), + ("version".to_string(), env!("CARGO_PKG_VERSION").to_string()), + ] + .into_iter() + .collect(), + ), + env_http_headers: Some( + [ + ("OpenAI-Organization".to_string(), "OPENAI_ORGANIZATION".to_string()), + ("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()), + ] + .into_iter() + .collect(), + ), }, ), ] @@ -135,6 +209,7 @@ pub fn built_in_model_providers() -> HashMap { mod tests { #![allow(clippy::unwrap_used)] use super::*; + use pretty_assertions::assert_eq; #[test] fn test_deserialize_ollama_model_provider_toml() { @@ -149,6 +224,8 @@ base_url = "http://localhost:11434/v1" env_key_instructions: None, wire_api: WireApi::Chat, query_params: None, + http_headers: None, + env_http_headers: None, }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); @@ -172,6 +249,36 @@ query_params = { api-version = "2025-04-01-preview" } query_params: Some(maplit::hashmap! { "api-version".to_string() => "2025-04-01-preview".to_string(), }), + http_headers: None, + env_http_headers: None, + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); + } + + #[test] + fn test_deserialize_example_model_provider_toml() { + let azure_provider_toml = r#" +name = "Example" +base_url = "https://example.com" +env_key = "API_KEY" +http_headers = { "X-Example-Header" = "example-value" } +env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } + "#; + let expected_provider = ModelProviderInfo { + name: "Example".into(), + base_url: "https://example.com".into(), + env_key: Some("API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + http_headers: Some(maplit::hashmap! { + "X-Example-Header".to_string() => "example-value".to_string(), + }), + env_http_headers: Some(maplit::hashmap! { + "X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(), + }), }; let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index e072e9c3..a23b119c 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -108,6 +108,8 @@ async fn keeps_previous_response_id_between_tasks() { env_key_instructions: None, wire_api: codex_core::WireApi::Responses, query_params: None, + http_headers: None, + env_http_headers: None, }; // Init session diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index c1ef10c3..43e533bd 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -97,6 +97,8 @@ async fn retries_on_early_close() { env_key_instructions: None, wire_api: codex_core::WireApi::Responses, query_params: None, + http_headers: None, + env_http_headers: None, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());