feat: add query_params option to ModelProviderInfo to support Azure (#1435)
As discovered in https://github.com/openai/codex/issues/1365, the Azure provider needs to be able to specify `api-version` as a query param, so this PR introduces a generic `query_params` option to the `model_providers` config so that an Azure provider can be defined as follows: ```toml [model_providers.azure] name = "Azure" base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai" env_key = "AZURE_OPENAI_API_KEY" query_params = { api-version = "2025-04-01-preview" } ``` This PR also updates the docs with this example. While here, we also update `wire_api` to default to `"chat"`, as that is likely the common case for someone defining an external provider. Fixes https://github.com/openai/codex/issues/1365.
This commit is contained in:
@@ -41,8 +41,11 @@ base_url = "https://api.openai.com/v1"
|
|||||||
# using Codex with this provider. The value of the environment variable must be
|
# using Codex with this provider. The value of the environment variable must be
|
||||||
# non-empty and will be used in the `Bearer TOKEN` HTTP header for the POST request.
|
# non-empty and will be used in the `Bearer TOKEN` HTTP header for the POST request.
|
||||||
env_key = "OPENAI_API_KEY"
|
env_key = "OPENAI_API_KEY"
|
||||||
# Valid values for wire_api are "chat" and "responses".
|
# Valid values for wire_api are "chat" and "responses". Defaults to "chat" if omitted.
|
||||||
wire_api = "chat"
|
wire_api = "chat"
|
||||||
|
# If necessary, extra query params that need to be added to the URL.
|
||||||
|
# See the Azure example below.
|
||||||
|
query_params = {}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note this makes it possible to use Codex CLI with non-OpenAI models, so long as they use a wire API that is compatible with the OpenAI chat completions API. For example, you could define the following provider to use Codex CLI with Ollama running locally:
|
Note this makes it possible to use Codex CLI with non-OpenAI models, so long as they use a wire API that is compatible with the OpenAI chat completions API. For example, you could define the following provider to use Codex CLI with Ollama running locally:
|
||||||
@@ -51,7 +54,6 @@ Note this makes it possible to use Codex CLI with non-OpenAI models, so long as
|
|||||||
[model_providers.ollama]
|
[model_providers.ollama]
|
||||||
name = "Ollama"
|
name = "Ollama"
|
||||||
base_url = "http://localhost:11434/v1"
|
base_url = "http://localhost:11434/v1"
|
||||||
wire_api = "chat"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Or a third-party provider (using a distinct environment variable for the API key):
|
Or a third-party provider (using a distinct environment variable for the API key):
|
||||||
@@ -61,7 +63,17 @@ Or a third-party provider (using a distinct environment variable for the API key
|
|||||||
name = "Mistral"
|
name = "Mistral"
|
||||||
base_url = "https://api.mistral.ai/v1"
|
base_url = "https://api.mistral.ai/v1"
|
||||||
env_key = "MISTRAL_API_KEY"
|
env_key = "MISTRAL_API_KEY"
|
||||||
wire_api = "chat"
|
```
|
||||||
|
|
||||||
|
Note that Azure requires `api-version` to be passed as a query parameter, so be sure to specify it as part of `query_params` when defining the Azure provider:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[model_providers.azure]
|
||||||
|
name = "Azure"
|
||||||
|
# Make sure you set the appropriate subdomain for this URL.
|
||||||
|
base_url = "https://YOUR_PROJECT_NAME.openai.azure.com/openai"
|
||||||
|
env_key = "AZURE_OPENAI_API_KEY" # Or "OPENAI_API_KEY", whichever you use.
|
||||||
|
query_params = { api-version = "2025-04-01-preview" }
|
||||||
```
|
```
|
||||||
|
|
||||||
## model_provider
|
## model_provider
|
||||||
|
|||||||
@@ -114,8 +114,7 @@ pub(crate) async fn stream_chat_completions(
|
|||||||
"tools": tools_json,
|
"tools": tools_json,
|
||||||
});
|
});
|
||||||
|
|
||||||
let base_url = provider.base_url.trim_end_matches('/');
|
let url = provider.get_full_url();
|
||||||
let url = format!("{}/chat/completions", base_url);
|
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"POST to {url}: {}",
|
"POST to {url}: {}",
|
||||||
|
|||||||
@@ -123,9 +123,7 @@ impl ModelClient {
|
|||||||
stream: true,
|
stream: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
let base_url = self.provider.base_url.clone();
|
let url = self.provider.get_full_url();
|
||||||
let base_url = base_url.trim_end_matches('/');
|
|
||||||
let url = format!("{}/responses", base_url);
|
|
||||||
trace!("POST to {url}: {}", serde_json::to_string(&payload)?);
|
trace!("POST to {url}: {}", serde_json::to_string(&payload)?);
|
||||||
|
|
||||||
let mut attempt = 0;
|
let mut attempt = 0;
|
||||||
|
|||||||
@@ -658,6 +658,7 @@ disable_response_storage = true
|
|||||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||||
wire_api: crate::WireApi::Chat,
|
wire_api: crate::WireApi::Chat,
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
|
query_params: None,
|
||||||
};
|
};
|
||||||
let model_provider_map = {
|
let model_provider_map = {
|
||||||
let mut model_provider_map = built_in_model_providers();
|
let mut model_provider_map = built_in_model_providers();
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ use crate::openai_api_key::get_openai_api_key;
|
|||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum WireApi {
|
pub enum WireApi {
|
||||||
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
/// The experimental “Responses” API exposed by OpenAI at `/v1/responses`.
|
||||||
#[default]
|
|
||||||
Responses,
|
Responses,
|
||||||
|
|
||||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||||
|
#[default]
|
||||||
Chat,
|
Chat,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +45,32 @@ pub struct ModelProviderInfo {
|
|||||||
pub env_key_instructions: Option<String>,
|
pub env_key_instructions: Option<String>,
|
||||||
|
|
||||||
/// Which wire protocol this provider expects.
|
/// Which wire protocol this provider expects.
|
||||||
|
#[serde(default)]
|
||||||
pub wire_api: WireApi,
|
pub wire_api: WireApi,
|
||||||
|
|
||||||
|
/// Optional query parameters to append to the base URL.
|
||||||
|
pub query_params: Option<HashMap<String, String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelProviderInfo {
|
||||||
|
pub(crate) fn get_full_url(&self) -> String {
|
||||||
|
let query_string = self
|
||||||
|
.query_params
|
||||||
|
.as_ref()
|
||||||
|
.map_or_else(String::new, |params| {
|
||||||
|
let full_params = params
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("{k}={v}"))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("&");
|
||||||
|
format!("?{full_params}")
|
||||||
|
});
|
||||||
|
let base_url = &self.base_url;
|
||||||
|
match self.wire_api {
|
||||||
|
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||||
|
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelProviderInfo {
|
impl ModelProviderInfo {
|
||||||
@@ -96,6 +122,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
|||||||
env_key: Some("OPENAI_API_KEY".into()),
|
env_key: Some("OPENAI_API_KEY".into()),
|
||||||
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
|
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
|
||||||
wire_api: WireApi::Responses,
|
wire_api: WireApi::Responses,
|
||||||
|
query_params: None,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -103,3 +130,51 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
|||||||
.map(|(k, v)| (k.to_string(), v))
|
.map(|(k, v)| (k.to_string(), v))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_ollama_model_provider_toml() {
|
||||||
|
let azure_provider_toml = r#"
|
||||||
|
name = "Ollama"
|
||||||
|
base_url = "http://localhost:11434/v1"
|
||||||
|
"#;
|
||||||
|
let expected_provider = ModelProviderInfo {
|
||||||
|
name: "Ollama".into(),
|
||||||
|
base_url: "http://localhost:11434/v1".into(),
|
||||||
|
env_key: None,
|
||||||
|
env_key_instructions: None,
|
||||||
|
wire_api: WireApi::Chat,
|
||||||
|
query_params: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
|
assert_eq!(expected_provider, provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_azure_model_provider_toml() {
|
||||||
|
let azure_provider_toml = r#"
|
||||||
|
name = "Azure"
|
||||||
|
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||||
|
env_key = "AZURE_OPENAI_API_KEY"
|
||||||
|
query_params = { api-version = "2025-04-01-preview" }
|
||||||
|
"#;
|
||||||
|
let expected_provider = ModelProviderInfo {
|
||||||
|
name: "Azure".into(),
|
||||||
|
base_url: "https://xxxxx.openai.azure.com/openai".into(),
|
||||||
|
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||||
|
env_key_instructions: None,
|
||||||
|
wire_api: WireApi::Chat,
|
||||||
|
query_params: Some(maplit::hashmap! {
|
||||||
|
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||||
|
assert_eq!(expected_provider, provider);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ async fn keeps_previous_response_id_between_tasks() {
|
|||||||
env_key: Some("PATH".into()),
|
env_key: Some("PATH".into()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: codex_core::WireApi::Responses,
|
wire_api: codex_core::WireApi::Responses,
|
||||||
|
query_params: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Init session
|
// Init session
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ async fn retries_on_early_close() {
|
|||||||
env_key: Some("PATH".into()),
|
env_key: Some("PATH".into()),
|
||||||
env_key_instructions: None,
|
env_key_instructions: None,
|
||||||
wire_api: codex_core::WireApi::Responses,
|
wire_api: codex_core::WireApi::Responses,
|
||||||
|
query_params: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||||
|
|||||||
Reference in New Issue
Block a user