feat: read model_provider and model_providers from config.toml (#853)
This is the first step in supporting other model providers in the Rust CLI. Specifically, this PR adds support for the new entries in `Config` and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the basic config needed for an LLM provider. This PR does not get us all the way there yet because `client.rs` still categorically appends `/responses` to the URL and expects the endpoint to support the OpenAI Responses API. Will fix that next!
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
use crate::flags::OPENAI_DEFAULT_MODEL;
|
||||
use crate::mcp_server_config::McpServerConfig;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPermission;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
@@ -19,6 +21,9 @@ pub struct Config {
|
||||
/// Optional override of model selection.
|
||||
pub model: String,
|
||||
|
||||
/// Info needed to make an API request to the model.
|
||||
pub model_provider: ModelProviderInfo,
|
||||
|
||||
/// Approval policy for executing commands.
|
||||
pub approval_policy: AskForApproval,
|
||||
|
||||
@@ -61,6 +66,9 @@ pub struct Config {
|
||||
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
}
|
||||
|
||||
/// Base config deserialized from ~/.codex/config.toml.
|
||||
@@ -69,6 +77,9 @@ pub struct ConfigToml {
|
||||
/// Optional override of model selection.
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Provider to use from the model_providers map.
|
||||
pub model_provider: Option<String>,
|
||||
|
||||
/// Default approval policy for executing commands.
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
|
||||
@@ -93,6 +104,10 @@ pub struct ConfigToml {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -152,6 +167,7 @@ pub struct ConfigOverrides {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_policy: Option<SandboxPolicy>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -161,10 +177,13 @@ impl Config {
|
||||
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
|
||||
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
|
||||
tracing::warn!("Config parsed from config.toml: {cfg:?}");
|
||||
Ok(Self::load_from_base_config_with_overrides(cfg, overrides))
|
||||
Self::load_from_base_config_with_overrides(cfg, overrides)
|
||||
}
|
||||
|
||||
fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self {
|
||||
fn load_from_base_config_with_overrides(
|
||||
cfg: ConfigToml,
|
||||
overrides: ConfigOverrides,
|
||||
) -> std::io::Result<Self> {
|
||||
// Instructions: user-provided instructions.md > embedded default.
|
||||
let instructions =
|
||||
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
|
||||
@@ -176,6 +195,7 @@ impl Config {
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
provider,
|
||||
} = overrides;
|
||||
|
||||
let sandbox_policy = match sandbox_policy {
|
||||
@@ -193,8 +213,28 @@ impl Config {
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
let mut model_providers = built_in_model_providers();
|
||||
// Merge user-defined providers into the built-in list.
|
||||
for (key, provider) in cfg.model_providers.into_iter() {
|
||||
model_providers.entry(key).or_insert(provider);
|
||||
}
|
||||
|
||||
let model_provider_name = provider
|
||||
.or(cfg.model_provider)
|
||||
.unwrap_or_else(|| "openai".to_string());
|
||||
let model_provider = model_providers
|
||||
.get(&model_provider_name)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("Model provider `{model_provider_name}` not found"),
|
||||
)
|
||||
})?
|
||||
.clone();
|
||||
|
||||
let config = Self {
|
||||
model: model.or(cfg.model).unwrap_or_else(default_model),
|
||||
model_provider,
|
||||
cwd: cwd.map_or_else(
|
||||
|| {
|
||||
tracing::info!("cwd not set, using current dir");
|
||||
@@ -222,7 +262,9 @@ impl Config {
|
||||
notify: cfg.notify,
|
||||
instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
}
|
||||
model_providers,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn load_instructions() -> Option<String> {
|
||||
@@ -238,6 +280,7 @@ impl Config {
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
)
|
||||
.expect("defaults for test should always succeed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user