feat: read model_provider and model_providers from config.toml (#853)

This is the first step in supporting other model providers in the Rust
CLI. Specifically, this PR adds support for the new entries in `Config`
and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the
basic config needed for an LLM provider. This PR does not get us all the
way there yet because `client.rs` still categorically appends
`/responses` to the URL and expects the endpoint to support the OpenAI
Responses API. Will fix that next!
This commit is contained in:
Michael Bolin
2025-05-07 17:38:28 -07:00
committed by GitHub
parent cfe50c7107
commit 86022f097e
12 changed files with 208 additions and 30 deletions

View File

@@ -1,5 +1,7 @@
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::mcp_server_config::McpServerConfig;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPermission;
use crate::protocol::SandboxPolicy;
@@ -19,6 +21,9 @@ pub struct Config {
/// Optional override of model selection.
pub model: String,
/// Info needed to make an API request to the model.
pub model_provider: ModelProviderInfo,
/// Approval policy for executing commands.
pub approval_policy: AskForApproval,
@@ -61,6 +66,9 @@ pub struct Config {
/// Definition for MCP servers that Codex can reach out to for tool calls.
pub mcp_servers: HashMap<String, McpServerConfig>,
/// Combined provider map (defaults merged with user-defined overrides).
pub model_providers: HashMap<String, ModelProviderInfo>,
}
/// Base config deserialized from ~/.codex/config.toml.
@@ -69,6 +77,9 @@ pub struct ConfigToml {
/// Optional override of model selection.
pub model: Option<String>,
/// Provider to use from the model_providers map.
pub model_provider: Option<String>,
/// Default approval policy for executing commands.
pub approval_policy: Option<AskForApproval>,
@@ -93,6 +104,10 @@ pub struct ConfigToml {
/// Definition for MCP servers that Codex can reach out to for tool calls.
#[serde(default)]
pub mcp_servers: HashMap<String, McpServerConfig>,
/// User-defined provider entries that extend/override the built-in list.
#[serde(default)]
pub model_providers: HashMap<String, ModelProviderInfo>,
}
impl ConfigToml {
@@ -152,6 +167,7 @@ pub struct ConfigOverrides {
pub approval_policy: Option<AskForApproval>,
pub sandbox_policy: Option<SandboxPolicy>,
pub disable_response_storage: Option<bool>,
pub provider: Option<String>,
}
impl Config {
@@ -161,10 +177,13 @@ impl Config {
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
tracing::warn!("Config parsed from config.toml: {cfg:?}");
Ok(Self::load_from_base_config_with_overrides(cfg, overrides))
Self::load_from_base_config_with_overrides(cfg, overrides)
}
fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self {
fn load_from_base_config_with_overrides(
cfg: ConfigToml,
overrides: ConfigOverrides,
) -> std::io::Result<Self> {
// Instructions: user-provided instructions.md > embedded default.
let instructions =
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
@@ -176,6 +195,7 @@ impl Config {
approval_policy,
sandbox_policy,
disable_response_storage,
provider,
} = overrides;
let sandbox_policy = match sandbox_policy {
@@ -193,8 +213,28 @@ impl Config {
}
};
Self {
let mut model_providers = built_in_model_providers();
// Merge user-defined providers into the built-in list.
for (key, provider) in cfg.model_providers.into_iter() {
model_providers.entry(key).or_insert(provider);
}
let model_provider_name = provider
.or(cfg.model_provider)
.unwrap_or_else(|| "openai".to_string());
let model_provider = model_providers
.get(&model_provider_name)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Model provider `{model_provider_name}` not found"),
)
})?
.clone();
let config = Self {
model: model.or(cfg.model).unwrap_or_else(default_model),
model_provider,
cwd: cwd.map_or_else(
|| {
tracing::info!("cwd not set, using current dir");
@@ -222,7 +262,9 @@ impl Config {
notify: cfg.notify,
instructions,
mcp_servers: cfg.mcp_servers,
}
model_providers,
};
Ok(config)
}
fn load_instructions() -> Option<String> {
@@ -238,6 +280,7 @@ impl Config {
ConfigToml::default(),
ConfigOverrides::default(),
)
.expect("defaults for test should always succeed")
}
}