feat: read model_provider and model_providers from config.toml (#853)

This is the first step in supporting other model providers in the Rust
CLI. Specifically, this PR adds support for the new entries in `Config`
and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the
basic config needed for an LLM provider. This PR does not get us all the
way there yet because `client.rs` still categorically appends
`/responses` to the URL and expects the endpoint to support the OpenAI
Responses API. Will fix that next!
This commit is contained in:
Michael Bolin
2025-05-07 17:38:28 -07:00
committed by GitHub
parent cfe50c7107
commit 86022f097e
12 changed files with 208 additions and 30 deletions

View File

@@ -26,10 +26,9 @@ use tracing::warn;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::flags::OPENAI_API_BASE;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::flags::get_api_key;
use crate::model_provider_info::ModelProviderInfo;
use crate::models::ResponseItem;
use crate::util::backoff;
@@ -141,13 +140,16 @@ static DEFAULT_TOOLS: LazyLock<Vec<ResponsesApiTool>> = LazyLock::new(|| {
pub struct ModelClient {
model: String,
client: reqwest::Client,
provider: ModelProviderInfo,
}
impl ModelClient {
pub fn new(model: impl ToString) -> Self {
let model = model.to_string();
let client = reqwest::Client::new();
Self { model, client }
pub fn new(model: impl ToString, provider: ModelProviderInfo) -> Self {
Self {
model: model.to_string(),
client: reqwest::Client::new(),
provider,
}
}
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
@@ -188,7 +190,9 @@ impl ModelClient {
stream: true,
};
let url = format!("{}/v1/responses", *OPENAI_API_BASE);
let base_url = self.provider.base_url.clone();
let base_url = base_url.trim_end_matches('/');
let url = format!("{}/responses", base_url);
debug!(url, "POST");
trace!("request payload: {}", serde_json::to_string(&payload)?);
@@ -196,10 +200,14 @@ impl ModelClient {
loop {
attempt += 1;
let api_key = self
.provider
.api_key()
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
let res = self
.client
.post(&url)
.bearer_auth(get_api_key()?)
.bearer_auth(api_key)
.header("OpenAI-Beta", "responses=experimental")
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload)

View File

@@ -80,6 +80,7 @@ impl Codex {
let (tx_sub, rx_sub) = async_channel::bounded(64);
let (tx_event, rx_event) = async_channel::bounded(64);
let configure_session = Op::ConfigureSession {
provider: config.model_provider.clone(),
model: config.model.clone(),
instructions: config.instructions.clone(),
approval_policy: config.approval_policy,
@@ -504,6 +505,7 @@ async fn submission_loop(
sess.abort();
}
Op::ConfigureSession {
provider,
model,
instructions,
approval_policy,
@@ -512,7 +514,7 @@ async fn submission_loop(
notify,
cwd,
} => {
info!(model, "Configuring session");
info!("Configuring session: model={model}; provider={provider:?}");
if !cwd.is_absolute() {
let message = format!("cwd is not absolute: {cwd:?}");
error!(message);
@@ -526,7 +528,7 @@ async fn submission_loop(
return;
}
let client = ModelClient::new(model.clone());
let client = ModelClient::new(model.clone(), provider.clone());
// abort any current running session and clone its state
let state = match sess.take() {

View File

@@ -1,5 +1,7 @@
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::mcp_server_config::McpServerConfig;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPermission;
use crate::protocol::SandboxPolicy;
@@ -19,6 +21,9 @@ pub struct Config {
/// Optional override of model selection.
pub model: String,
/// Info needed to make an API request to the model.
pub model_provider: ModelProviderInfo,
/// Approval policy for executing commands.
pub approval_policy: AskForApproval,
@@ -61,6 +66,9 @@ pub struct Config {
/// Definition for MCP servers that Codex can reach out to for tool calls.
pub mcp_servers: HashMap<String, McpServerConfig>,
/// Combined provider map (defaults merged with user-defined overrides).
pub model_providers: HashMap<String, ModelProviderInfo>,
}
/// Base config deserialized from ~/.codex/config.toml.
@@ -69,6 +77,9 @@ pub struct ConfigToml {
/// Optional override of model selection.
pub model: Option<String>,
/// Provider to use from the model_providers map.
pub model_provider: Option<String>,
/// Default approval policy for executing commands.
pub approval_policy: Option<AskForApproval>,
@@ -93,6 +104,10 @@ pub struct ConfigToml {
/// Definition for MCP servers that Codex can reach out to for tool calls.
#[serde(default)]
pub mcp_servers: HashMap<String, McpServerConfig>,
/// User-defined provider entries that extend/override the built-in list.
#[serde(default)]
pub model_providers: HashMap<String, ModelProviderInfo>,
}
impl ConfigToml {
@@ -152,6 +167,7 @@ pub struct ConfigOverrides {
pub approval_policy: Option<AskForApproval>,
pub sandbox_policy: Option<SandboxPolicy>,
pub disable_response_storage: Option<bool>,
pub provider: Option<String>,
}
impl Config {
@@ -161,10 +177,13 @@ impl Config {
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
tracing::warn!("Config parsed from config.toml: {cfg:?}");
Ok(Self::load_from_base_config_with_overrides(cfg, overrides))
Self::load_from_base_config_with_overrides(cfg, overrides)
}
fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self {
fn load_from_base_config_with_overrides(
cfg: ConfigToml,
overrides: ConfigOverrides,
) -> std::io::Result<Self> {
// Instructions: user-provided instructions.md > embedded default.
let instructions =
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
@@ -176,6 +195,7 @@ impl Config {
approval_policy,
sandbox_policy,
disable_response_storage,
provider,
} = overrides;
let sandbox_policy = match sandbox_policy {
@@ -193,8 +213,28 @@ impl Config {
}
};
Self {
let mut model_providers = built_in_model_providers();
// Merge user-defined providers into the built-in list.
for (key, provider) in cfg.model_providers.into_iter() {
model_providers.entry(key).or_insert(provider);
}
let model_provider_name = provider
.or(cfg.model_provider)
.unwrap_or_else(|| "openai".to_string());
let model_provider = model_providers
.get(&model_provider_name)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Model provider `{model_provider_name}` not found"),
)
})?
.clone();
let config = Self {
model: model.or(cfg.model).unwrap_or_else(default_model),
model_provider,
cwd: cwd.map_or_else(
|| {
tracing::info!("cwd not set, using current dir");
@@ -222,7 +262,9 @@ impl Config {
notify: cfg.notify,
instructions,
mcp_servers: cfg.mcp_servers,
}
model_providers,
};
Ok(config)
}
fn load_instructions() -> Option<String> {
@@ -238,6 +280,7 @@ impl Config {
ConfigToml::default(),
ConfigOverrides::default(),
)
.expect("defaults for test should always succeed")
}
}

View File

@@ -2,12 +2,11 @@ use std::time::Duration;
use env_flags::env_flags;
use crate::error::CodexErr;
use crate::error::Result;
env_flags! {
pub OPENAI_DEFAULT_MODEL: &str = "o3";
pub OPENAI_API_BASE: &str = "https://api.openai.com";
pub OPENAI_API_BASE: &str = "https://api.openai.com/v1";
/// Fallback when the provider-specific key is not set.
pub OPENAI_API_KEY: Option<&str> = None;
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
value.parse().map(Duration::from_millis)
@@ -21,9 +20,6 @@ env_flags! {
value.parse().map(Duration::from_millis)
};
/// Fixture path for offline tests (see client.rs).
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
}
pub fn get_api_key() -> Result<&'static str> {
OPENAI_API_KEY.ok_or_else(|| CodexErr::EnvVar("OPENAI_API_KEY"))
}

View File

@@ -7,6 +7,7 @@
mod client;
pub mod codex;
pub use codex::Codex;
pub mod codex_wrapper;
pub mod config;
pub mod error;
@@ -18,6 +19,8 @@ pub mod linux;
mod mcp_connection_manager;
pub mod mcp_server_config;
mod mcp_tool_call;
mod model_provider_info;
pub use model_provider_info::ModelProviderInfo;
mod models;
pub mod protocol;
mod rollout;
@@ -25,5 +28,3 @@ mod safety;
mod user_notification;
pub mod util;
mod zdr_transcript;
pub use codex::Codex;

View File

@@ -0,0 +1,103 @@
//! Registry of model providers supported by Codex.
//!
//! Providers can be defined in two places:
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: String,
/// Environment variable that stores the user's API key for this provider.
pub env_key: String,
}
impl ModelProviderInfo {
/// Returns the API key for this provider if present in the environment.
pub fn api_key(&self) -> Option<String> {
std::env::var(&self.env_key).ok()
}
}
/// Built-in default provider list.
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
[
(
"openai",
P {
name: "OpenAI".into(),
base_url: "https://api.openai.com/v1".into(),
env_key: "OPENAI_API_KEY".into(),
},
),
(
"openrouter",
P {
name: "OpenRouter".into(),
base_url: "https://openrouter.ai/api/v1".into(),
env_key: "OPENROUTER_API_KEY".into(),
},
),
(
"gemini",
P {
name: "Gemini".into(),
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(),
env_key: "GEMINI_API_KEY".into(),
},
),
(
"ollama",
P {
name: "Ollama".into(),
base_url: "http://localhost:11434/v1".into(),
env_key: "OLLAMA_API_KEY".into(),
},
),
(
"mistral",
P {
name: "Mistral".into(),
base_url: "https://api.mistral.ai/v1".into(),
env_key: "MISTRAL_API_KEY".into(),
},
),
(
"deepseek",
P {
name: "DeepSeek".into(),
base_url: "https://api.deepseek.com".into(),
env_key: "DEEPSEEK_API_KEY".into(),
},
),
(
"xai",
P {
name: "xAI".into(),
base_url: "https://api.x.ai/v1".into(),
env_key: "XAI_API_KEY".into(),
},
),
(
"groq",
P {
name: "Groq".into(),
base_url: "https://api.groq.com/openai/v1".into(),
env_key: "GROQ_API_KEY".into(),
},
),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect()
}

View File

@@ -11,6 +11,8 @@ use mcp_types::CallToolResult;
use serde::Deserialize;
use serde::Serialize;
use crate::model_provider_info::ModelProviderInfo;
/// Submission Queue Entry - requests from user
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Submission {
@@ -27,6 +29,9 @@ pub struct Submission {
pub enum Op {
/// Configure the model session.
ConfigureSession {
/// Provider identifier ("openai", "openrouter", ...).
provider: ModelProviderInfo,
/// If not specified, server will use its default model.
model: String,
/// Model instructions

View File

@@ -1,6 +1,7 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::config::Config;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
@@ -80,14 +81,21 @@ async fn keeps_previous_response_id_between_tasks() {
// Update environment `set_var` is `unsafe` starting with the 2024
// edition so we group the calls into a single `unsafe { … }` block.
unsafe {
std::env::set_var("OPENAI_API_KEY", "test-key");
std::env::set_var("OPENAI_API_BASE", server.uri());
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
}
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: "PATH".into(),
};
// Init session
let config = Config::load_default_config_for_test();
let mut config = Config::load_default_config_for_test();
config.model_provider = model_provider;
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();

View File

@@ -4,6 +4,7 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::config::Config;
use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
@@ -68,15 +69,23 @@ async fn retries_on_early_close() {
// scope is very small and clearly delineated.
unsafe {
std::env::set_var("OPENAI_API_KEY", "test-key");
std::env::set_var("OPENAI_API_BASE", server.uri());
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
}
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: "PATH".into(),
};
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let config = Config::load_default_config_for_test();
let mut config = Config::load_default_config_for_test();
config.model_provider = model_provider;
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
codex

View File

@@ -66,6 +66,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> {
None
},
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
provider: None,
};
let config = Config::load_with_overrides(overrides)?;

View File

@@ -158,6 +158,7 @@ impl CodexToolCallParam {
approval_policy: approval_policy.map(Into::into),
sandbox_policy,
disable_response_storage,
provider: None,
};
let cfg = codex_core::config::Config::load_with_overrides(overrides)?;

View File

@@ -58,6 +58,7 @@ pub fn run_main(cli: Cli) -> std::io::Result<()> {
None
},
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
provider: None,
};
#[allow(clippy::print_stderr)]
match Config::load_with_overrides(overrides) {