feat: read model_provider and model_providers from config.toml (#853)
This is the first step in supporting other model providers in the Rust CLI. Specifically, this PR adds support for the new entries in `Config` and `ConfigOverrides` to specify a `ModelProviderInfo`, which is the basic config needed for an LLM provider. This PR does not get us all the way there yet because `client.rs` still categorically appends `/responses` to the URL and expects the endpoint to support the OpenAI Responses API. Will fix that next!
This commit is contained in:
@@ -26,10 +26,9 @@ use tracing::warn;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_API_BASE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::flags::get_api_key;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::util::backoff;
|
||||
|
||||
@@ -141,13 +140,16 @@ static DEFAULT_TOOLS: LazyLock<Vec<ResponsesApiTool>> = LazyLock::new(|| {
|
||||
pub struct ModelClient {
|
||||
model: String,
|
||||
client: reqwest::Client,
|
||||
provider: ModelProviderInfo,
|
||||
}
|
||||
|
||||
impl ModelClient {
|
||||
pub fn new(model: impl ToString) -> Self {
|
||||
let model = model.to_string();
|
||||
let client = reqwest::Client::new();
|
||||
Self { model, client }
|
||||
pub fn new(model: impl ToString, provider: ModelProviderInfo) -> Self {
|
||||
Self {
|
||||
model: model.to_string(),
|
||||
client: reqwest::Client::new(),
|
||||
provider,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
@@ -188,7 +190,9 @@ impl ModelClient {
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let url = format!("{}/v1/responses", *OPENAI_API_BASE);
|
||||
let base_url = self.provider.base_url.clone();
|
||||
let base_url = base_url.trim_end_matches('/');
|
||||
let url = format!("{}/responses", base_url);
|
||||
debug!(url, "POST");
|
||||
trace!("request payload: {}", serde_json::to_string(&payload)?);
|
||||
|
||||
@@ -196,10 +200,14 @@ impl ModelClient {
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let api_key = self
|
||||
.provider
|
||||
.api_key()
|
||||
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
|
||||
let res = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(get_api_key()?)
|
||||
.bearer_auth(api_key)
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
|
||||
@@ -80,6 +80,7 @@ impl Codex {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||
let (tx_event, rx_event) = async_channel::bounded(64);
|
||||
let configure_session = Op::ConfigureSession {
|
||||
provider: config.model_provider.clone(),
|
||||
model: config.model.clone(),
|
||||
instructions: config.instructions.clone(),
|
||||
approval_policy: config.approval_policy,
|
||||
@@ -504,6 +505,7 @@ async fn submission_loop(
|
||||
sess.abort();
|
||||
}
|
||||
Op::ConfigureSession {
|
||||
provider,
|
||||
model,
|
||||
instructions,
|
||||
approval_policy,
|
||||
@@ -512,7 +514,7 @@ async fn submission_loop(
|
||||
notify,
|
||||
cwd,
|
||||
} => {
|
||||
info!(model, "Configuring session");
|
||||
info!("Configuring session: model={model}; provider={provider:?}");
|
||||
if !cwd.is_absolute() {
|
||||
let message = format!("cwd is not absolute: {cwd:?}");
|
||||
error!(message);
|
||||
@@ -526,7 +528,7 @@ async fn submission_loop(
|
||||
return;
|
||||
}
|
||||
|
||||
let client = ModelClient::new(model.clone());
|
||||
let client = ModelClient::new(model.clone(), provider.clone());
|
||||
|
||||
// abort any current running session and clone its state
|
||||
let state = match sess.take() {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use crate::flags::OPENAI_DEFAULT_MODEL;
|
||||
use crate::mcp_server_config::McpServerConfig;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPermission;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
@@ -19,6 +21,9 @@ pub struct Config {
|
||||
/// Optional override of model selection.
|
||||
pub model: String,
|
||||
|
||||
/// Info needed to make an API request to the model.
|
||||
pub model_provider: ModelProviderInfo,
|
||||
|
||||
/// Approval policy for executing commands.
|
||||
pub approval_policy: AskForApproval,
|
||||
|
||||
@@ -61,6 +66,9 @@ pub struct Config {
|
||||
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
}
|
||||
|
||||
/// Base config deserialized from ~/.codex/config.toml.
|
||||
@@ -69,6 +77,9 @@ pub struct ConfigToml {
|
||||
/// Optional override of model selection.
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Provider to use from the model_providers map.
|
||||
pub model_provider: Option<String>,
|
||||
|
||||
/// Default approval policy for executing commands.
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
|
||||
@@ -93,6 +104,10 @@ pub struct ConfigToml {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
}
|
||||
|
||||
impl ConfigToml {
|
||||
@@ -152,6 +167,7 @@ pub struct ConfigOverrides {
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_policy: Option<SandboxPolicy>,
|
||||
pub disable_response_storage: Option<bool>,
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@@ -161,10 +177,13 @@ impl Config {
|
||||
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
|
||||
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
|
||||
tracing::warn!("Config parsed from config.toml: {cfg:?}");
|
||||
Ok(Self::load_from_base_config_with_overrides(cfg, overrides))
|
||||
Self::load_from_base_config_with_overrides(cfg, overrides)
|
||||
}
|
||||
|
||||
fn load_from_base_config_with_overrides(cfg: ConfigToml, overrides: ConfigOverrides) -> Self {
|
||||
fn load_from_base_config_with_overrides(
|
||||
cfg: ConfigToml,
|
||||
overrides: ConfigOverrides,
|
||||
) -> std::io::Result<Self> {
|
||||
// Instructions: user-provided instructions.md > embedded default.
|
||||
let instructions =
|
||||
Self::load_instructions().or_else(|| Some(EMBEDDED_INSTRUCTIONS.to_string()));
|
||||
@@ -176,6 +195,7 @@ impl Config {
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
provider,
|
||||
} = overrides;
|
||||
|
||||
let sandbox_policy = match sandbox_policy {
|
||||
@@ -193,8 +213,28 @@ impl Config {
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
let mut model_providers = built_in_model_providers();
|
||||
// Merge user-defined providers into the built-in list.
|
||||
for (key, provider) in cfg.model_providers.into_iter() {
|
||||
model_providers.entry(key).or_insert(provider);
|
||||
}
|
||||
|
||||
let model_provider_name = provider
|
||||
.or(cfg.model_provider)
|
||||
.unwrap_or_else(|| "openai".to_string());
|
||||
let model_provider = model_providers
|
||||
.get(&model_provider_name)
|
||||
.ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("Model provider `{model_provider_name}` not found"),
|
||||
)
|
||||
})?
|
||||
.clone();
|
||||
|
||||
let config = Self {
|
||||
model: model.or(cfg.model).unwrap_or_else(default_model),
|
||||
model_provider,
|
||||
cwd: cwd.map_or_else(
|
||||
|| {
|
||||
tracing::info!("cwd not set, using current dir");
|
||||
@@ -222,7 +262,9 @@ impl Config {
|
||||
notify: cfg.notify,
|
||||
instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
}
|
||||
model_providers,
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn load_instructions() -> Option<String> {
|
||||
@@ -238,6 +280,7 @@ impl Config {
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
)
|
||||
.expect("defaults for test should always succeed")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,12 +2,11 @@ use std::time::Duration;
|
||||
|
||||
use env_flags::env_flags;
|
||||
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
|
||||
env_flags! {
|
||||
pub OPENAI_DEFAULT_MODEL: &str = "o3";
|
||||
pub OPENAI_API_BASE: &str = "https://api.openai.com";
|
||||
pub OPENAI_API_BASE: &str = "https://api.openai.com/v1";
|
||||
|
||||
/// Fallback when the provider-specific key is not set.
|
||||
pub OPENAI_API_KEY: Option<&str> = None;
|
||||
pub OPENAI_TIMEOUT_MS: Duration = Duration::from_millis(300_000), |value| {
|
||||
value.parse().map(Duration::from_millis)
|
||||
@@ -21,9 +20,6 @@ env_flags! {
|
||||
value.parse().map(Duration::from_millis)
|
||||
};
|
||||
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
}
|
||||
|
||||
pub fn get_api_key() -> Result<&'static str> {
|
||||
OPENAI_API_KEY.ok_or_else(|| CodexErr::EnvVar("OPENAI_API_KEY"))
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
mod client;
|
||||
pub mod codex;
|
||||
pub use codex::Codex;
|
||||
pub mod codex_wrapper;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
@@ -18,6 +19,8 @@ pub mod linux;
|
||||
mod mcp_connection_manager;
|
||||
pub mod mcp_server_config;
|
||||
mod mcp_tool_call;
|
||||
mod model_provider_info;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
mod models;
|
||||
pub mod protocol;
|
||||
mod rollout;
|
||||
@@ -25,5 +28,3 @@ mod safety;
|
||||
mod user_notification;
|
||||
pub mod util;
|
||||
mod zdr_transcript;
|
||||
|
||||
pub use codex::Codex;
|
||||
|
||||
103
codex-rs/core/src/model_provider_info.rs
Normal file
103
codex-rs/core/src/model_provider_info.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
//! Registry of model providers supported by Codex.
|
||||
//!
|
||||
//! Providers can be defined in two places:
|
||||
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ModelProviderInfo {
|
||||
/// Friendly display name.
|
||||
pub name: String,
|
||||
/// Base URL for the provider's OpenAI-compatible API.
|
||||
pub base_url: String,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: String,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Returns the API key for this provider if present in the environment.
|
||||
pub fn api_key(&self) -> Option<String> {
|
||||
std::env::var(&self.env_key).ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in default provider list.
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
base_url: "https://api.openai.com/v1".into(),
|
||||
env_key: "OPENAI_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"openrouter",
|
||||
P {
|
||||
name: "OpenRouter".into(),
|
||||
base_url: "https://openrouter.ai/api/v1".into(),
|
||||
env_key: "OPENROUTER_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"gemini",
|
||||
P {
|
||||
name: "Gemini".into(),
|
||||
base_url: "https://generativelanguage.googleapis.com/v1beta/openai".into(),
|
||||
env_key: "GEMINI_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"ollama",
|
||||
P {
|
||||
name: "Ollama".into(),
|
||||
base_url: "http://localhost:11434/v1".into(),
|
||||
env_key: "OLLAMA_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"mistral",
|
||||
P {
|
||||
name: "Mistral".into(),
|
||||
base_url: "https://api.mistral.ai/v1".into(),
|
||||
env_key: "MISTRAL_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"deepseek",
|
||||
P {
|
||||
name: "DeepSeek".into(),
|
||||
base_url: "https://api.deepseek.com".into(),
|
||||
env_key: "DEEPSEEK_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"xai",
|
||||
P {
|
||||
name: "xAI".into(),
|
||||
base_url: "https://api.x.ai/v1".into(),
|
||||
env_key: "XAI_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
(
|
||||
"groq",
|
||||
P {
|
||||
name: "Groq".into(),
|
||||
base_url: "https://api.groq.com/openai/v1".into(),
|
||||
env_key: "GROQ_API_KEY".into(),
|
||||
},
|
||||
),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
@@ -11,6 +11,8 @@ use mcp_types::CallToolResult;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
|
||||
/// Submission Queue Entry - requests from user
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Submission {
|
||||
@@ -27,6 +29,9 @@ pub struct Submission {
|
||||
pub enum Op {
|
||||
/// Configure the model session.
|
||||
ConfigureSession {
|
||||
/// Provider identifier ("openai", "openrouter", ...).
|
||||
provider: ModelProviderInfo,
|
||||
|
||||
/// If not specified, server will use its default model.
|
||||
model: String,
|
||||
/// Model instructions
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -80,14 +81,21 @@ async fn keeps_previous_response_id_between_tasks() {
|
||||
// Update environment – `set_var` is `unsafe` starting with the 2024
|
||||
// edition so we group the calls into a single `unsafe { … }` block.
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_API_KEY", "test-key");
|
||||
std::env::set_var("OPENAI_API_BASE", server.uri());
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "0");
|
||||
}
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: "PATH".into(),
|
||||
};
|
||||
|
||||
// Init session
|
||||
let config = Config::load_default_config_for_test();
|
||||
let mut config = Config::load_default_config_for_test();
|
||||
config.model_provider = model_provider;
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -68,15 +69,23 @@ async fn retries_on_early_close() {
|
||||
// scope is very small and clearly delineated.
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("OPENAI_API_KEY", "test-key");
|
||||
std::env::set_var("OPENAI_API_BASE", server.uri());
|
||||
std::env::set_var("OPENAI_REQUEST_MAX_RETRIES", "0");
|
||||
std::env::set_var("OPENAI_STREAM_MAX_RETRIES", "1");
|
||||
std::env::set_var("OPENAI_STREAM_IDLE_TIMEOUT_MS", "2000");
|
||||
}
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: "PATH".into(),
|
||||
};
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let config = Config::load_default_config_for_test();
|
||||
let mut config = Config::load_default_config_for_test();
|
||||
config.model_provider = model_provider;
|
||||
let (codex, _init_id) = Codex::spawn(config, ctrl_c).await.unwrap();
|
||||
|
||||
codex
|
||||
|
||||
@@ -66,6 +66,7 @@ pub async fn run_main(cli: Cli) -> anyhow::Result<()> {
|
||||
None
|
||||
},
|
||||
cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)),
|
||||
provider: None,
|
||||
};
|
||||
let config = Config::load_with_overrides(overrides)?;
|
||||
|
||||
|
||||
@@ -158,6 +158,7 @@ impl CodexToolCallParam {
|
||||
approval_policy: approval_policy.map(Into::into),
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
provider: None,
|
||||
};
|
||||
|
||||
let cfg = codex_core::config::Config::load_with_overrides(overrides)?;
|
||||
|
||||
@@ -58,6 +58,7 @@ pub fn run_main(cli: Cli) -> std::io::Result<()> {
|
||||
None
|
||||
},
|
||||
cwd: cli.cwd.clone().map(|p| p.canonicalize().unwrap_or(p)),
|
||||
provider: None,
|
||||
};
|
||||
#[allow(clippy::print_stderr)]
|
||||
match Config::load_with_overrides(overrides) {
|
||||
|
||||
Reference in New Issue
Block a user