diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 888a3092..e7fd7b8d 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -239,6 +239,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re } => McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers: None, + env_http_headers: None, }, AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"), }; @@ -260,11 +262,20 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re if let McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var: None, + http_headers, + env_http_headers, } = transport && matches!(supports_oauth_login(&url).await, Ok(true)) { println!("Detected OAuth support. Starting OAuth flow…"); - perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?; + perform_oauth_login( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers.clone(), + env_http_headers.clone(), + ) + .await?; println!("Successfully logged in."); } @@ -317,12 +328,24 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) bail!("No MCP server named '{name}' found."); }; - let url = match &server.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(), + let (url, http_headers, env_http_headers) = match &server.transport { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => (url.clone(), http_headers.clone(), env_http_headers.clone()), _ => bail!("OAuth login is only supported for streamable HTTP servers."), }; - perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?; + perform_oauth_login( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers, + env_http_headers, + ) + .await?; println!("Successfully logged in to MCP server '{name}'."); Ok(()) } @@ -386,11 +409,15 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { serde_json::json!({ "type": "streamable_http", "url": url, "bearer_token_env_var": bearer_token_env_var, + "http_headers": http_headers, + "env_http_headers": env_http_headers, }) } }; @@ -465,6 +492,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + .. } => { let status = if cfg.enabled { "enabled".to_string() @@ -610,10 +638,14 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => serde_json::json!({ "type": "streamable_http", "url": url, "bearer_token_env_var": bearer_token_env_var, + "http_headers": http_headers, + "env_http_headers": env_http_headers, }), }; let output = serde_json::to_string_pretty(&serde_json::json!({ @@ -661,11 +693,39 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { println!(" transport: streamable_http"); println!(" url: {url}"); let env_var = bearer_token_env_var.as_deref().unwrap_or("-"); println!(" bearer_token_env_var: {env_var}"); + let headers_display = match http_headers { + Some(map) if !map.is_empty() => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + _ => "-".to_string(), + }; + println!(" http_headers: {headers_display}"); + let env_headers_display = match env_http_headers { + Some(map) if !map.is_empty() => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + _ => "-".to_string(), + }; + println!(" env_http_headers: {env_headers_display}"); } } if let Some(timeout) = server.startup_timeout_sec { diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 705509ab..83abe72c 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -112,9 +112,13 @@ async fn add_streamable_http_without_manual_token() -> Result<()> { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert!(bearer_token_env_var.is_none()); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport: {other:?}"), } @@ -150,9 +154,13 @@ async fn add_streamable_http_with_custom_env_var() -> Result<()> { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/issues"); assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN")); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport: {other:?}"), } diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index ae414079..570cd515 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -415,11 +415,37 @@ pub fn write_global_mcp_servers( McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { entry["url"] = toml_edit::value(url.clone()); if let Some(env_var) = bearer_token_env_var { entry["bearer_token_env_var"] = toml_edit::value(env_var.clone()); } + if let Some(headers) = http_headers + && !headers.is_empty() + { + let mut table = TomlTable::new(); + table.set_implicit(false); + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + table.insert(key, toml_edit::value(value.clone())); + } + entry["http_headers"] = TomlItem::Table(table); + } + if let Some(headers) = env_http_headers + && !headers.is_empty() + { + let mut table = TomlTable::new(); + table.set_implicit(false); + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + table.insert(key, toml_edit::value(value.clone())); + } + entry["env_http_headers"] = TomlItem::Table(table); + } } } @@ -1948,15 +1974,18 @@ ZIG_VAR = "3" } #[tokio::test] - async fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> { + async fn write_global_mcp_servers_streamable_http_serializes_bearer_token() -> anyhow::Result<()> + { let codex_home = TempDir::new()?; - let mut servers = BTreeMap::from([( + let servers = BTreeMap::from([( "docs".to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(2)), @@ -1983,20 +2012,127 @@ startup_timeout_sec = 2.0 McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN")); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport {other:?}"), } assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2))); + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_serializes_custom_headers() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +bearer_token_env_var = "MCP_TOKEN" +startup_timeout_sec = 2.0 + +[mcp_servers.docs.http_headers] +X-Doc = "42" + +[mcp_servers.docs.env_http_headers] +X-Auth = "DOCS_AUTH" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { + http_headers, + env_http_headers, + .. + } => { + assert_eq!( + http_headers, + &Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])) + ); + assert_eq!( + env_http_headers, + &Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string() + )])) + ); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_removes_optional_sections() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + + let mut servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + let serialized_with_optional = std::fs::read_to_string(&config_path)?; + assert!(serialized_with_optional.contains("bearer_token_env_var = \"MCP_TOKEN\"")); + assert!(serialized_with_optional.contains("[mcp_servers.docs.http_headers]")); + assert!(serialized_with_optional.contains("[mcp_servers.docs.env_http_headers]")); + servers.insert( "docs".to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: None, @@ -2019,9 +2155,110 @@ url = "https://example.com/mcp" McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert!(bearer_token_env_var.is_none()); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + assert!(docs.startup_timeout_sec.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_isolates_headers_between_servers() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + + let servers = BTreeMap::from([ + ( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([( + "X-Doc".to_string(), + "42".to_string(), + )])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + ), + ( + "logs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "logs-server".to_string(), + args: vec!["--follow".to_string()], + env: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + ), + ]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains("[mcp_servers.docs.http_headers]"), + "serialized config missing docs headers section:\n{serialized}" + ); + assert!( + !serialized.contains("[mcp_servers.logs.http_headers]"), + "serialized config should not add logs headers section:\n{serialized}" + ); + assert!( + !serialized.contains("[mcp_servers.logs.env_http_headers]"), + "serialized config should not add logs env headers section:\n{serialized}" + ); + assert!( + !serialized.contains("mcp_servers.logs.bearer_token_env_var"), + "serialized config should not add bearer token to logs:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { + http_headers, + env_http_headers, + .. + } => { + assert_eq!( + http_headers, + &Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])) + ); + assert_eq!( + env_http_headers, + &Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string() + )])) + ); + } + other => panic!("unexpected transport {other:?}"), + } + let logs = loaded.get("logs").expect("logs entry"); + match &logs.transport { + McpServerTransportConfig::Stdio { env, .. } => { + assert!(env.is_none()); } other => panic!("unexpected transport {other:?}"), } diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index b724086a..3a14d77a 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -49,6 +49,10 @@ impl<'de> Deserialize<'de> for McpServerConfig { args: Option>, #[serde(default)] env: Option>, + #[serde(default)] + http_headers: Option>, + #[serde(default)] + env_http_headers: Option>, url: Option, bearer_token: Option, @@ -94,6 +98,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { env, url, bearer_token_env_var, + http_headers, + env_http_headers, .. } => { throw_if_set("stdio", "url", url.as_ref())?; @@ -102,6 +108,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { "bearer_token_env_var", bearer_token_env_var.as_ref(), )?; + throw_if_set("stdio", "http_headers", http_headers.as_ref())?; + throw_if_set("stdio", "env_http_headers", env_http_headers.as_ref())?; McpServerTransportConfig::Stdio { command, args: args.unwrap_or_default(), @@ -115,6 +123,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args, env, + http_headers, + env_http_headers, .. } => { throw_if_set("streamable_http", "command", command.as_ref())?; @@ -124,6 +134,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } } _ => return Err(SerdeError::custom("invalid transport")), @@ -161,6 +173,12 @@ pub enum McpServerTransportConfig { /// The actual secret value must be provided via the environment. #[serde(default, skip_serializing_if = "Option::is_none")] bearer_token_env_var: Option, + /// Additional HTTP headers to include in requests to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + http_headers: Option>, + /// HTTP headers where the value is sourced from an environment variable. + #[serde(default, skip_serializing_if = "Option::is_none")] + env_http_headers: Option>, }, } @@ -557,7 +575,9 @@ mod tests { cfg.transport, McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), - bearer_token_env_var: None + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, } ); assert!(cfg.enabled); @@ -577,12 +597,39 @@ mod tests { cfg.transport, McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), - bearer_token_env_var: Some("GITHUB_TOKEN".to_string()) + bearer_token_env_var: Some("GITHUB_TOKEN".to_string()), + http_headers: None, + env_http_headers: None, } ); assert!(cfg.enabled); } + #[test] + fn deserialize_streamable_http_server_config_with_headers() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + http_headers = { "X-Foo" = "bar" } + env_http_headers = { "X-Token" = "TOKEN_ENV" } + "#, + ) + .expect("should deserialize http config with headers"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: None, + http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Token".to_string(), + "TOKEN_ENV".to_string() + )])), + } + ); + } + #[test] fn deserialize_rejects_command_and_url() { toml::from_str::( @@ -605,6 +652,25 @@ mod tests { .expect_err("should reject env for http transport"); } + #[test] + fn deserialize_rejects_headers_for_stdio() { + toml::from_str::( + r#" + command = "echo" + http_headers = { "X-Foo" = "bar" } + "#, + ) + .expect_err("should reject http_headers for stdio transport"); + + toml::from_str::( + r#" + command = "echo" + env_http_headers = { "X-Foo" = "BAR_ENV" } + "#, + ) + .expect_err("should reject env_http_headers for stdio transport"); + } + #[test] fn deserialize_rejects_inline_bearer_token_field() { let err = toml::from_str::( diff --git a/codex-rs/core/src/mcp/auth.rs b/codex-rs/core/src/mcp/auth.rs index dbb9db80..22d1f5f5 100644 --- a/codex-rs/core/src/mcp/auth.rs +++ b/codex-rs/core/src/mcp/auth.rs @@ -45,11 +45,15 @@ async fn compute_auth_status( McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { determine_streamable_http_auth_status( server_name, url, bearer_token_env_var.as_deref(), + http_headers.clone(), + env_http_headers.clone(), store_mode, ) .await diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 768c6b01..8ddb0366 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -121,17 +121,27 @@ impl McpClientAdapter { } } + #[allow(clippy::too_many_arguments)] async fn new_streamable_http_client( server_name: String, url: String, bearer_token: Option, + http_headers: Option>, + env_http_headers: Option>, params: mcp_types::InitializeRequestParams, startup_timeout: Duration, store_mode: OAuthCredentialsStoreMode, ) -> Result { let client = Arc::new( - RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode) - .await?, + RmcpClient::new_streamable_http_client( + &server_name, + &url, + bearer_token, + http_headers, + env_http_headers, + store_mode, + ) + .await?, ); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Rmcp(client)) @@ -259,11 +269,18 @@ impl McpConnectionManager { ) .await } - McpServerTransportConfig::StreamableHttp { url, .. } => { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => { McpClientAdapter::new_streamable_http_client( server_name.clone(), url, resolved_bearer_token.unwrap_or_default(), + http_headers, + env_http_headers, params, startup_timeout, store_mode, diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 9dd921b9..b2559cd2 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -235,6 +235,8 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { transport: McpServerTransportConfig::StreamableHttp { url: server_url, bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), @@ -416,6 +418,8 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { transport: McpServerTransportConfig::StreamableHttp { url: server_url, bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index 5e32eed4..77c33f69 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::time::Duration; use anyhow::Error; @@ -6,11 +7,14 @@ use codex_protocol::protocol::McpAuthStatus; use reqwest::Client; use reqwest::StatusCode; use reqwest::Url; +use reqwest::header::HeaderMap; use serde::Deserialize; use tracing::debug; use crate::OAuthCredentialsStoreMode; use crate::oauth::has_oauth_tokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5); const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version"; @@ -21,6 +25,8 @@ pub async fn determine_streamable_http_auth_status( server_name: &str, url: &str, bearer_token_env_var: Option<&str>, + http_headers: Option>, + env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, ) -> Result { if bearer_token_env_var.is_some() { @@ -31,7 +37,9 @@ pub async fn determine_streamable_http_auth_status( return Ok(McpAuthStatus::OAuth); } - match supports_oauth_login(url).await { + let default_headers = build_default_headers(http_headers, env_http_headers)?; + + match supports_oauth_login_with_headers(url, &default_headers).await { Ok(true) => Ok(McpAuthStatus::NotLoggedIn), Ok(false) => Ok(McpAuthStatus::Unsupported), Err(error) => { @@ -45,8 +53,13 @@ pub async fn determine_streamable_http_auth_status( /// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login. pub async fn supports_oauth_login(url: &str) -> Result { + supports_oauth_login_with_headers(url, &HeaderMap::new()).await +} + +async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result { let base_url = Url::parse(url)?; - let client = Client::builder().timeout(DISCOVERY_TIMEOUT).build()?; + let builder = Client::builder().timeout(DISCOVERY_TIMEOUT); + let client = apply_default_headers(builder, default_headers).build()?; let mut last_error: Option = None; for candidate_path in discovery_paths(base_url.path()) { diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index c2d39a21..c5276227 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::string::String; use std::sync::Arc; use std::time::Duration; @@ -5,6 +6,7 @@ use std::time::Duration; use anyhow::Context; use anyhow::Result; use anyhow::anyhow; +use reqwest::ClientBuilder; use rmcp::transport::auth::OAuthState; use tiny_http::Response; use tiny_http::Server; @@ -16,6 +18,8 @@ use crate::OAuthCredentialsStoreMode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; use crate::save_oauth_tokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; struct CallbackServerGuard { server: Arc, @@ -31,6 +35,8 @@ pub async fn perform_oauth_login( server_name: &str, server_url: &str, store_mode: OAuthCredentialsStoreMode, + http_headers: Option>, + env_http_headers: Option>, ) -> Result<()> { let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); let guard = CallbackServerGuard { @@ -51,7 +57,10 @@ pub async fn perform_oauth_login( let (tx, rx) = oneshot::channel(); spawn_callback_server(server, tx); - let mut oauth_state = OAuthState::new(server_url, None).await?; + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; + + let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; oauth_state .start_authorization(&[], &redirect_uri, Some("Codex")) .await?; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 3d12e508..038f32ef 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -14,6 +14,7 @@ use mcp_types::InitializeRequestParams; use mcp_types::InitializeResult; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; +use reqwest::header::HeaderMap; use rmcp::model::CallToolRequestParam; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; @@ -38,6 +39,8 @@ use crate::logging_client_handler::LoggingClientHandler; use crate::oauth::OAuthCredentialsStoreMode; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; use crate::utils::convert_call_tool_result; use crate::utils::convert_to_mcp; use crate::utils::convert_to_rmcp; @@ -116,12 +119,17 @@ impl RmcpClient { }) } + #[allow(clippy::too_many_arguments)] pub async fn new_streamable_http_client( server_name: &str, url: &str, bearer_token: Option, + http_headers: Option>, + env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, ) -> Result { + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let initial_oauth_tokens = match bearer_token { Some(_) => None, None => match load_oauth_tokens(server_name, url, store_mode) { @@ -132,21 +140,30 @@ impl RmcpClient { } }, }; + let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() { - let (transport, oauth_persistor) = - create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode) - .await?; + let (transport, oauth_persistor) = create_oauth_transport_and_runtime( + server_name, + url, + initial_tokens, + store_mode, + default_headers.clone(), + ) + .await?; PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, } } else { let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); - if let Some(bearer_token) = bearer_token { + if let Some(bearer_token) = bearer_token.clone() { http_config = http_config.auth_header(bearer_token); } - let transport = StreamableHttpClientTransport::from_config(http_config); + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; + + let transport = StreamableHttpClientTransport::with_client(http_client, http_config); PendingTransport::StreamableHttp { transport } }; Ok(Self { @@ -290,11 +307,13 @@ async fn create_oauth_transport_and_runtime( url: &str, initial_tokens: StoredOAuthTokens, credentials_store: OAuthCredentialsStoreMode, + default_headers: HeaderMap, ) -> Result<( StreamableHttpClientTransport>, OAuthPersistor, )> { - let http_client = reqwest::Client::builder().build()?; + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?; oauth_state diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index 6b7bd894..cccb12c0 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -6,6 +6,10 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use mcp_types::CallToolResult; +use reqwest::ClientBuilder; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderName; +use reqwest::header::HeaderValue; use rmcp::model::CallToolResult as RmcpCallToolResult; use rmcp::service::ServiceError; use serde_json::Value; @@ -78,6 +82,75 @@ pub(crate) fn create_env_for_mcp_server( .collect() } +pub(crate) fn build_default_headers( + http_headers: Option>, + env_http_headers: Option>, +) -> Result { + let mut headers = HeaderMap::new(); + + if let Some(static_headers) = http_headers { + for (name, value) in static_headers { + let header_name = match HeaderName::from_bytes(name.as_bytes()) { + Ok(name) => name, + Err(err) => { + tracing::warn!("invalid HTTP header name `{name}`: {err}"); + continue; + } + }; + let header_value = match HeaderValue::from_str(value.as_str()) { + Ok(value) => value, + Err(err) => { + tracing::warn!("invalid HTTP header value for `{name}`: {err}"); + continue; + } + }; + headers.insert(header_name, header_value); + } + } + + if let Some(env_headers) = env_http_headers { + for (name, env_var) in env_headers { + if let Ok(value) = env::var(&env_var) { + if value.trim().is_empty() { + continue; + } + + let header_name = match HeaderName::from_bytes(name.as_bytes()) { + Ok(name) => name, + Err(err) => { + tracing::warn!("invalid HTTP header name `{name}`: {err}"); + continue; + } + }; + + let header_value = match HeaderValue::from_str(value.as_str()) { + Ok(value) => value, + Err(err) => { + tracing::warn!( + "invalid HTTP header value read from {env_var} for `{name}`: {err}" + ); + continue; + } + }; + headers.insert(header_name, header_value); + } + } + } + + Ok(headers) +} + +pub(crate) fn apply_default_headers( + builder: ClientBuilder, + default_headers: &HeaderMap, +) -> ClientBuilder { + if default_headers.is_empty() { + builder + } else { + builder.default_headers(default_headers.clone()) + } +} + #[cfg(unix)] pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ "HOME", diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index fe37d5fa..2dd2b578 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1031,8 +1031,37 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Env: ".into(), env_pairs.join(" ").into()].into()); } } - McpServerTransportConfig::StreamableHttp { url, .. } => { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => { lines.push(vec![" • URL: ".into(), url.clone().into()].into()); + if let Some(headers) = http_headers.as_ref() + && !headers.is_empty() + { + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + let display = pairs + .into_iter() + .map(|(name, value)| format!("{name}={value}")) + .collect::>() + .join(", "); + lines.push(vec![" • HTTP headers: ".into(), display.into()].into()); + } + if let Some(headers) = env_http_headers.as_ref() + && !headers.is_empty() + { + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + let display = pairs + .into_iter() + .map(|(name, env_var)| format!("{name}={env_var}")) + .collect::>() + .join(", "); + lines.push(vec![" • Env HTTP headers: ".into(), display.into()].into()); + } } }