From 86022f097e5ff81e751704e61ddd89e8a2acd642 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Wed, 7 May 2025 17:38:28 -0700 Subject: [PATCH] feat: read `model_provider` and `model_providers` from config.toml (#853) This is the first step in supporting other model providers in the Rust CLI. Specifically, this PR adds support for the new entries in `Config` and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the basic config needed for an LLM provider. This PR does not get us all the way there yet because `client.rs` still categorically appends `/responses` to the URL and expects the endpoint to support the OpenAI Responses API. Will fix that next! --- codex-rs/core/src/client.rs | 24 +++-- codex-rs/core/src/codex.rs | 6 +- codex-rs/core/src/config.rs | 51 ++++++++- codex-rs/core/src/flags.rs | 12 +-- codex-rs/core/src/lib.rs | 5 +- codex-rs/core/src/model_provider_info.rs | 103 +++++++++++++++++++ codex-rs/core/src/protocol.rs | 5 + codex-rs/core/tests/previous_response_id.rs | 14 ++- codex-rs/core/tests/stream_no_completed.rs | 15 ++- codex-rs/exec/src/lib.rs | 1 + codex-rs/mcp-server/src/codex_tool_config.rs | 1 + codex-rs/tui/src/lib.rs | 1 + 12 files changed, 208 insertions(+), 30 deletions(-) create mode 100644 codex-rs/core/src/model_provider_info.rs diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 79f99e8c..9216e68c 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -26,10 +26,9 @@ use tracing::warn; use crate::error::CodexErr; use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; -use crate::flags::OPENAI_API_BASE; use crate::flags::OPENAI_REQUEST_MAX_RETRIES; use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS; -use crate::flags::get_api_key; +use crate::model_provider_info::ModelProviderInfo; use crate::models::ResponseItem; use crate::util::backoff; @@ -141,13 +140,16 @@ static DEFAULT_TOOLS: LazyLock> = LazyLock::new(|| { pub struct ModelClient { model: String, client: reqwest::Client, + provider: ModelProviderInfo, } impl ModelClient { - pub fn new(model: impl ToString) -> Self { - let model = model.to_string(); - let client = reqwest::Client::new(); - Self { model, client } + pub fn new(model: impl ToString, provider: ModelProviderInfo) -> Self { + Self { + model: model.to_string(), + client: reqwest::Client::new(), + provider, + } } pub async fn stream(&mut self, prompt: &Prompt) -> Result { @@ -188,7 +190,9 @@ impl ModelClient { stream: true, }; - let url = format!("{}/v1/responses", *OPENAI_API_BASE); + let base_url = self.provider.base_url.clone(); + let base_url = base_url.trim_end_matches('/'); + let url = format!("{}/responses", base_url); debug!(url, "POST"); trace!("request payload: {}", serde_json::to_string(&payload)?); @@ -196,10 +200,14 @@ impl ModelClient { loop { attempt += 1; + let api_key = self + .provider + .api_key() + .ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?; let res = self .client .post(&url) - .bearer_auth(get_api_key()?) + .bearer_auth(api_key) .header("OpenAI-Beta", "responses=experimental") .header(reqwest::header::ACCEPT, "text/event-stream") .json(&payload) diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 7749ee7d..039e11ce 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -80,6 +80,7 @@ impl Codex { let (tx_sub, rx_sub) = async_channel::bounded(64); let (tx_event, rx_event) = async_channel::bounded(64); let configure_session = Op::ConfigureSession { + provider: config.model_provider.clone(), model: config.model.clone(), instructions: config.instructions.clone(), approval_policy: config.approval_policy, @@ -504,6 +505,7 @@ async fn submission_loop( sess.abort(); } Op::ConfigureSession { + provider, model, instructions, approval_policy, @@ -512,7 +514,7 @@ async fn submission_loop( notify, cwd, } => { - info!(model, "Configuring session"); + info!("Configuring session: model={model}; provider={provider:?}"); if !cwd.is_absolute() { let message = format!("cwd is not absolute: {cwd:?}"); error!(message); @@ -526,7 +528,7 @@ async fn submission_loop( return; } - let client = ModelClient::new(model.clone()); + let client = ModelClient::new(model.clone(), provider.clone()); // abort any current running session and clone its state let state = match sess.take() { diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 68fec35e..087d6afb 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -1,5 +1,7 @@ use crate::flags::OPENAI_DEFAULT_MODEL; use crate::mcp_server_config::McpServerConfig; +use crate::model_provider_info::ModelProviderInfo; +use crate::model_provider_info::built_in_model_providers; use crate::protocol::AskForApproval; use crate::protocol::SandboxPermission; use crate::protocol::SandboxPolicy; @@ -19,6 +21,9 @@ pub struct Config { /// Optional override of model selection. pub model: String, + /// Info needed to make an API request to the model. + pub model_provider: ModelProviderInfo, + /// Approval policy for executing commands. pub approval_policy: AskForApproval, @@ -61,6 +66,9 @@ pub struct Config { /// Definition for MCP servers that Codex can reach out to for tool calls. pub mcp_servers: HashMap, + + /// Combined provider map (defaults merged with user-defined overrides). + pub model_providers: HashMap, } /// Base config deserialized from ~/.codex/config.toml. @@ -69,6 +77,9 @@ pub struct ConfigToml { /// Optional override of model selection. pub model: Option, + /// Provider to use from the model_providers map. + pub model_provider: Option, + /// Default approval policy for executing commands. pub approval_policy: Option, @@ -93,6 +104,10 @@ pub struct ConfigToml { /// Definition for MCP servers that Codex can reach out to for tool calls. #[serde(default)] pub mcp_servers: HashMap, + + /// User-defined provider entries that extend/override the built-in list. + #[serde(default)] + pub model_providers: HashMap, } impl ConfigToml { @@ -152,6 +167,7 @@ pub struct ConfigOverrides { pub approval_policy: Option, pub sandbox_policy: Option, pub disable_response_storage: Option, + pub provider: Option, } impl Config { @@ -161,10 +177,13 @@ impl Config { pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result { let cfg: ConfigToml = ConfigToml::load_from_toml()?; tracing::warn!("Config parsed from config.toml: {cfg:?}"); - Ok(Self::load_from_base_config_with_overrides(cfg, overrides)) + Self::load_from_base_config_with_overrides(cfg, overrides) } - fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self { + fn load_from_base_config_with_overrides( + cfg: ConfigToml, + overrides: ConfigOverrides, + ) -> std::io::Result { // Instructions: user-provided instructions.md > embedded default. let instructions = Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string())); @@ -176,6 +195,7 @@ impl Config { approval_policy, sandbox_policy, disable_response_storage, + provider, } = overrides; let sandbox_policy = match sandbox_policy { @@ -193,8 +213,28 @@ impl Config { } }; - Self { + let mut model_providers = built_in_model_providers(); + // Merge user-defined providers into the built-in list. + for (key, provider) in cfg.model_providers.into_iter() { + model_providers.entry(key).or_insert(provider); + } + + let model_provider_name = provider + .or(cfg.model_provider) + .unwrap_or_else(|| "openai".to_string()); + let model_provider = model_providers + .get(&model_provider_name) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::NotFound, + format!("Model provider `{model_provider_name}` not found"), + ) + })? + .clone(); + + let config = Self { model: model.or(cfg.model).unwrap_or_else(default_model), + model_provider, cwd: cwd.map_or_else( || { tracing::info!("cwd not set, using current dir"); @@ -222,7 +262,9 @@ impl Config { notify: cfg.notify, instructions, mcp_servers: cfg.mcp_servers, - } + model_providers, + }; + Ok(config) } fn load_instructions() -> Option { @@ -238,6 +280,7 @@ impl Config { ConfigToml::default(), ConfigOverrides::default(), ) + .expect("defaults for test should always succeed") } } diff --git a/codex-rs/core/src/flags.rs b/codex-rs/core/src/flags.rs index 4d0d4bbe..44198fde 100644 --- a/codex-rs/core/src/flags.rs +++ b/codex-rs/core/src/flags.rs @@ -2,12 +2,11 @@ use std::time::Duration; use env_flags::env_flags; -use crate::error::CodexErr; -use crate::error::Result; - env_flags! { pub OPENAI_DEFAULT_MODEL: &str = "o3"; - pub OPENAI_API_BASE: &str = "https://api.openai.com"; + pub OPENAI_API_BASE: &str = "https://api.openai.com/v1"; + + /// Fallback when the provider-specific key is not set. pub OPENAI_API_KEY: Option<&str> = None; pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| { value.parse().map(Duration::from_millis) @@ -21,9 +20,6 @@ env_flags! { value.parse().map(Duration::from_millis) }; + /// Fixture path for offline tests (see client.rs). pub CODEX_RS_SSE_FIXTURE: Option<&str> = None; } - -pub fn get_api_key() -> Result<&'static str> { - OPENAI_API_KEY.ok_or_else(|| CodexErr::EnvVar("OPENAI_API_KEY")) -} diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index ef671a94..1c3a46df 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -7,6 +7,7 @@ mod client; pub mod codex; +pub use codex::Codex; pub mod codex_wrapper; pub mod config; pub mod error; @@ -18,6 +19,8 @@ pub mod linux; mod mcp_connection_manager; pub mod mcp_server_config; mod mcp_tool_call; +mod model_provider_info; +pub use model_provider_info::ModelProviderInfo; mod models; pub mod protocol; mod rollout; @@ -25,5 +28,3 @@ mod safety; mod user_notification; pub mod util; mod zdr_transcript; - -pub use codex::Codex; diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs new file mode 100644 index 00000000..e7069c04 --- /dev/null +++ b/codex-rs/core/src/model_provider_info.rs @@ -0,0 +1,103 @@ +//! Registry of model providers supported by Codex. +//! +//! Providers can be defined in two places: +//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box. +//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers` +//! key. These override or extend the defaults at runtime. + +use serde::Deserialize; +use serde::Serialize; +use std::collections::HashMap; + +/// Serializable representation of a provider definition. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ModelProviderInfo { + /// Friendly display name. + pub name: String, + /// Base URL for the provider's OpenAI-compatible API. + pub base_url: String, + /// Environment variable that stores the user's API key for this provider. + pub env_key: String, +} + +impl ModelProviderInfo { + /// Returns the API key for this provider if present in the environment. + pub fn api_key(&self) -> Option { + std::env::var(&self.env_key).ok() + } +} + +/// Built-in default provider list. +pub fn built_in_model_providers() -> HashMap { + use ModelProviderInfo as P; + + [ + ( + "openai", + P { + name: "OpenAI".into(), + base_url: "https://api.openai.com/v1".into(), + env_key: "OPENAI_API_KEY".into(), + }, + ), + ( + "openrouter", + P { + name: "OpenRouter".into(), + base_url: "https://openrouter.ai/api/v1".into(), + env_key: "OPENROUTER_API_KEY".into(), + }, + ), + ( + "gemini", + P { + name: "Gemini".into(), + base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(), + env_key: "GEMINI_API_KEY".into(), + }, + ), + ( + "ollama", + P { + name: "Ollama".into(), + base_url: "http://localhost:11434/v1".into(), + env_key: "OLLAMA_API_KEY".into(), + }, + ), + ( + "mistral", + P { + name: "Mistral".into(), + base_url: "https://api.mistral.ai/v1".into(), + env_key: "MISTRAL_API_KEY".into(), + }, + ), + ( + "deepseek", + P { + name: "DeepSeek".into(), + base_url: "https://api.deepseek.com".into(), + env_key: "DEEPSEEK_API_KEY".into(), + }, + ), + ( + "xai", + P { + name: "xAI".into(), + base_url: "https://api.x.ai/v1".into(), + env_key: "XAI_API_KEY".into(), + }, + ), + ( + "groq", + P { + name: "Groq".into(), + base_url: "https://api.groq.com/openai/v1".into(), + env_key: "GROQ_API_KEY".into(), + }, + ), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect() +} diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 4796381d..613dfe72 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -11,6 +11,8 @@ use mcp_types::CallToolResult; use serde::Deserialize; use serde::Serialize; +use crate::model_provider_info::ModelProviderInfo; + /// Submission Queue Entry - requests from user #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Submission { @@ -27,6 +29,9 @@ pub struct Submission { pub enum Op { /// Configure the model session. ConfigureSession { + /// Provider identifier ("openai", "openrouter", ...). + provider: ModelProviderInfo, + /// If not specified, server will use its default model. model: String, /// Model instructions diff --git a/codex-rs/core/tests/previous_response_id.rs b/codex-rs/core/tests/previous_response_id.rs index de1b1b2b..50c1ba39 100644 --- a/codex-rs/core/tests/previous_response_id.rs +++ b/codex-rs/core/tests/previous_response_id.rs @@ -1,6 +1,7 @@ use std::time::Duration; use codex_core::Codex; +use codex_core::ModelProviderInfo; use codex_core::config::Config; use codex_core::protocol::InputItem; use codex_core::protocol::Op; @@ -80,14 +81,21 @@ async fn keeps_previous_response_id_between_tasks() { // Update environment – `set_var` is `unsafe` starting with the 2024 // edition so we group the calls into a single `unsafe { … }` block. unsafe { - std::env::set_var("OPENAI_API_KEY", "test-key"); - std::env::set_var("OPENAI_API_BASE", server.uri()); std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0"); } + let model_provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + // Environment variable that should exist in the test environment. + // ModelClient will return an error if the environment variable for the + // provider is not set. + env_key: "PATH".into(), + }; // Init session - let config = Config::load_default_config_for_test(); + let mut config = Config::load_default_config_for_test(); + config.model_provider = model_provider; let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 061f9b2f..1af5fc4a 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -4,6 +4,7 @@ use std::time::Duration; use codex_core::Codex; +use codex_core::ModelProviderInfo; use codex_core::config::Config; use codex_core::protocol::InputItem; use codex_core::protocol::Op; @@ -68,15 +69,23 @@ async fn retries_on_early_close() { // scope is very small and clearly delineated. unsafe { - std::env::set_var("OPENAI_API_KEY", "test-key"); - std::env::set_var("OPENAI_API_BASE", server.uri()); std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0"); std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1"); std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000"); } + let model_provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + // Environment variable that should exist in the test environment. + // ModelClient will return an error if the environment variable for the + // provider is not set. + env_key: "PATH".into(), + }; + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); - let config = Config::load_default_config_for_test(); + let mut config = Config::load_default_config_for_test(); + config.model_provider = model_provider; let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap(); codex diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 1bd5069e..cb11ca62 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -66,6 +66,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> { None }, cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), + provider: None, }; let config = Config::load_with_overrides(overrides)?; diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index d05ec154..89b19f72 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -158,6 +158,7 @@ impl CodexToolCallParam { approval_policy: approval_policy.map(Into::into), sandbox_policy, disable_response_storage, + provider: None, }; let cfg = codex_core::config::Config::load_with_overrides(overrides)?; diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 30169699..a7de9aae 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -58,6 +58,7 @@ pub fn run_main(cli: Cli) -> std::io::Result<()> { None }, cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)), + provider: None, }; #[allow(clippy::print_stderr)] match Config::load_with_overrides(overrides) {