feat: introduce --profile for Rust CLI (#921)

This introduces a much-needed "profile" concept where users can specify
a collection of options under one name and then pass that via
`--profile` to the CLI.

This PR introduces the `ConfigProfile` struct and makes it a field of
`CargoToml`. It further updates
`Config::load_from_base_config_with_overrides()` to respect
`ConfigProfile`, overriding default values where appropriate. A detailed
unit test is added at the end of `config.rs` to verify this behavior.

Details on how to use this feature have also been added to
`codex-rs/README.md`.
This commit is contained in:
Michael Bolin
2025-05-13 16:52:52 -07:00
committed by GitHub
parent ae809f3721
commit 3c03c25e56
14 changed files with 309 additions and 15 deletions

View File

@@ -58,5 +58,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
[dev-dependencies]
assert_cmd = "2"
predicates = "3"
pretty_assertions = "1.4.1"
tempfile = "3"
wiremock = "0.6"

View File

@@ -1,3 +1,4 @@
use crate::config_profile::ConfigProfile;
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::mcp_server_config::McpServerConfig;
use crate::model_provider_info::ModelProviderInfo;
@@ -8,6 +9,7 @@ use crate::protocol::SandboxPolicy;
use dirs::home_dir;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
/// Maximum number of bytes of the documentation that will be embedded. Larger
@@ -16,7 +18,7 @@ use std::path::PathBuf;
pub(crate) const PROJECT_DOC_MAX_BYTES: usize = 32 * 1024; // 32 KiB
/// Application configuration loaded from disk and merged with overrides.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
/// Optional override of model selection.
pub model: String,
@@ -117,6 +119,13 @@ pub struct ConfigToml {
/// Maximum number of bytes to include from an AGENTS.md project doc file.
pub project_doc_max_bytes: Option<usize>,
/// Profile to use from the `profiles` map.
pub profile: Option<String>,
/// Named profiles to facilitate switching between different configurations.
#[serde(default)]
pub profiles: HashMap<String, ConfigProfile>,
}
impl ConfigToml {
@@ -176,7 +185,8 @@ pub struct ConfigOverrides {
pub approval_policy: Option<AskForApproval>,
pub sandbox_policy: Option<SandboxPolicy>,
pub disable_response_storage: Option<bool>,
pub provider: Option<String>,
pub model_provider: Option<String>,
pub config_profile: Option<String>,
}
impl Config {
@@ -186,14 +196,16 @@ impl Config {
pub fn load_with_overrides(overrides: ConfigOverrides) -> std::io::Result<Self> {
let cfg: ConfigToml = ConfigToml::load_from_toml()?;
tracing::warn!("Config parsed from config.toml: {cfg:?}");
Self::load_from_base_config_with_overrides(cfg, overrides)
let codex_dir = codex_dir().ok();
Self::load_from_base_config_with_overrides(cfg, overrides, codex_dir.as_deref())
}
fn load_from_base_config_with_overrides(
cfg: ConfigToml,
overrides: ConfigOverrides,
codex_dir: Option<&Path>,
) -> std::io::Result<Self> {
let instructions = Self::load_instructions();
let instructions = Self::load_instructions(codex_dir);
// Destructure ConfigOverrides fully to ensure all overrides are applied.
let ConfigOverrides {
@@ -202,9 +214,24 @@ impl Config {
approval_policy,
sandbox_policy,
disable_response_storage,
provider,
model_provider,
config_profile: config_profile_key,
} = overrides;
let config_profile = match config_profile_key.or(cfg.profile) {
Some(key) => cfg
.profiles
.get(&key)
.ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("config profile `{key}` not found"),
)
})?
.clone(),
None => ConfigProfile::default(),
};
let sandbox_policy = match sandbox_policy {
Some(sandbox_policy) => sandbox_policy,
None => {
@@ -226,7 +253,8 @@ impl Config {
model_providers.entry(key).or_insert(provider);
}
let model_provider_id = provider
let model_provider_id = model_provider
.or(config_profile.model_provider)
.or(cfg.model_provider)
.unwrap_or_else(|| "openai".to_string());
let model_provider = model_providers
@@ -259,15 +287,20 @@ impl Config {
};
let config = Self {
model: model.or(cfg.model).unwrap_or_else(default_model),
model: model
.or(config_profile.model)
.or(cfg.model)
.unwrap_or_else(default_model),
model_provider_id,
model_provider,
cwd: resolved_cwd,
approval_policy: approval_policy
.or(config_profile.approval_policy)
.or(cfg.approval_policy)
.unwrap_or_else(AskForApproval::default),
sandbox_policy,
disable_response_storage: disable_response_storage
.or(config_profile.disable_response_storage)
.or(cfg.disable_response_storage)
.unwrap_or(false),
notify: cfg.notify,
@@ -279,8 +312,12 @@ impl Config {
Ok(config)
}
fn load_instructions() -> Option<String> {
let mut p = codex_dir().ok()?;
fn load_instructions(codex_dir: Option<&Path>) -> Option<String> {
let mut p = match codex_dir {
Some(p) => p.to_path_buf(),
None => return None,
};
p.push("instructions.md");
std::fs::read_to_string(&p).ok().and_then(|s| {
let s = s.trim();
@@ -299,6 +336,7 @@ impl Config {
Self::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
None,
)
.expect("defaults for test should always succeed")
}
@@ -377,6 +415,8 @@ pub fn parse_sandbox_permission_with_base_path(
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use pretty_assertions::assert_eq;
use tempfile::TempDir;
/// Verify that the `sandbox_permissions` field on `ConfigToml` correctly
/// differentiates between a value that is completely absent in the
@@ -429,4 +469,173 @@ mod tests {
let msg = err.to_string();
assert!(msg.contains("not-a-real-permission"));
}
/// Users can specify config values at multiple levels that have the
/// following precedence:
///
/// 1. custom command-line argument, e.g. `--model o3`
/// 2. as part of a profile, where the `--profile` is specified via a CLI
/// (or in the config file itelf)
/// 3. as an entry in `config.toml`, e.g. `model = "o3"`
/// 4. the default value for a required field defined in code, e.g.,
/// `crate::flags::OPENAI_DEFAULT_MODEL`
///
/// Note that profiles are the recommended way to specify a group of
/// configuration options together.
#[test]
fn test_precedence_overrides_then_profile_then_config_toml() -> std::io::Result<()> {
let toml = r#"
model = "o3"
approval_policy = "unless-allow-listed"
sandbox_permissions = ["disk-full-read-access"]
disable_response_storage = false
# Can be used to determine which profile to use if not specified by
# `ConfigOverrides`.
profile = "gpt3"
[model_providers.openai-chat-completions]
name = "OpenAI using Chat Completions"
base_url = "https://api.openai.com/v1"
env_key = "OPENAI_API_KEY"
wire_api = "chat"
[profiles.o3]
model = "o3"
model_provider = "openai"
approval_policy = "never"
[profiles.gpt3]
model = "gpt-3.5-turbo"
model_provider = "openai-chat-completions"
[profiles.zdr]
model = "o3"
model_provider = "openai"
approval_policy = "on-failure"
disable_response_storage = true
"#;
let cfg: ConfigToml = toml::from_str(toml).expect("TOML deserialization should succeed");
// Use a temporary directory for the cwd so it does not contain an
// AGENTS.md file.
let cwd_temp_dir = TempDir::new().unwrap();
let cwd = cwd_temp_dir.path().to_path_buf();
// Make it look like a Git repo so it does not search for AGENTS.md in
// a parent folder, either.
std::fs::write(cwd.join(".git"), "gitdir: nowhere")?;
let openai_chat_completions_provider = ModelProviderInfo {
name: "OpenAI using Chat Completions".to_string(),
base_url: "https://api.openai.com/v1".to_string(),
env_key: Some("OPENAI_API_KEY".to_string()),
wire_api: crate::WireApi::Chat,
env_key_instructions: None,
};
let model_provider_map = {
let mut model_provider_map = built_in_model_providers();
model_provider_map.insert(
"openai-chat-completions".to_string(),
openai_chat_completions_provider.clone(),
);
model_provider_map
};
let openai_provider = model_provider_map
.get("openai")
.expect("openai provider should exist")
.clone();
let o3_profile_overrides = ConfigOverrides {
config_profile: Some("o3".to_string()),
cwd: Some(cwd.clone()),
..Default::default()
};
let o3_profile_config =
Config::load_from_base_config_with_overrides(cfg.clone(), o3_profile_overrides, None)?;
assert_eq!(
Config {
model: "o3".to_string(),
model_provider_id: "openai".to_string(),
model_provider: openai_provider.clone(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
disable_response_storage: false,
instructions: None,
notify: None,
cwd: cwd.clone(),
mcp_servers: HashMap::new(),
model_providers: model_provider_map.clone(),
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
},
o3_profile_config
);
let gpt3_profile_overrides = ConfigOverrides {
config_profile: Some("gpt3".to_string()),
cwd: Some(cwd.clone()),
..Default::default()
};
let gpt3_profile_config = Config::load_from_base_config_with_overrides(
cfg.clone(),
gpt3_profile_overrides,
None,
)?;
let expected_gpt3_profile_config = Config {
model: "gpt-3.5-turbo".to_string(),
model_provider_id: "openai-chat-completions".to_string(),
model_provider: openai_chat_completions_provider,
approval_policy: AskForApproval::UnlessAllowListed,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
disable_response_storage: false,
instructions: None,
notify: None,
cwd: cwd.clone(),
mcp_servers: HashMap::new(),
model_providers: model_provider_map.clone(),
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
};
assert_eq!(expected_gpt3_profile_config.clone(), gpt3_profile_config);
// Verify that loading without specifying a profile in ConfigOverrides
// uses the default profile from the config file.
let default_profile_overrides = ConfigOverrides {
cwd: Some(cwd.clone()),
..Default::default()
};
let default_profile_config = Config::load_from_base_config_with_overrides(
cfg.clone(),
default_profile_overrides,
None,
)?;
assert_eq!(expected_gpt3_profile_config, default_profile_config);
let zdr_profile_overrides = ConfigOverrides {
config_profile: Some("zdr".to_string()),
cwd: Some(cwd.clone()),
..Default::default()
};
let zdr_profile_config =
Config::load_from_base_config_with_overrides(cfg.clone(), zdr_profile_overrides, None)?;
assert_eq!(
Config {
model: "o3".to_string(),
model_provider_id: "openai".to_string(),
model_provider: openai_provider.clone(),
approval_policy: AskForApproval::OnFailure,
sandbox_policy: SandboxPolicy::new_read_only_policy(),
disable_response_storage: true,
instructions: None,
notify: None,
cwd: cwd.clone(),
mcp_servers: HashMap::new(),
model_providers: model_provider_map.clone(),
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
},
zdr_profile_config
);
Ok(())
}
}

View File

@@ -0,0 +1,15 @@
use serde::Deserialize;
use crate::protocol::AskForApproval;
/// Collection of common configuration options that a user can define as a unit
/// in `config.toml`.
#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
pub struct ConfigProfile {
pub model: Option<String>,
/// The key in the `model_providers` map identifying the
/// [`ModelProviderInfo`] to use.
pub model_provider: Option<String>,
pub approval_policy: Option<AskForApproval>,
pub disable_response_storage: Option<bool>,
}

View File

@@ -3,7 +3,7 @@ use std::time::Duration;
use env_flags::env_flags;
env_flags! {
pub OPENAI_DEFAULT_MODEL: &str = "o3";
pub OPENAI_DEFAULT_MODEL: &str = "o4-mini";
pub OPENAI_API_BASE: &str = "https://api.openai.com/v1";
/// Fallback when the provider-specific key is not set.

View File

@@ -13,6 +13,7 @@ pub mod codex;
pub use codex::Codex;
pub mod codex_wrapper;
pub mod config;
pub mod config_profile;
mod conversation_history;
pub mod error;
pub mod exec;

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
pub command: String,

View File

@@ -29,7 +29,7 @@ pub enum WireApi {
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,