diff --git a/codex-rs/config.md b/codex-rs/config.md index de9e4ec9..f7e72581 100644 --- a/codex-rs/config.md +++ b/codex-rs/config.md @@ -41,8 +41,11 @@ base_url = "https://api.openai.com/v1" # using Codex with this provider. The value of the environment variable must be # non-empty and will be used in the `Bearer TOKEN` HTTP header for the POST request. env_key = "OPENAI_API_KEY" -# Valid values for wire_api are "chat" and "responses". +# Valid values for wire_api are "chat" and "responses". Defaults to "chat" if omitted. wire_api = "chat" +# If necessary, extra query params that need to be added to the URL. +# See the Azure example below. +query_params = {} ``` Note this makes it possible to use Codex CLI with non-OpenAI models, so long as they use a wire API that is compatible with the OpenAI chat completions API. For example, you could define the following provider to use Codex CLI with Ollama running locally: @@ -51,7 +54,6 @@ Note this makes it possible to use Codex CLI with non-OpenAI models, so long as [model_providers.ollama] name = "Ollama" base_url = "http://localhost:11434/v1" -wire_api = "chat" ``` Or a third-party provider (using a distinct environment variable for the API key): @@ -61,7 +63,17 @@ Or a third-party provider (using a distinct environment variable for the API key name = "Mistral" base_url = "https://api.mistral.ai/v1" env_key = "MISTRAL_API_KEY" -wire_api = "chat" +``` + +Note that Azure requires `api-version` to be passed as a query parameter, so be sure to specify it as part of `query_params` when defining the Azure provider: + +```toml +[model_providers.azure] +name = "Azure" +# Make sure you set the appropriate subdomain for this URL. +base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai" +env_key = "AZURE_OPENAI_API_KEY" # Or "OPENAI_API_KEY", whichever you use. +query_params = { api-version = "2025-04-01-preview" } ``` ## model_provider diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index dfe06d1f..ce2ab053 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -114,8 +114,7 @@ pub(crate) async fn stream_chat_completions( "tools": tools_json, }); - let base_url = provider.base_url.trim_end_matches('/'); - let url = format!("{}/chat/completions", base_url); + let url = provider.get_full_url(); debug!( "POST to {url}: {}", diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 6daa3a89..91a84bf3 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -123,9 +123,7 @@ impl ModelClient { stream: true, }; - let base_url = self.provider.base_url.clone(); - let base_url = base_url.trim_end_matches('/'); - let url = format!("{}/responses", base_url); + let url = self.provider.get_full_url(); trace!("POST to {url}: {}", serde_json::to_string(&payload)?); let mut attempt = 0; diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 6652d7c7..240c6eaf 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -658,6 +658,7 @@ disable_response_storage = true env_key: Some("OPENAI_API_KEY".to_string()), wire_api: crate::WireApi::Chat, env_key_instructions: None, + query_params: None, }; let model_provider_map = { let mut model_provider_map = built_in_model_providers(); diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index a0e0aeb2..b8326ace 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -23,9 +23,10 @@ use crate::openai_api_key::get_openai_api_key; #[serde(rename_all = "lowercase")] pub enum WireApi { /// The experimental “Responses” API exposed by OpenAI at `/v1/responses`. - #[default] Responses, + /// Regular Chat Completions compatible with `/v1/chat/completions`. + #[default] Chat, } @@ -44,7 +45,32 @@ pub struct ModelProviderInfo { pub env_key_instructions: Option, /// Which wire protocol this provider expects. + #[serde(default)] pub wire_api: WireApi, + + /// Optional query parameters to append to the base URL. + pub query_params: Option>, +} + +impl ModelProviderInfo { + pub(crate) fn get_full_url(&self) -> String { + let query_string = self + .query_params + .as_ref() + .map_or_else(String::new, |params| { + let full_params = params + .iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join("&"); + format!("?{full_params}") + }); + let base_url = &self.base_url; + match self.wire_api { + WireApi::Responses => format!("{base_url}/responses{query_string}"), + WireApi::Chat => format!("{base_url}/chat/completions{query_string}"), + } + } } impl ModelProviderInfo { @@ -96,6 +122,7 @@ pub fn built_in_model_providers() -> HashMap { env_key: Some("OPENAI_API_KEY".into()), env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()), wire_api: WireApi::Responses, + query_params: None, }, ), ] @@ -103,3 +130,51 @@ pub fn built_in_model_providers() -> HashMap { .map(|(k, v)| (k.to_string(), v)) .collect() } + +#[cfg(test)] +mod tests { + #![allow(clippy::unwrap_used)] + use super::*; + + #[test] + fn test_deserialize_ollama_model_provider_toml() { + let azure_provider_toml = r#" +name = "Ollama" +base_url = "http://localhost:11434/v1" + "#; + let expected_provider = ModelProviderInfo { + name: "Ollama".into(), + base_url: "http://localhost:11434/v1".into(), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: None, + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); + } + + #[test] + fn test_deserialize_azure_model_provider_toml() { + let azure_provider_toml = r#" +name = "Azure" +base_url = "https://xxxxx.openai.azure.com/openai" +env_key = "AZURE_OPENAI_API_KEY" +query_params = { api-version = "2025-04-01-preview" } + "#; + let expected_provider = ModelProviderInfo { + name: "Azure".into(), + base_url: "https://xxxxx.openai.azure.com/openai".into(), + env_key: Some("AZURE_OPENAI_API_KEY".into()), + env_key_instructions: None, + wire_api: WireApi::Chat, + query_params: Some(maplit::hashmap! { + "api-version".to_string() => "2025-04-01-preview".to_string(), + }), + }; + + let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); + assert_eq!(expected_provider, provider); + } +} diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index b9c89f35..e072e9c3 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -107,6 +107,7 @@ async fn keeps_previous_response_id_between_tasks() { env_key: Some("PATH".into()), env_key_instructions: None, wire_api: codex_core::WireApi::Responses, + query_params: None, }; // Init session diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 02c03681..c1ef10c3 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -96,6 +96,7 @@ async fn retries_on_early_close() { env_key: Some("PATH".into()), env_key_instructions: None, wire_api: codex_core::WireApi::Responses, + query_params: None, }; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());