feat: add query_params option to ModelProviderInfo to support Azure (#1435)
As discovered in https://github.com/openai/codex/issues/1365, the Azure provider needs to be able to specify `api-version` as a query param, so this PR introduces a generic `query_params` option to the `model_providers` config so that an Azure provider can be defined as follows: ```toml [model_providers.azure] name = "Azure" base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai" env_key = "AZURE_OPENAI_API_KEY" query_params = { api-version = "2025-04-01-preview" } ``` This PR also updates the docs with this example. While here, we also update `wire_api` to default to `"chat"`, as that is likely the common case for someone defining an external provider. Fixes https://github.com/openai/codex/issues/1365.
This commit is contained in:
@@ -41,8 +41,11 @@ base_url = "https://api.openai.com/v1"
|
||||
# using Codex with this provider. The value of the environment variable must be
|
||||
# non-empty and will be used in the `Bearer TOKEN` HTTP header for the POST request.
|
||||
env_key = "OPENAI_API_KEY"
|
||||
# Valid values for wire_api are "chat" and "responses".
|
||||
# Valid values for wire_api are "chat" and "responses". Defaults to "chat" if omitted.
|
||||
wire_api = "chat"
|
||||
# If necessary, extra query params that need to be added to the URL.
|
||||
# See the Azure example below.
|
||||
query_params = {}
|
||||
```
|
||||
|
||||
Note this makes it possible to use Codex CLI with non-OpenAI models, so long as they use a wire API that is compatible with the OpenAI chat completions API. For example, you could define the following provider to use Codex CLI with Ollama running locally:
|
||||
@@ -51,7 +54,6 @@ Note this makes it possible to use Codex CLI with non-OpenAI models, so long as
|
||||
[model_providers.ollama]
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
wire_api = "chat"
|
||||
```
|
||||
|
||||
Or a third-party provider (using a distinct environment variable for the API key):
|
||||
@@ -61,7 +63,17 @@ Or a third-party provider (using a distinct environment variable for the API key
|
||||
name = "Mistral"
|
||||
base_url = "https://api.mistral.ai/v1"
|
||||
env_key = "MISTRAL_API_KEY"
|
||||
wire_api = "chat"
|
||||
```
|
||||
|
||||
Note that Azure requires `api-version` to be passed as a query parameter, so be sure to specify it as part of `query_params` when defining the Azure provider:
|
||||
|
||||
```toml
|
||||
[model_providers.azure]
|
||||
name = "Azure"
|
||||
# Make sure you set the appropriate subdomain for this URL.
|
||||
base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai"
|
||||
env_key = "AZURE_OPENAI_API_KEY" # Or "OPENAI_API_KEY", whichever you use.
|
||||
query_params = { api-version = "2025-04-01-preview" }
|
||||
```
|
||||
|
||||
## model_provider
|
||||
|
||||
@@ -114,8 +114,7 @@ pub(crate) async fn stream_chat_completions(
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
let base_url = provider.base_url.trim_end_matches('/');
|
||||
let url = format!("{}/chat/completions", base_url);
|
||||
let url = provider.get_full_url();
|
||||
|
||||
debug!(
|
||||
"POST to {url}: {}",
|
||||
|
||||
@@ -123,9 +123,7 @@ impl ModelClient {
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let base_url = self.provider.base_url.clone();
|
||||
let base_url = base_url.trim_end_matches('/');
|
||||
let url = format!("{}/responses", base_url);
|
||||
let url = self.provider.get_full_url();
|
||||
trace!("POST to {url}: {}", serde_json::to_string(&payload)?);
|
||||
|
||||
let mut attempt = 0;
|
||||
|
||||
@@ -658,6 +658,7 @@ disable_response_storage = true
|
||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||
wire_api: crate::WireApi::Chat,
|
||||
env_key_instructions: None,
|
||||
query_params: None,
|
||||
};
|
||||
let model_provider_map = {
|
||||
let mut model_provider_map = built_in_model_providers();
|
||||
|
||||
@@ -23,9 +23,10 @@ use crate::openai_api_key::get_openai_api_key;
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
||||
#[default]
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
}
|
||||
|
||||
@@ -44,7 +45,32 @@ pub struct ModelProviderInfo {
|
||||
pub env_key_instructions: Option<String>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
|
||||
/// Optional query parameters to append to the base URL.
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
pub(crate) fn get_full_url(&self) -> String {
|
||||
let query_string = self
|
||||
.query_params
|
||||
.as_ref()
|
||||
.map_or_else(String::new, |params| {
|
||||
let full_params = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
format!("?{full_params}")
|
||||
});
|
||||
let base_url = &self.base_url;
|
||||
match self.wire_api {
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
@@ -96,6 +122,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
env_key: Some("OPENAI_API_KEY".into()),
|
||||
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -103,3 +130,51 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: "http://localhost:11434/v1".into(),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_azure_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Azure"
|
||||
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||
env_key = "AZURE_OPENAI_API_KEY"
|
||||
query_params = { api-version = "2025-04-01-preview" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: "https://xxxxx.openai.azure.com/openai".into(),
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,6 +107,7 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
};
|
||||
|
||||
// Init session
|
||||
|
||||
@@ -96,6 +96,7 @@ async fn retries_on_early_close() {
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
|
||||
Reference in New Issue
Block a user