diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index da19466c..6289625a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1374,6 +1374,7 @@ dependencies = [ "rmcp", "serde", "serde_json", + "serial_test", "sha2", "tempfile", "tiny_http", diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index e7fd7b8d..50274f3e 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -6,6 +6,7 @@ use anyhow::anyhow; use anyhow::bail; use clap::ArgGroup; use codex_common::CliConfigOverrides; +use codex_common::format_env_display::format_env_display; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::find_codex_home; @@ -227,6 +228,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re command: command_bin, args: command_args, env: env_map, + env_vars: Vec::new(), + cwd: None, } } AddMcpTransportArgs { @@ -400,11 +403,19 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> .copied() .unwrap_or(McpAuthStatus::Unsupported); let transport = match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => serde_json::json!({ "type": "stdio", "command": command, "args": args, "env": env, + "env_vars": env_vars, + "cwd": cwd, }), McpServerTransportConfig::StreamableHttp { url, @@ -446,30 +457,29 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> return Ok(()); } - let mut stdio_rows: Vec<[String; 6]> = Vec::new(); + let mut stdio_rows: Vec<[String; 7]> = Vec::new(); let mut http_rows: Vec<[String; 5]> = Vec::new(); for (name, cfg) in entries { match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let args_display = if args.is_empty() { "-".to_string() } else { args.join(" ") }; - let env_display = match env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") - } - }; + let env_display = format_env_display(env.as_ref(), env_vars); + let cwd_display = cwd + .as_ref() + .map(|path| path.display().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "-".to_string()); let status = if cfg.enabled { "enabled".to_string() } else { @@ -485,6 +495,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> command.clone(), args_display, env_display, + cwd_display, status, auth_status, ]); @@ -521,6 +532,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> "Command".len(), "Args".len(), "Env".len(), + "Cwd".len(), "Status".len(), "Auth".len(), ]; @@ -531,36 +543,40 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> } println!( - "{name: Re if get_args.json { let transport = match &server.transport { - McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => serde_json::json!({ "type": "stdio", "command": command, "args": args, "env": env, + "env_vars": env_vars, + "cwd": cwd, }), McpServerTransportConfig::StreamableHttp { url, @@ -666,7 +690,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re println!("{}", get_args.name); println!(" enabled: {}", server.enabled); match &server.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { println!(" transport: stdio"); println!(" command: {command}"); let args_display = if args.is_empty() { @@ -675,19 +705,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re args.join(" ") }; println!(" args: {args_display}"); - let env_display = match env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") - } - }; + let cwd_display = cwd + .as_ref() + .map(|path| path.display().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "-".to_string()); + println!(" cwd: {cwd_display}"); + let env_display = format_env_display(env.as_ref(), env_vars); println!(" env: {env_display}"); } McpServerTransportConfig::StreamableHttp { diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 83abe72c..7a6c2daa 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -28,10 +28,18 @@ async fn add_and_remove_server_updates_global_config() -> Result<()> { assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); assert!(env.is_none()); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); } other => panic!("unexpected transport: {other:?}"), } diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index 8f33a8e4..ea0d6fc1 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -1,6 +1,9 @@ use std::path::Path; use anyhow::Result; +use codex_core::config::load_global_mcp_servers; +use codex_core::config::write_global_mcp_servers; +use codex_core::config_types::McpServerTransportConfig; use predicates::prelude::PredicateBooleanExt; use predicates::str::contains; use pretty_assertions::assert_eq; @@ -27,8 +30,8 @@ fn list_shows_empty_state() -> Result<()> { Ok(()) } -#[test] -fn list_and_get_render_expected_output() -> Result<()> { +#[tokio::test] +async fn list_and_get_render_expected_output() -> Result<()> { let codex_home = TempDir::new()?; let mut add = codex_command(codex_home.path())?; @@ -46,6 +49,18 @@ fn list_and_get_render_expected_output() -> Result<()> { .assert() .success(); + let mut servers = load_global_mcp_servers(codex_home.path()).await?; + let docs_entry = servers + .get_mut("docs") + .expect("docs server should exist after add"); + match &mut docs_entry.transport { + McpServerTransportConfig::Stdio { env_vars, .. } => { + *env_vars = vec!["APP_TOKEN".to_string(), "WORKSPACE_ID".to_string()]; + } + other => panic!("unexpected transport: {other:?}"), + } + write_global_mcp_servers(codex_home.path(), &servers)?; + let mut list_cmd = codex_command(codex_home.path())?; let list_output = list_cmd.args(["mcp", "list"]).output()?; assert!(list_output.status.success()); @@ -54,6 +69,8 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(stdout.contains("docs")); assert!(stdout.contains("docs-server")); assert!(stdout.contains("TOKEN=secret")); + assert!(stdout.contains("APP_TOKEN=$APP_TOKEN")); + assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID")); assert!(stdout.contains("Status")); assert!(stdout.contains("Auth")); assert!(stdout.contains("enabled")); @@ -79,7 +96,12 @@ fn list_and_get_render_expected_output() -> Result<()> { ], "env": { "TOKEN": "secret" - } + }, + "env_vars": [ + "APP_TOKEN", + "WORKSPACE_ID" + ], + "cwd": null }, "startup_timeout_sec": null, "tool_timeout_sec": null, @@ -98,6 +120,8 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(stdout.contains("command: docs-server")); assert!(stdout.contains("args: --port 4000")); assert!(stdout.contains("env: TOKEN=secret")); + assert!(stdout.contains("APP_TOKEN=$APP_TOKEN")); + assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID")); assert!(stdout.contains("enabled: true")); assert!(stdout.contains("remove: codex mcp remove docs")); diff --git a/codex-rs/common/src/format_env_display.rs b/codex-rs/common/src/format_env_display.rs new file mode 100644 index 00000000..640be307 --- /dev/null +++ b/codex-rs/common/src/format_env_display.rs @@ -0,0 +1,66 @@ +use std::collections::HashMap; + +pub fn format_env_display(env: Option<&HashMap>, env_vars: &[String]) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(map) = env { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + parts.extend( + pairs + .into_iter() + .map(|(key, value)| format!("{key}={value}")), + ); + } + + if !env_vars.is_empty() { + parts.extend(env_vars.iter().map(|var| format!("{var}=${var}"))); + } + + if parts.is_empty() { + "-".to_string() + } else { + parts.join(", ") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_dash_when_empty() { + assert_eq!(format_env_display(None, &[]), "-"); + + let empty_map = HashMap::new(); + assert_eq!(format_env_display(Some(&empty_map), &[]), "-"); + } + + #[test] + fn formats_sorted_env_pairs() { + let mut env = HashMap::new(); + env.insert("B".to_string(), "two".to_string()); + env.insert("A".to_string(), "one".to_string()); + + assert_eq!(format_env_display(Some(&env), &[]), "A=one, B=two"); + } + + #[test] + fn formats_env_vars_with_dollar_prefix() { + let vars = vec!["TOKEN".to_string(), "PATH".to_string()]; + + assert_eq!(format_env_display(None, &vars), "TOKEN=$TOKEN, PATH=$PATH"); + } + + #[test] + fn combines_env_pairs_and_vars() { + let mut env = HashMap::new(); + env.insert("HOME".to_string(), "/tmp".to_string()); + let vars = vec!["TOKEN".to_string()]; + + assert_eq!( + format_env_display(Some(&env), &vars), + "HOME=/tmp, TOKEN=$TOKEN" + ); + } +} diff --git a/codex-rs/common/src/lib.rs b/codex-rs/common/src/lib.rs index 292503f7..276bfca0 100644 --- a/codex-rs/common/src/lib.rs +++ b/codex-rs/common/src/lib.rs @@ -13,6 +13,9 @@ mod sandbox_mode_cli_arg; #[cfg(feature = "cli")] pub use sandbox_mode_cli_arg::SandboxModeCliArg; +#[cfg(feature = "cli")] +pub mod format_env_display; + #[cfg(any(feature = "cli", test))] mod config_override; diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 570cd515..6f5defce 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -388,7 +388,13 @@ pub fn write_global_mcp_servers( let mut entry = TomlTable::new(); entry.set_implicit(false); match &config.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { entry["command"] = toml_edit::value(command.clone()); if !args.is_empty() { @@ -411,6 +417,15 @@ pub fn write_global_mcp_servers( } entry["env"] = TomlItem::Table(env_table); } + + if !env_vars.is_empty() { + entry["env_vars"] = + TomlItem::Value(env_vars.iter().collect::().into()); + } + + if let Some(cwd) = cwd { + entry["cwd"] = toml_edit::value(cwd.to_string_lossy().to_string()); + } } McpServerTransportConfig::StreamableHttp { url, @@ -1806,6 +1821,8 @@ approve_all = true command: "echo".to_string(), args: vec!["hello".to_string()], env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(3)), @@ -1819,10 +1836,18 @@ approve_all = true assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); assert!(env.is_none()); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); } other => panic!("unexpected transport {other:?}"), } @@ -1932,6 +1957,8 @@ bearer_token = "secret" ("ZIG_VAR".to_string(), "3".to_string()), ("ALPHA_VAR".to_string(), "1".to_string()), ])), + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: None, @@ -1958,7 +1985,13 @@ ZIG_VAR = "3" let loaded = load_global_mcp_servers(codex_home.path()).await?; let docs = loaded.get("docs").expect("docs entry"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "docs-server"); assert_eq!(args, &vec!["--verbose".to_string()]); let env = env @@ -1966,6 +1999,91 @@ ZIG_VAR = "3" .expect("env should be preserved for stdio transport"); assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string())); assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string())); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_serializes_env_vars() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: Vec::new(), + env: None, + env_vars: vec!["ALPHA".to_string(), "BETA".to_string()], + cwd: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains(r#"env_vars = ["ALPHA", "BETA"]"#), + "serialized config missing env_vars field:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { env_vars, .. } => { + assert_eq!(env_vars, &vec!["ALPHA".to_string(), "BETA".to_string()]); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_serializes_cwd() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let cwd_path = PathBuf::from("/tmp/codex-mcp"); + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: Some(cwd_path.clone()), + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains(r#"cwd = "/tmp/codex-mcp""#), + "serialized config missing cwd field:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { cwd, .. } => { + assert_eq!(cwd.as_deref(), Some(Path::new("/tmp/codex-mcp"))); } other => panic!("unexpected transport {other:?}"), } @@ -2205,6 +2323,8 @@ url = "https://example.com/mcp" command: "logs-server".to_string(), args: vec!["--follow".to_string()], env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: None, @@ -2277,6 +2397,8 @@ url = "https://example.com/mcp" command: "docs-server".to_string(), args: Vec::new(), env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: false, startup_timeout_sec: None, diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 3a14d77a..3da61086 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -44,20 +44,26 @@ impl<'de> Deserialize<'de> for McpServerConfig { { #[derive(Deserialize)] struct RawMcpServerConfig { + // stdio command: Option, #[serde(default)] args: Option>, #[serde(default)] env: Option>, #[serde(default)] + env_vars: Option>, + #[serde(default)] + cwd: Option, http_headers: Option>, #[serde(default)] env_http_headers: Option>, + // streamable_http url: Option, bearer_token: Option, bearer_token_env_var: Option, + // shared #[serde(default)] startup_timeout_sec: Option, #[serde(default)] @@ -96,6 +102,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command: Some(command), args, env, + env_vars, + cwd, url, bearer_token_env_var, http_headers, @@ -114,6 +122,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args: args.unwrap_or_default(), env, + env_vars: env_vars.unwrap_or_default(), + cwd, } } RawMcpServerConfig { @@ -123,13 +133,20 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args, env, + env_vars, + cwd, http_headers, env_http_headers, - .. + startup_timeout_sec: _, + tool_timeout_sec: _, + startup_timeout_ms: _, + enabled: _, } => { throw_if_set("streamable_http", "command", command.as_ref())?; throw_if_set("streamable_http", "args", args.as_ref())?; throw_if_set("streamable_http", "env", env.as_ref())?; + throw_if_set("streamable_http", "env_vars", env_vars.as_ref())?; + throw_if_set("streamable_http", "cwd", cwd.as_ref())?; throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?; McpServerTransportConfig::StreamableHttp { url, @@ -164,6 +181,10 @@ pub enum McpServerTransportConfig { args: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] env: Option>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + env_vars: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + cwd: Option, }, /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http StreamableHttp { @@ -500,7 +521,9 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec![], - env: None + env: None, + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); @@ -521,7 +544,9 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec!["hello".to_string(), "world".to_string()], - env: None + env: None, + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); @@ -543,12 +568,58 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec!["hello".to_string(), "world".to_string()], - env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])) + env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])), + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); } + #[test] + fn deserialize_stdio_command_server_config_with_env_vars() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + env_vars = ["FOO", "BAR"] + "#, + ) + .expect("should deserialize command config with env_vars"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: vec!["FOO".to_string(), "BAR".to_string()], + cwd: None, + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_cwd() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + cwd = "/tmp" + "#, + ) + .expect("should deserialize command config with cwd"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: Vec::new(), + cwd: Some(PathBuf::from("/tmp")), + } + ); + } + #[test] fn deserialize_disabled_server_config() { let cfg: McpServerConfig = toml::from_str( diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 8ddb0366..e7ff9983 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::env; use std::ffi::OsString; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -102,20 +103,25 @@ enum McpClientAdapter { } impl McpClientAdapter { + #[allow(clippy::too_many_arguments)] async fn new_stdio_client( use_rmcp_client: bool, program: OsString, args: Vec, env: Option>, + env_vars: Vec, + cwd: Option, params: mcp_types::InitializeRequestParams, startup_timeout: Duration, ) -> Result { if use_rmcp_client { - let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?); + let client = + Arc::new(RmcpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Rmcp(client)) } else { - let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?); + let client = + Arc::new(McpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Legacy(client)) } @@ -256,7 +262,13 @@ impl McpConnectionManager { }; let client = match transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); McpClientAdapter::new_stdio_client( @@ -264,6 +276,8 @@ impl McpConnectionManager { command_os, args_os, env, + env_vars, + cwd, params, startup_timeout, ) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index b2559cd2..1a2815b0 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::ffi::OsStr; use std::ffi::OsString; use std::fs; use std::net::TcpListener; @@ -35,6 +36,7 @@ use tokio::time::sleep; use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_test_value)] async fn stdio_server_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -86,6 +88,8 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { "MCP_TEST_VALUE".to_string(), expected_env_value.to_string(), )])), + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), @@ -106,7 +110,143 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), + Duration::from_secs(10), + ) + .await; + + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ECHOING: ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + server.verify().await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_test_value)] +async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "call-1234"; + let server_name = "rmcp_whitelist"; + let tool_name = format!("{server_name}__echo"); + + mount_sse_once_match( + &server, + any(), + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once_match( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let expected_env_value = "propagated-env-from-whitelist"; + let _guard = EnvVarGuard::set("MCP_TEST_VALUE", OsStr::new(expected_env_value)); + let rmcp_test_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("test_stdio_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let fixture = test_codex() + .with_config(move |config| { + config.features.enable(Feature::RmcpClient); + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: None, + env_vars: vec!["MCP_TEST_VALUE".to_string()], + cwd: None, + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, @@ -257,7 +397,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, @@ -440,7 +580,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index f46058b9..8e1f322d 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { // Spawn the subprocess and connect the client. let program = args.remove(0); let env = None; - let client = McpClient::new_stdio_client(program, args, env) + let client = McpClient::new_stdio_client(program, args, env, &[], None) .await .with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?; diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 27f96494..3be93f35 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -13,6 +13,7 @@ use std::collections::HashMap; use std::ffi::OsString; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -86,19 +87,26 @@ impl McpClient { program: OsString, args: Vec, env: Option>, + env_vars: &[String], + cwd: Option, ) -> std::io::Result { - let mut child = Command::new(program) + let mut command = Command::new(program); + command .args(args) .env_clear() - .envs(create_env_for_mcp_server(env)) + .envs(create_env_for_mcp_server(env, env_vars)) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::null()) // As noted in the `kill_on_drop` documentation, the Tokio runtime makes // a "best effort" to reap-after-exit to avoid zombie processes, but it // is not a guarantee. - .kill_on_drop(true) - .spawn()?; + .kill_on_drop(true); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } + + let mut child = command.spawn()?; let stdin = child .stdin @@ -447,12 +455,16 @@ const DEFAULT_ENV_VARS: &[&str] = &[ /// `config.toml`. fn create_env_for_mcp_server( extra_env: Option>, + env_vars: &[String], ) -> HashMap { DEFAULT_ENV_VARS .iter() - .filter_map(|var| match std::env::var(var) { - Ok(value) => Some((var.to_string(), value)), - Err(_) => None, + .copied() + .chain(env_vars.iter().map(String::as_str)) + .filter_map(|var| { + std::env::var(var) + .ok() + .map(|value| (var.to_string(), value)) }) .chain(extra_env.unwrap_or_default()) .collect::>() @@ -462,14 +474,36 @@ fn create_env_for_mcp_server( mod tests { use super::*; + fn set_env_var(key: &str, value: &str) { + unsafe { + std::env::set_var(key, value); + } + } + + fn remove_env_var(key: &str) { + unsafe { + std::env::remove_var(key); + } + } + #[test] fn test_create_env_for_mcp_server() { let env_var = "USER"; let env_var_existing_value = std::env::var(env_var).unwrap_or_default(); let env_var_new_value = format!("{env_var_existing_value}-extra"); let extra_env = HashMap::from([(env_var.to_owned(), env_var_new_value.clone())]); - let mcp_server_env = create_env_for_mcp_server(Some(extra_env)); + let mcp_server_env = create_env_for_mcp_server(Some(extra_env), &[]); assert!(mcp_server_env.contains_key("PATH")); assert_eq!(Some(&env_var_new_value), mcp_server_env.get(env_var)); } + + #[test] + fn test_create_env_for_mcp_server_includes_extra_whitelisted_vars() { + let custom_var = "CUSTOM_TEST_VAR"; + let value = "value".to_string(); + set_env_var(custom_var, &value); + let mcp_server_env = create_env_for_mcp_server(None, &[custom_var.to_string()]); + assert_eq!(Some(&value), mcp_server_env.get(custom_var)); + remove_env_var(custom_var); + } } diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 99a609b3..0016f114 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -58,4 +58,5 @@ webbrowser = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } +serial_test = { workspace = true } tempfile = { workspace = true } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 038f32ef..245b4eb3 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::ffi::OsString; use std::io; +use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; @@ -79,6 +80,8 @@ impl RmcpClient { program: OsString, args: Vec, env: Option>, + env_vars: &[String], + cwd: Option, ) -> io::Result { let program_name = program.to_string_lossy().into_owned(); let mut command = Command::new(&program); @@ -87,8 +90,11 @@ impl RmcpClient { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .env_clear() - .envs(create_env_for_mcp_server(env)) + .envs(create_env_for_mcp_server(env, env_vars)) .args(&args); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } let (transport, stderr) = TokioChildProcess::builder(command) .stderr(Stdio::piped()) diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index cccb12c0..17e050cb 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -74,9 +74,12 @@ where pub(crate) fn create_env_for_mcp_server( extra_env: Option>, + env_vars: &[String], ) -> HashMap { DEFAULT_ENV_VARS .iter() + .copied() + .chain(env_vars.iter().map(String::as_str)) .filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value))) .chain(extra_env.unwrap_or_default()) .collect() @@ -185,13 +188,59 @@ mod tests { use rmcp::model::CallToolResult as RmcpCallToolResult; use serde_json::json; + use serial_test::serial; + use std::ffi::OsString; + + struct EnvVarGuard { + key: String, + original: Option, + } + + impl EnvVarGuard { + fn set(key: &str, value: &str) -> Self { + let original = std::env::var_os(key); + unsafe { + std::env::set_var(key, value); + } + Self { + key: key.to_string(), + original, + } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + if let Some(value) = &self.original { + unsafe { + std::env::set_var(&self.key, value); + } + } else { + unsafe { + std::env::remove_var(&self.key); + } + } + } + } + #[tokio::test] async fn create_env_honors_overrides() { let value = "custom".to_string(); - let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())]))); + let env = + create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]); assert_eq!(env.get("TZ"), Some(&value)); } + #[test] + #[serial(extra_rmcp_env)] + fn create_env_includes_additional_whitelisted_variables() { + let custom_var = "EXTRA_RMCP_ENV"; + let value = "from-env"; + let _guard = EnvVarGuard::set(custom_var, value); + let env = create_env_for_mcp_server(None, &[custom_var.to_string()]); + assert_eq!(env.get(custom_var), Some(&value.to_string())); + } + #[test] fn convert_call_tool_result_defaults_missing_content() -> Result<()> { let structured_content = json!({ "key": "value" }); diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 2dd2b578..dae48d48 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -20,6 +20,7 @@ use crate::wrapping::RtOptions; use crate::wrapping::word_wrap_line; use crate::wrapping::word_wrap_lines; use base64::Engine; +use codex_common::format_env_display::format_env_display; use codex_core::config::Config; use codex_core::config_types::McpServerTransportConfig; use codex_core::config_types::ReasoningSummaryFormat; @@ -1013,7 +1014,13 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Auth: ".into(), status.to_string().into()].into()); match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let args_suffix = if args.is_empty() { String::new() } else { @@ -1022,13 +1029,13 @@ pub(crate) fn new_mcp_tools_output( let cmd_display = format!("{command}{args_suffix}"); lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); - if let Some(env) = env.as_ref() - && !env.is_empty() - { - let mut env_pairs: Vec = - env.iter().map(|(k, v)| format!("{k}={v}")).collect(); - env_pairs.sort(); - lines.push(vec![" • Env: ".into(), env_pairs.join(" ").into()].into()); + if let Some(cwd) = cwd.as_ref() { + lines.push(vec![" • Cwd: ".into(), cwd.display().to_string().into()].into()); + } + + let env_display = format_env_display(env.as_ref(), env_vars); + if env_display != "-" { + lines.push(vec![" • Env: ".into(), env_display.into()].into()); } } McpServerTransportConfig::StreamableHttp { @@ -1077,7 +1084,6 @@ pub(crate) fn new_mcp_tools_output( PlainHistoryCell { lines } } - pub(crate) fn new_info_event(message: String, hint: Option) -> PlainHistoryCell { let mut line = vec!["• ".dim(), message.into()]; if let Some(hint) = hint { diff --git a/docs/config.md b/docs/config.md index 8b5a45ab..dafd9dba 100644 --- a/docs/config.md +++ b/docs/config.md @@ -359,6 +359,11 @@ env = { "API_KEY" = "value" } # or [mcp_servers.server_name.env] API_KEY = "value" +# Optional: Additional list of environment variables that will be whitelisted in the MCP server's environment. +env_vars = ["API_KEY2"] + +# Optional: cwd that the command will be run from +cwd = "/Users//code/my-server" ``` #### Streamable HTTP