From bdda762deb38f6cc8ba18debc74716928d22a947 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 16 Oct 2025 21:24:43 -0700 Subject: [PATCH] [MCP] Allow specifying cwd and additional env vars (#5246) This makes stdio mcp servers more flexible by allowing users to specify the cwd to run the server command from and adding additional environment variables to be passed through to the server. Example config using the test server in this repo: ```toml [mcp_servers.test_stdio] cwd = "/Users//code/codex/codex-rs" command = "cargo" args = ["run", "--bin", "test_stdio_server"] env_vars = ["MCP_TEST_VALUE"] ``` @bolinfest I know you hate these env var tests but let's roll with this for now. I may take a stab at the env guard + serial macro at some point. --- codex-rs/Cargo.lock | 1 + codex-rs/cli/src/mcp_cmd.rs | 102 ++++++++------ codex-rs/cli/tests/mcp_add_remove.rs | 10 +- codex-rs/cli/tests/mcp_list.rs | 30 +++- codex-rs/common/src/format_env_display.rs | 66 +++++++++ codex-rs/common/src/lib.rs | 3 + codex-rs/core/src/config.rs | 128 ++++++++++++++++- codex-rs/core/src/config_types.rs | 79 ++++++++++- codex-rs/core/src/mcp_connection_manager.rs | 20 ++- codex-rs/core/tests/suite/rmcp_client.rs | 146 +++++++++++++++++++- codex-rs/mcp-client/src/main.rs | 2 +- codex-rs/mcp-client/src/mcp_client.rs | 50 +++++-- codex-rs/rmcp-client/Cargo.toml | 1 + codex-rs/rmcp-client/src/rmcp_client.rs | 8 +- codex-rs/rmcp-client/src/utils.rs | 51 ++++++- codex-rs/tui/src/history_cell.rs | 24 ++-- docs/config.md | 5 + 17 files changed, 650 insertions(+), 76 deletions(-) create mode 100644 codex-rs/common/src/format_env_display.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index da19466c..6289625a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1374,6 +1374,7 @@ dependencies = [ "rmcp", "serde", "serde_json", + "serial_test", "sha2", "tempfile", "tiny_http", diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index e7fd7b8d..50274f3e 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -6,6 +6,7 @@ use anyhow::anyhow; use anyhow::bail; use clap::ArgGroup; use codex_common::CliConfigOverrides; +use codex_common::format_env_display::format_env_display; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::find_codex_home; @@ -227,6 +228,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re command: command_bin, args: command_args, env: env_map, + env_vars: Vec::new(), + cwd: None, } } AddMcpTransportArgs { @@ -400,11 +403,19 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> .copied() .unwrap_or(McpAuthStatus::Unsupported); let transport = match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => serde_json::json!({ "type": "stdio", "command": command, "args": args, "env": env, + "env_vars": env_vars, + "cwd": cwd, }), McpServerTransportConfig::StreamableHttp { url, @@ -446,30 +457,29 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> return Ok(()); } - let mut stdio_rows: Vec<[String; 6]> = Vec::new(); + let mut stdio_rows: Vec<[String; 7]> = Vec::new(); let mut http_rows: Vec<[String; 5]> = Vec::new(); for (name, cfg) in entries { match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let args_display = if args.is_empty() { "-".to_string() } else { args.join(" ") }; - let env_display = match env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") - } - }; + let env_display = format_env_display(env.as_ref(), env_vars); + let cwd_display = cwd + .as_ref() + .map(|path| path.display().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "-".to_string()); let status = if cfg.enabled { "enabled".to_string() } else { @@ -485,6 +495,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> command.clone(), args_display, env_display, + cwd_display, status, auth_status, ]); @@ -521,6 +532,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> "Command".len(), "Args".len(), "Env".len(), + "Cwd".len(), "Status".len(), "Auth".len(), ]; @@ -531,36 +543,40 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> } println!( - "{name: Re if get_args.json { let transport = match &server.transport { - McpServerTransportConfig::Stdio { command, args, env } => serde_json::json!({ + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => serde_json::json!({ "type": "stdio", "command": command, "args": args, "env": env, + "env_vars": env_vars, + "cwd": cwd, }), McpServerTransportConfig::StreamableHttp { url, @@ -666,7 +690,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re println!("{}", get_args.name); println!(" enabled: {}", server.enabled); match &server.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { println!(" transport: stdio"); println!(" command: {command}"); let args_display = if args.is_empty() { @@ -675,19 +705,13 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re args.join(" ") }; println!(" args: {args_display}"); - let env_display = match env.as_ref() { - None => "-".to_string(), - Some(map) if map.is_empty() => "-".to_string(), - Some(map) => { - let mut pairs: Vec<_> = map.iter().collect(); - pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); - pairs - .into_iter() - .map(|(k, v)| format!("{k}={v}")) - .collect::>() - .join(", ") - } - }; + let cwd_display = cwd + .as_ref() + .map(|path| path.display().to_string()) + .filter(|value| !value.is_empty()) + .unwrap_or_else(|| "-".to_string()); + println!(" cwd: {cwd_display}"); + let env_display = format_env_display(env.as_ref(), env_vars); println!(" env: {env_display}"); } McpServerTransportConfig::StreamableHttp { diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 83abe72c..7a6c2daa 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -28,10 +28,18 @@ async fn add_and_remove_server_updates_global_config() -> Result<()> { assert_eq!(servers.len(), 1); let docs = servers.get("docs").expect("server should exist"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); assert!(env.is_none()); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); } other => panic!("unexpected transport: {other:?}"), } diff --git a/codex-rs/cli/tests/mcp_list.rs b/codex-rs/cli/tests/mcp_list.rs index 8f33a8e4..ea0d6fc1 100644 --- a/codex-rs/cli/tests/mcp_list.rs +++ b/codex-rs/cli/tests/mcp_list.rs @@ -1,6 +1,9 @@ use std::path::Path; use anyhow::Result; +use codex_core::config::load_global_mcp_servers; +use codex_core::config::write_global_mcp_servers; +use codex_core::config_types::McpServerTransportConfig; use predicates::prelude::PredicateBooleanExt; use predicates::str::contains; use pretty_assertions::assert_eq; @@ -27,8 +30,8 @@ fn list_shows_empty_state() -> Result<()> { Ok(()) } -#[test] -fn list_and_get_render_expected_output() -> Result<()> { +#[tokio::test] +async fn list_and_get_render_expected_output() -> Result<()> { let codex_home = TempDir::new()?; let mut add = codex_command(codex_home.path())?; @@ -46,6 +49,18 @@ fn list_and_get_render_expected_output() -> Result<()> { .assert() .success(); + let mut servers = load_global_mcp_servers(codex_home.path()).await?; + let docs_entry = servers + .get_mut("docs") + .expect("docs server should exist after add"); + match &mut docs_entry.transport { + McpServerTransportConfig::Stdio { env_vars, .. } => { + *env_vars = vec!["APP_TOKEN".to_string(), "WORKSPACE_ID".to_string()]; + } + other => panic!("unexpected transport: {other:?}"), + } + write_global_mcp_servers(codex_home.path(), &servers)?; + let mut list_cmd = codex_command(codex_home.path())?; let list_output = list_cmd.args(["mcp", "list"]).output()?; assert!(list_output.status.success()); @@ -54,6 +69,8 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(stdout.contains("docs")); assert!(stdout.contains("docs-server")); assert!(stdout.contains("TOKEN=secret")); + assert!(stdout.contains("APP_TOKEN=$APP_TOKEN")); + assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID")); assert!(stdout.contains("Status")); assert!(stdout.contains("Auth")); assert!(stdout.contains("enabled")); @@ -79,7 +96,12 @@ fn list_and_get_render_expected_output() -> Result<()> { ], "env": { "TOKEN": "secret" - } + }, + "env_vars": [ + "APP_TOKEN", + "WORKSPACE_ID" + ], + "cwd": null }, "startup_timeout_sec": null, "tool_timeout_sec": null, @@ -98,6 +120,8 @@ fn list_and_get_render_expected_output() -> Result<()> { assert!(stdout.contains("command: docs-server")); assert!(stdout.contains("args: --port 4000")); assert!(stdout.contains("env: TOKEN=secret")); + assert!(stdout.contains("APP_TOKEN=$APP_TOKEN")); + assert!(stdout.contains("WORKSPACE_ID=$WORKSPACE_ID")); assert!(stdout.contains("enabled: true")); assert!(stdout.contains("remove: codex mcp remove docs")); diff --git a/codex-rs/common/src/format_env_display.rs b/codex-rs/common/src/format_env_display.rs new file mode 100644 index 00000000..640be307 --- /dev/null +++ b/codex-rs/common/src/format_env_display.rs @@ -0,0 +1,66 @@ +use std::collections::HashMap; + +pub fn format_env_display(env: Option<&HashMap>, env_vars: &[String]) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(map) = env { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + parts.extend( + pairs + .into_iter() + .map(|(key, value)| format!("{key}={value}")), + ); + } + + if !env_vars.is_empty() { + parts.extend(env_vars.iter().map(|var| format!("{var}=${var}"))); + } + + if parts.is_empty() { + "-".to_string() + } else { + parts.join(", ") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_dash_when_empty() { + assert_eq!(format_env_display(None, &[]), "-"); + + let empty_map = HashMap::new(); + assert_eq!(format_env_display(Some(&empty_map), &[]), "-"); + } + + #[test] + fn formats_sorted_env_pairs() { + let mut env = HashMap::new(); + env.insert("B".to_string(), "two".to_string()); + env.insert("A".to_string(), "one".to_string()); + + assert_eq!(format_env_display(Some(&env), &[]), "A=one, B=two"); + } + + #[test] + fn formats_env_vars_with_dollar_prefix() { + let vars = vec!["TOKEN".to_string(), "PATH".to_string()]; + + assert_eq!(format_env_display(None, &vars), "TOKEN=$TOKEN, PATH=$PATH"); + } + + #[test] + fn combines_env_pairs_and_vars() { + let mut env = HashMap::new(); + env.insert("HOME".to_string(), "/tmp".to_string()); + let vars = vec!["TOKEN".to_string()]; + + assert_eq!( + format_env_display(Some(&env), &vars), + "HOME=/tmp, TOKEN=$TOKEN" + ); + } +} diff --git a/codex-rs/common/src/lib.rs b/codex-rs/common/src/lib.rs index 292503f7..276bfca0 100644 --- a/codex-rs/common/src/lib.rs +++ b/codex-rs/common/src/lib.rs @@ -13,6 +13,9 @@ mod sandbox_mode_cli_arg; #[cfg(feature = "cli")] pub use sandbox_mode_cli_arg::SandboxModeCliArg; +#[cfg(feature = "cli")] +pub mod format_env_display; + #[cfg(any(feature = "cli", test))] mod config_override; diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 570cd515..6f5defce 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -388,7 +388,13 @@ pub fn write_global_mcp_servers( let mut entry = TomlTable::new(); entry.set_implicit(false); match &config.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { entry["command"] = toml_edit::value(command.clone()); if !args.is_empty() { @@ -411,6 +417,15 @@ pub fn write_global_mcp_servers( } entry["env"] = TomlItem::Table(env_table); } + + if !env_vars.is_empty() { + entry["env_vars"] = + TomlItem::Value(env_vars.iter().collect::().into()); + } + + if let Some(cwd) = cwd { + entry["cwd"] = toml_edit::value(cwd.to_string_lossy().to_string()); + } } McpServerTransportConfig::StreamableHttp { url, @@ -1806,6 +1821,8 @@ approve_all = true command: "echo".to_string(), args: vec!["hello".to_string()], env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(3)), @@ -1819,10 +1836,18 @@ approve_all = true assert_eq!(loaded.len(), 1); let docs = loaded.get("docs").expect("docs entry"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "echo"); assert_eq!(args, &vec!["hello".to_string()]); assert!(env.is_none()); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); } other => panic!("unexpected transport {other:?}"), } @@ -1932,6 +1957,8 @@ bearer_token = "secret" ("ZIG_VAR".to_string(), "3".to_string()), ("ALPHA_VAR".to_string(), "1".to_string()), ])), + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: None, @@ -1958,7 +1985,13 @@ ZIG_VAR = "3" let loaded = load_global_mcp_servers(codex_home.path()).await?; let docs = loaded.get("docs").expect("docs entry"); match &docs.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { assert_eq!(command, "docs-server"); assert_eq!(args, &vec!["--verbose".to_string()]); let env = env @@ -1966,6 +1999,91 @@ ZIG_VAR = "3" .expect("env should be preserved for stdio transport"); assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string())); assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string())); + assert!(env_vars.is_empty()); + assert!(cwd.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_serializes_env_vars() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: Vec::new(), + env: None, + env_vars: vec!["ALPHA".to_string(), "BETA".to_string()], + cwd: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains(r#"env_vars = ["ALPHA", "BETA"]"#), + "serialized config missing env_vars field:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { env_vars, .. } => { + assert_eq!(env_vars, &vec!["ALPHA".to_string(), "BETA".to_string()]); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_serializes_cwd() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let cwd_path = PathBuf::from("/tmp/codex-mcp"); + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "docs-server".to_string(), + args: Vec::new(), + env: None, + env_vars: Vec::new(), + cwd: Some(cwd_path.clone()), + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains(r#"cwd = "/tmp/codex-mcp""#), + "serialized config missing cwd field:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::Stdio { cwd, .. } => { + assert_eq!(cwd.as_deref(), Some(Path::new("/tmp/codex-mcp"))); } other => panic!("unexpected transport {other:?}"), } @@ -2205,6 +2323,8 @@ url = "https://example.com/mcp" command: "logs-server".to_string(), args: vec!["--follow".to_string()], env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: None, @@ -2277,6 +2397,8 @@ url = "https://example.com/mcp" command: "docs-server".to_string(), args: Vec::new(), env: None, + env_vars: Vec::new(), + cwd: None, }, enabled: false, startup_timeout_sec: None, diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 3a14d77a..3da61086 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -44,20 +44,26 @@ impl<'de> Deserialize<'de> for McpServerConfig { { #[derive(Deserialize)] struct RawMcpServerConfig { + // stdio command: Option, #[serde(default)] args: Option>, #[serde(default)] env: Option>, #[serde(default)] + env_vars: Option>, + #[serde(default)] + cwd: Option, http_headers: Option>, #[serde(default)] env_http_headers: Option>, + // streamable_http url: Option, bearer_token: Option, bearer_token_env_var: Option, + // shared #[serde(default)] startup_timeout_sec: Option, #[serde(default)] @@ -96,6 +102,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command: Some(command), args, env, + env_vars, + cwd, url, bearer_token_env_var, http_headers, @@ -114,6 +122,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args: args.unwrap_or_default(), env, + env_vars: env_vars.unwrap_or_default(), + cwd, } } RawMcpServerConfig { @@ -123,13 +133,20 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args, env, + env_vars, + cwd, http_headers, env_http_headers, - .. + startup_timeout_sec: _, + tool_timeout_sec: _, + startup_timeout_ms: _, + enabled: _, } => { throw_if_set("streamable_http", "command", command.as_ref())?; throw_if_set("streamable_http", "args", args.as_ref())?; throw_if_set("streamable_http", "env", env.as_ref())?; + throw_if_set("streamable_http", "env_vars", env_vars.as_ref())?; + throw_if_set("streamable_http", "cwd", cwd.as_ref())?; throw_if_set("streamable_http", "bearer_token", bearer_token.as_ref())?; McpServerTransportConfig::StreamableHttp { url, @@ -164,6 +181,10 @@ pub enum McpServerTransportConfig { args: Vec, #[serde(default, skip_serializing_if = "Option::is_none")] env: Option>, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + env_vars: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + cwd: Option, }, /// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http StreamableHttp { @@ -500,7 +521,9 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec![], - env: None + env: None, + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); @@ -521,7 +544,9 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec!["hello".to_string(), "world".to_string()], - env: None + env: None, + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); @@ -543,12 +568,58 @@ mod tests { McpServerTransportConfig::Stdio { command: "echo".to_string(), args: vec!["hello".to_string(), "world".to_string()], - env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])) + env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])), + env_vars: Vec::new(), + cwd: None, } ); assert!(cfg.enabled); } + #[test] + fn deserialize_stdio_command_server_config_with_env_vars() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + env_vars = ["FOO", "BAR"] + "#, + ) + .expect("should deserialize command config with env_vars"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: vec!["FOO".to_string(), "BAR".to_string()], + cwd: None, + } + ); + } + + #[test] + fn deserialize_stdio_command_server_config_with_cwd() { + let cfg: McpServerConfig = toml::from_str( + r#" + command = "echo" + cwd = "/tmp" + "#, + ) + .expect("should deserialize command config with cwd"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::Stdio { + command: "echo".to_string(), + args: vec![], + env: None, + env_vars: Vec::new(), + cwd: Some(PathBuf::from("/tmp")), + } + ); + } + #[test] fn deserialize_disabled_server_config() { let cfg: McpServerConfig = toml::from_str( diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 8ddb0366..e7ff9983 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -10,6 +10,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::env; use std::ffi::OsString; +use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; @@ -102,20 +103,25 @@ enum McpClientAdapter { } impl McpClientAdapter { + #[allow(clippy::too_many_arguments)] async fn new_stdio_client( use_rmcp_client: bool, program: OsString, args: Vec, env: Option>, + env_vars: Vec, + cwd: Option, params: mcp_types::InitializeRequestParams, startup_timeout: Duration, ) -> Result { if use_rmcp_client { - let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?); + let client = + Arc::new(RmcpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Rmcp(client)) } else { - let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?); + let client = + Arc::new(McpClient::new_stdio_client(program, args, env, &env_vars, cwd).await?); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Legacy(client)) } @@ -256,7 +262,13 @@ impl McpConnectionManager { }; let client = match transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let command_os: OsString = command.into(); let args_os: Vec = args.into_iter().map(Into::into).collect(); McpClientAdapter::new_stdio_client( @@ -264,6 +276,8 @@ impl McpConnectionManager { command_os, args_os, env, + env_vars, + cwd, params, startup_timeout, ) diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index b2559cd2..1a2815b0 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::ffi::OsStr; use std::ffi::OsString; use std::fs; use std::net::TcpListener; @@ -35,6 +36,7 @@ use tokio::time::sleep; use wiremock::matchers::any; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_test_value)] async fn stdio_server_round_trip() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -86,6 +88,8 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { "MCP_TEST_VALUE".to_string(), expected_env_value.to_string(), )])), + env_vars: Vec::new(), + cwd: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), @@ -106,7 +110,143 @@ async fn stdio_server_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| matches!(ev, EventMsg::McpToolCallBegin(_)), + Duration::from_secs(10), + ) + .await; + + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ECHOING: ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + server.verify().await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +#[serial(mcp_test_value)] +async fn stdio_server_propagates_whitelisted_env_vars() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "call-1234"; + let server_name = "rmcp_whitelist"; + let tool_name = format!("{server_name}__echo"); + + mount_sse_once_match( + &server, + any(), + responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once_match( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let expected_env_value = "propagated-env-from-whitelist"; + let _guard = EnvVarGuard::set("MCP_TEST_VALUE", OsStr::new(expected_env_value)); + let rmcp_test_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("test_stdio_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let fixture = test_codex() + .with_config(move |config| { + config.features.enable(Feature::RmcpClient); + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: rmcp_test_server_bin, + args: Vec::new(), + env: None, + env_vars: vec!["MCP_TEST_VALUE".to_string()], + cwd: None, + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, @@ -257,7 +397,7 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, @@ -440,7 +580,7 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { final_output_json_schema: None, cwd: fixture.cwd.path().to_path_buf(), approval_policy: AskForApproval::Never, - sandbox_policy: SandboxPolicy::DangerFullAccess, + sandbox_policy: SandboxPolicy::ReadOnly, model: session_model, effort: None, summary: ReasoningSummary::Auto, diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index f46058b9..8e1f322d 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -49,7 +49,7 @@ async fn main() -> Result<()> { // Spawn the subprocess and connect the client. let program = args.remove(0); let env = None; - let client = McpClient::new_stdio_client(program, args, env) + let client = McpClient::new_stdio_client(program, args, env, &[], None) .await .with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?; diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 27f96494..3be93f35 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -13,6 +13,7 @@ use std::collections::HashMap; use std::ffi::OsString; +use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -86,19 +87,26 @@ impl McpClient { program: OsString, args: Vec, env: Option>, + env_vars: &[String], + cwd: Option, ) -> std::io::Result { - let mut child = Command::new(program) + let mut command = Command::new(program); + command .args(args) .env_clear() - .envs(create_env_for_mcp_server(env)) + .envs(create_env_for_mcp_server(env, env_vars)) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::null()) // As noted in the `kill_on_drop` documentation, the Tokio runtime makes // a "best effort" to reap-after-exit to avoid zombie processes, but it // is not a guarantee. - .kill_on_drop(true) - .spawn()?; + .kill_on_drop(true); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } + + let mut child = command.spawn()?; let stdin = child .stdin @@ -447,12 +455,16 @@ const DEFAULT_ENV_VARS: &[&str] = &[ /// `config.toml`. fn create_env_for_mcp_server( extra_env: Option>, + env_vars: &[String], ) -> HashMap { DEFAULT_ENV_VARS .iter() - .filter_map(|var| match std::env::var(var) { - Ok(value) => Some((var.to_string(), value)), - Err(_) => None, + .copied() + .chain(env_vars.iter().map(String::as_str)) + .filter_map(|var| { + std::env::var(var) + .ok() + .map(|value| (var.to_string(), value)) }) .chain(extra_env.unwrap_or_default()) .collect::>() @@ -462,14 +474,36 @@ fn create_env_for_mcp_server( mod tests { use super::*; + fn set_env_var(key: &str, value: &str) { + unsafe { + std::env::set_var(key, value); + } + } + + fn remove_env_var(key: &str) { + unsafe { + std::env::remove_var(key); + } + } + #[test] fn test_create_env_for_mcp_server() { let env_var = "USER"; let env_var_existing_value = std::env::var(env_var).unwrap_or_default(); let env_var_new_value = format!("{env_var_existing_value}-extra"); let extra_env = HashMap::from([(env_var.to_owned(), env_var_new_value.clone())]); - let mcp_server_env = create_env_for_mcp_server(Some(extra_env)); + let mcp_server_env = create_env_for_mcp_server(Some(extra_env), &[]); assert!(mcp_server_env.contains_key("PATH")); assert_eq!(Some(&env_var_new_value), mcp_server_env.get(env_var)); } + + #[test] + fn test_create_env_for_mcp_server_includes_extra_whitelisted_vars() { + let custom_var = "CUSTOM_TEST_VAR"; + let value = "value".to_string(); + set_env_var(custom_var, &value); + let mcp_server_env = create_env_for_mcp_server(None, &[custom_var.to_string()]); + assert_eq!(Some(&value), mcp_server_env.get(custom_var)); + remove_env_var(custom_var); + } } diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index 99a609b3..0016f114 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -58,4 +58,5 @@ webbrowser = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } +serial_test = { workspace = true } tempfile = { workspace = true } diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 038f32ef..245b4eb3 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::ffi::OsString; use std::io; +use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; @@ -79,6 +80,8 @@ impl RmcpClient { program: OsString, args: Vec, env: Option>, + env_vars: &[String], + cwd: Option, ) -> io::Result { let program_name = program.to_string_lossy().into_owned(); let mut command = Command::new(&program); @@ -87,8 +90,11 @@ impl RmcpClient { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .env_clear() - .envs(create_env_for_mcp_server(env)) + .envs(create_env_for_mcp_server(env, env_vars)) .args(&args); + if let Some(cwd) = cwd { + command.current_dir(cwd); + } let (transport, stderr) = TokioChildProcess::builder(command) .stderr(Stdio::piped()) diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index cccb12c0..17e050cb 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -74,9 +74,12 @@ where pub(crate) fn create_env_for_mcp_server( extra_env: Option>, + env_vars: &[String], ) -> HashMap { DEFAULT_ENV_VARS .iter() + .copied() + .chain(env_vars.iter().map(String::as_str)) .filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value))) .chain(extra_env.unwrap_or_default()) .collect() @@ -185,13 +188,59 @@ mod tests { use rmcp::model::CallToolResult as RmcpCallToolResult; use serde_json::json; + use serial_test::serial; + use std::ffi::OsString; + + struct EnvVarGuard { + key: String, + original: Option, + } + + impl EnvVarGuard { + fn set(key: &str, value: &str) -> Self { + let original = std::env::var_os(key); + unsafe { + std::env::set_var(key, value); + } + Self { + key: key.to_string(), + original, + } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + if let Some(value) = &self.original { + unsafe { + std::env::set_var(&self.key, value); + } + } else { + unsafe { + std::env::remove_var(&self.key); + } + } + } + } + #[tokio::test] async fn create_env_honors_overrides() { let value = "custom".to_string(); - let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())]))); + let env = + create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]); assert_eq!(env.get("TZ"), Some(&value)); } + #[test] + #[serial(extra_rmcp_env)] + fn create_env_includes_additional_whitelisted_variables() { + let custom_var = "EXTRA_RMCP_ENV"; + let value = "from-env"; + let _guard = EnvVarGuard::set(custom_var, value); + let env = create_env_for_mcp_server(None, &[custom_var.to_string()]); + assert_eq!(env.get(custom_var), Some(&value.to_string())); + } + #[test] fn convert_call_tool_result_defaults_missing_content() -> Result<()> { let structured_content = json!({ "key": "value" }); diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index 2dd2b578..dae48d48 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -20,6 +20,7 @@ use crate::wrapping::RtOptions; use crate::wrapping::word_wrap_line; use crate::wrapping::word_wrap_lines; use base64::Engine; +use codex_common::format_env_display::format_env_display; use codex_core::config::Config; use codex_core::config_types::McpServerTransportConfig; use codex_core::config_types::ReasoningSummaryFormat; @@ -1013,7 +1014,13 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Auth: ".into(), status.to_string().into()].into()); match &cfg.transport { - McpServerTransportConfig::Stdio { command, args, env } => { + McpServerTransportConfig::Stdio { + command, + args, + env, + env_vars, + cwd, + } => { let args_suffix = if args.is_empty() { String::new() } else { @@ -1022,13 +1029,13 @@ pub(crate) fn new_mcp_tools_output( let cmd_display = format!("{command}{args_suffix}"); lines.push(vec![" • Command: ".into(), cmd_display.into()].into()); - if let Some(env) = env.as_ref() - && !env.is_empty() - { - let mut env_pairs: Vec = - env.iter().map(|(k, v)| format!("{k}={v}")).collect(); - env_pairs.sort(); - lines.push(vec![" • Env: ".into(), env_pairs.join(" ").into()].into()); + if let Some(cwd) = cwd.as_ref() { + lines.push(vec![" • Cwd: ".into(), cwd.display().to_string().into()].into()); + } + + let env_display = format_env_display(env.as_ref(), env_vars); + if env_display != "-" { + lines.push(vec![" • Env: ".into(), env_display.into()].into()); } } McpServerTransportConfig::StreamableHttp { @@ -1077,7 +1084,6 @@ pub(crate) fn new_mcp_tools_output( PlainHistoryCell { lines } } - pub(crate) fn new_info_event(message: String, hint: Option) -> PlainHistoryCell { let mut line = vec!["• ".dim(), message.into()]; if let Some(hint) = hint { diff --git a/docs/config.md b/docs/config.md index 8b5a45ab..dafd9dba 100644 --- a/docs/config.md +++ b/docs/config.md @@ -359,6 +359,11 @@ env = { "API_KEY" = "value" } # or [mcp_servers.server_name.env] API_KEY = "value" +# Optional: Additional list of environment variables that will be whitelisted in the MCP server's environment. +env_vars = ["API_KEY2"] + +# Optional: cwd that the command will be run from +cwd = "/Users//code/my-server" ``` #### Streamable HTTP