From a5d48a775bb934f57cf38b99d48e12538999cfdf Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Thu, 16 Oct 2025 20:15:47 -0700 Subject: [PATCH] [MCP] Allow specifying custom headers with streamable http servers (#5241) This adds two new config fields to streamable http mcp servers: `http_headers`: a map of key to value `env_http_headers` a map of key to env var which will be resolved at request time All headers will be passed to all MCP requests to that server just like authorization headers. There is a test ensuring that headers are not passed to other servers. Fixes #5180 --- codex-rs/cli/src/mcp_cmd.rs | 68 ++++- codex-rs/cli/tests/mcp_add_remove.rs | 8 + codex-rs/core/src/config.rs | 241 +++++++++++++++++- codex-rs/core/src/config_types.rs | 70 ++++- codex-rs/core/src/mcp/auth.rs | 4 + codex-rs/core/src/mcp_connection_manager.rs | 23 +- codex-rs/core/tests/suite/rmcp_client.rs | 4 + codex-rs/rmcp-client/src/auth_status.rs | 17 +- .../rmcp-client/src/perform_oauth_login.rs | 11 +- codex-rs/rmcp-client/src/rmcp_client.rs | 31 ++- codex-rs/rmcp-client/src/utils.rs | 73 ++++++ codex-rs/tui/src/history_cell.rs | 31 ++- 12 files changed, 560 insertions(+), 21 deletions(-) diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index 888a3092..e7fd7b8d 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -239,6 +239,8 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re } => McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers: None, + env_http_headers: None, }, AddMcpTransportArgs { .. } => bail!("exactly one of --command or --url must be provided"), }; @@ -260,11 +262,20 @@ async fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Re if let McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var: None, + http_headers, + env_http_headers, } = transport && matches!(supports_oauth_login(&url).await, Ok(true)) { println!("Detected OAuth support. Starting OAuth flow…"); - perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?; + perform_oauth_login( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers.clone(), + env_http_headers.clone(), + ) + .await?; println!("Successfully logged in."); } @@ -317,12 +328,24 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) bail!("No MCP server named '{name}' found."); }; - let url = match &server.transport { - McpServerTransportConfig::StreamableHttp { url, .. } => url.clone(), + let (url, http_headers, env_http_headers) = match &server.transport { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => (url.clone(), http_headers.clone(), env_http_headers.clone()), _ => bail!("OAuth login is only supported for streamable HTTP servers."), }; - perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?; + perform_oauth_login( + &name, + &url, + config.mcp_oauth_credentials_store_mode, + http_headers, + env_http_headers, + ) + .await?; println!("Successfully logged in to MCP server '{name}'."); Ok(()) } @@ -386,11 +409,15 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { serde_json::json!({ "type": "streamable_http", "url": url, "bearer_token_env_var": bearer_token_env_var, + "http_headers": http_headers, + "env_http_headers": env_http_headers, }) } }; @@ -465,6 +492,7 @@ async fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + .. } => { let status = if cfg.enabled { "enabled".to_string() @@ -610,10 +638,14 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => serde_json::json!({ "type": "streamable_http", "url": url, "bearer_token_env_var": bearer_token_env_var, + "http_headers": http_headers, + "env_http_headers": env_http_headers, }), }; let output = serde_json::to_string_pretty(&serde_json::json!({ @@ -661,11 +693,39 @@ async fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Re McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { println!(" transport: streamable_http"); println!(" url: {url}"); let env_var = bearer_token_env_var.as_deref().unwrap_or("-"); println!(" bearer_token_env_var: {env_var}"); + let headers_display = match http_headers { + Some(map) if !map.is_empty() => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + _ => "-".to_string(), + }; + println!(" http_headers: {headers_display}"); + let env_headers_display = match env_http_headers { + Some(map) if !map.is_empty() => { + let mut pairs: Vec<_> = map.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + pairs + .into_iter() + .map(|(k, v)| format!("{k}={v}")) + .collect::>() + .join(", ") + } + _ => "-".to_string(), + }; + println!(" env_http_headers: {env_headers_display}"); } } if let Some(timeout) = server.startup_timeout_sec { diff --git a/codex-rs/cli/tests/mcp_add_remove.rs b/codex-rs/cli/tests/mcp_add_remove.rs index 705509ab..83abe72c 100644 --- a/codex-rs/cli/tests/mcp_add_remove.rs +++ b/codex-rs/cli/tests/mcp_add_remove.rs @@ -112,9 +112,13 @@ async fn add_streamable_http_without_manual_token() -> Result<()> { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert!(bearer_token_env_var.is_none()); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport: {other:?}"), } @@ -150,9 +154,13 @@ async fn add_streamable_http_with_custom_env_var() -> Result<()> { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/issues"); assert_eq!(bearer_token_env_var.as_deref(), Some("GITHUB_TOKEN")); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport: {other:?}"), } diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index ae414079..570cd515 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -415,11 +415,37 @@ pub fn write_global_mcp_servers( McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { entry["url"] = toml_edit::value(url.clone()); if let Some(env_var) = bearer_token_env_var { entry["bearer_token_env_var"] = toml_edit::value(env_var.clone()); } + if let Some(headers) = http_headers + && !headers.is_empty() + { + let mut table = TomlTable::new(); + table.set_implicit(false); + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + table.insert(key, toml_edit::value(value.clone())); + } + entry["http_headers"] = TomlItem::Table(table); + } + if let Some(headers) = env_http_headers + && !headers.is_empty() + { + let mut table = TomlTable::new(); + table.set_implicit(false); + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (key, value) in pairs { + table.insert(key, toml_edit::value(value.clone())); + } + entry["env_http_headers"] = TomlItem::Table(table); + } } } @@ -1948,15 +1974,18 @@ ZIG_VAR = "3" } #[tokio::test] - async fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> { + async fn write_global_mcp_servers_streamable_http_serializes_bearer_token() -> anyhow::Result<()> + { let codex_home = TempDir::new()?; - let mut servers = BTreeMap::from([( + let servers = BTreeMap::from([( "docs".to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(2)), @@ -1983,20 +2012,127 @@ startup_timeout_sec = 2.0 McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert_eq!(bearer_token_env_var.as_deref(), Some("MCP_TOKEN")); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); } other => panic!("unexpected transport {other:?}"), } assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2))); + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_serializes_custom_headers() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + write_global_mcp_servers(codex_home.path(), &servers)?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + let serialized = std::fs::read_to_string(&config_path)?; + assert_eq!( + serialized, + r#"[mcp_servers.docs] +url = "https://example.com/mcp" +bearer_token_env_var = "MCP_TOKEN" +startup_timeout_sec = 2.0 + +[mcp_servers.docs.http_headers] +X-Doc = "42" + +[mcp_servers.docs.env_http_headers] +X-Auth = "DOCS_AUTH" +"# + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { + http_headers, + env_http_headers, + .. + } => { + assert_eq!( + http_headers, + &Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])) + ); + assert_eq!( + env_http_headers, + &Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string() + )])) + ); + } + other => panic!("unexpected transport {other:?}"), + } + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_removes_optional_sections() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + + let mut servers = BTreeMap::from([( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + )]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + let serialized_with_optional = std::fs::read_to_string(&config_path)?; + assert!(serialized_with_optional.contains("bearer_token_env_var = \"MCP_TOKEN\"")); + assert!(serialized_with_optional.contains("[mcp_servers.docs.http_headers]")); + assert!(serialized_with_optional.contains("[mcp_servers.docs.env_http_headers]")); + servers.insert( "docs".to_string(), McpServerConfig { transport: McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: None, @@ -2019,9 +2155,110 @@ url = "https://example.com/mcp" McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { assert_eq!(url, "https://example.com/mcp"); assert!(bearer_token_env_var.is_none()); + assert!(http_headers.is_none()); + assert!(env_http_headers.is_none()); + } + other => panic!("unexpected transport {other:?}"), + } + + assert!(docs.startup_timeout_sec.is_none()); + + Ok(()) + } + + #[tokio::test] + async fn write_global_mcp_servers_streamable_http_isolates_headers_between_servers() + -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + + let servers = BTreeMap::from([ + ( + "docs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: Some("MCP_TOKEN".to_string()), + http_headers: Some(HashMap::from([( + "X-Doc".to_string(), + "42".to_string(), + )])), + env_http_headers: Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string(), + )])), + }, + enabled: true, + startup_timeout_sec: Some(Duration::from_secs(2)), + tool_timeout_sec: None, + }, + ), + ( + "logs".to_string(), + McpServerConfig { + transport: McpServerTransportConfig::Stdio { + command: "logs-server".to_string(), + args: vec!["--follow".to_string()], + env: None, + }, + enabled: true, + startup_timeout_sec: None, + tool_timeout_sec: None, + }, + ), + ]); + + write_global_mcp_servers(codex_home.path(), &servers)?; + + let serialized = std::fs::read_to_string(&config_path)?; + assert!( + serialized.contains("[mcp_servers.docs.http_headers]"), + "serialized config missing docs headers section:\n{serialized}" + ); + assert!( + !serialized.contains("[mcp_servers.logs.http_headers]"), + "serialized config should not add logs headers section:\n{serialized}" + ); + assert!( + !serialized.contains("[mcp_servers.logs.env_http_headers]"), + "serialized config should not add logs env headers section:\n{serialized}" + ); + assert!( + !serialized.contains("mcp_servers.logs.bearer_token_env_var"), + "serialized config should not add bearer token to logs:\n{serialized}" + ); + + let loaded = load_global_mcp_servers(codex_home.path()).await?; + let docs = loaded.get("docs").expect("docs entry"); + match &docs.transport { + McpServerTransportConfig::StreamableHttp { + http_headers, + env_http_headers, + .. + } => { + assert_eq!( + http_headers, + &Some(HashMap::from([("X-Doc".to_string(), "42".to_string())])) + ); + assert_eq!( + env_http_headers, + &Some(HashMap::from([( + "X-Auth".to_string(), + "DOCS_AUTH".to_string() + )])) + ); + } + other => panic!("unexpected transport {other:?}"), + } + let logs = loaded.get("logs").expect("logs entry"); + match &logs.transport { + McpServerTransportConfig::Stdio { env, .. } => { + assert!(env.is_none()); } other => panic!("unexpected transport {other:?}"), } diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index b724086a..3a14d77a 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -49,6 +49,10 @@ impl<'de> Deserialize<'de> for McpServerConfig { args: Option>, #[serde(default)] env: Option>, + #[serde(default)] + http_headers: Option>, + #[serde(default)] + env_http_headers: Option>, url: Option, bearer_token: Option, @@ -94,6 +98,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { env, url, bearer_token_env_var, + http_headers, + env_http_headers, .. } => { throw_if_set("stdio", "url", url.as_ref())?; @@ -102,6 +108,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { "bearer_token_env_var", bearer_token_env_var.as_ref(), )?; + throw_if_set("stdio", "http_headers", http_headers.as_ref())?; + throw_if_set("stdio", "env_http_headers", env_http_headers.as_ref())?; McpServerTransportConfig::Stdio { command, args: args.unwrap_or_default(), @@ -115,6 +123,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { command, args, env, + http_headers, + env_http_headers, .. } => { throw_if_set("streamable_http", "command", command.as_ref())?; @@ -124,6 +134,8 @@ impl<'de> Deserialize<'de> for McpServerConfig { McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } } _ => return Err(SerdeError::custom("invalid transport")), @@ -161,6 +173,12 @@ pub enum McpServerTransportConfig { /// The actual secret value must be provided via the environment. #[serde(default, skip_serializing_if = "Option::is_none")] bearer_token_env_var: Option, + /// Additional HTTP headers to include in requests to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + http_headers: Option>, + /// HTTP headers where the value is sourced from an environment variable. + #[serde(default, skip_serializing_if = "Option::is_none")] + env_http_headers: Option>, }, } @@ -557,7 +575,9 @@ mod tests { cfg.transport, McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), - bearer_token_env_var: None + bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, } ); assert!(cfg.enabled); @@ -577,12 +597,39 @@ mod tests { cfg.transport, McpServerTransportConfig::StreamableHttp { url: "https://example.com/mcp".to_string(), - bearer_token_env_var: Some("GITHUB_TOKEN".to_string()) + bearer_token_env_var: Some("GITHUB_TOKEN".to_string()), + http_headers: None, + env_http_headers: None, } ); assert!(cfg.enabled); } + #[test] + fn deserialize_streamable_http_server_config_with_headers() { + let cfg: McpServerConfig = toml::from_str( + r#" + url = "https://example.com/mcp" + http_headers = { "X-Foo" = "bar" } + env_http_headers = { "X-Token" = "TOKEN_ENV" } + "#, + ) + .expect("should deserialize http config with headers"); + + assert_eq!( + cfg.transport, + McpServerTransportConfig::StreamableHttp { + url: "https://example.com/mcp".to_string(), + bearer_token_env_var: None, + http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])), + env_http_headers: Some(HashMap::from([( + "X-Token".to_string(), + "TOKEN_ENV".to_string() + )])), + } + ); + } + #[test] fn deserialize_rejects_command_and_url() { toml::from_str::( @@ -605,6 +652,25 @@ mod tests { .expect_err("should reject env for http transport"); } + #[test] + fn deserialize_rejects_headers_for_stdio() { + toml::from_str::( + r#" + command = "echo" + http_headers = { "X-Foo" = "bar" } + "#, + ) + .expect_err("should reject http_headers for stdio transport"); + + toml::from_str::( + r#" + command = "echo" + env_http_headers = { "X-Foo" = "BAR_ENV" } + "#, + ) + .expect_err("should reject env_http_headers for stdio transport"); + } + #[test] fn deserialize_rejects_inline_bearer_token_field() { let err = toml::from_str::( diff --git a/codex-rs/core/src/mcp/auth.rs b/codex-rs/core/src/mcp/auth.rs index dbb9db80..22d1f5f5 100644 --- a/codex-rs/core/src/mcp/auth.rs +++ b/codex-rs/core/src/mcp/auth.rs @@ -45,11 +45,15 @@ async fn compute_auth_status( McpServerTransportConfig::StreamableHttp { url, bearer_token_env_var, + http_headers, + env_http_headers, } => { determine_streamable_http_auth_status( server_name, url, bearer_token_env_var.as_deref(), + http_headers.clone(), + env_http_headers.clone(), store_mode, ) .await diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 768c6b01..8ddb0366 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -121,17 +121,27 @@ impl McpClientAdapter { } } + #[allow(clippy::too_many_arguments)] async fn new_streamable_http_client( server_name: String, url: String, bearer_token: Option, + http_headers: Option>, + env_http_headers: Option>, params: mcp_types::InitializeRequestParams, startup_timeout: Duration, store_mode: OAuthCredentialsStoreMode, ) -> Result { let client = Arc::new( - RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode) - .await?, + RmcpClient::new_streamable_http_client( + &server_name, + &url, + bearer_token, + http_headers, + env_http_headers, + store_mode, + ) + .await?, ); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Rmcp(client)) @@ -259,11 +269,18 @@ impl McpConnectionManager { ) .await } - McpServerTransportConfig::StreamableHttp { url, .. } => { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => { McpClientAdapter::new_streamable_http_client( server_name.clone(), url, resolved_bearer_token.unwrap_or_default(), + http_headers, + env_http_headers, params, startup_timeout, store_mode, diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs index 9dd921b9..b2559cd2 100644 --- a/codex-rs/core/tests/suite/rmcp_client.rs +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -235,6 +235,8 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> { transport: McpServerTransportConfig::StreamableHttp { url: server_url, bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), @@ -416,6 +418,8 @@ async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> { transport: McpServerTransportConfig::StreamableHttp { url: server_url, bearer_token_env_var: None, + http_headers: None, + env_http_headers: None, }, enabled: true, startup_timeout_sec: Some(Duration::from_secs(10)), diff --git a/codex-rs/rmcp-client/src/auth_status.rs b/codex-rs/rmcp-client/src/auth_status.rs index 5e32eed4..77c33f69 100644 --- a/codex-rs/rmcp-client/src/auth_status.rs +++ b/codex-rs/rmcp-client/src/auth_status.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::time::Duration; use anyhow::Error; @@ -6,11 +7,14 @@ use codex_protocol::protocol::McpAuthStatus; use reqwest::Client; use reqwest::StatusCode; use reqwest::Url; +use reqwest::header::HeaderMap; use serde::Deserialize; use tracing::debug; use crate::OAuthCredentialsStoreMode; use crate::oauth::has_oauth_tokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5); const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version"; @@ -21,6 +25,8 @@ pub async fn determine_streamable_http_auth_status( server_name: &str, url: &str, bearer_token_env_var: Option<&str>, + http_headers: Option>, + env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, ) -> Result { if bearer_token_env_var.is_some() { @@ -31,7 +37,9 @@ pub async fn determine_streamable_http_auth_status( return Ok(McpAuthStatus::OAuth); } - match supports_oauth_login(url).await { + let default_headers = build_default_headers(http_headers, env_http_headers)?; + + match supports_oauth_login_with_headers(url, &default_headers).await { Ok(true) => Ok(McpAuthStatus::NotLoggedIn), Ok(false) => Ok(McpAuthStatus::Unsupported), Err(error) => { @@ -45,8 +53,13 @@ pub async fn determine_streamable_http_auth_status( /// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login. pub async fn supports_oauth_login(url: &str) -> Result { + supports_oauth_login_with_headers(url, &HeaderMap::new()).await +} + +async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result { let base_url = Url::parse(url)?; - let client = Client::builder().timeout(DISCOVERY_TIMEOUT).build()?; + let builder = Client::builder().timeout(DISCOVERY_TIMEOUT); + let client = apply_default_headers(builder, default_headers).build()?; let mut last_error: Option = None; for candidate_path in discovery_paths(base_url.path()) { diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index c2d39a21..c5276227 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::string::String; use std::sync::Arc; use std::time::Duration; @@ -5,6 +6,7 @@ use std::time::Duration; use anyhow::Context; use anyhow::Result; use anyhow::anyhow; +use reqwest::ClientBuilder; use rmcp::transport::auth::OAuthState; use tiny_http::Response; use tiny_http::Server; @@ -16,6 +18,8 @@ use crate::OAuthCredentialsStoreMode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; use crate::save_oauth_tokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; struct CallbackServerGuard { server: Arc, @@ -31,6 +35,8 @@ pub async fn perform_oauth_login( server_name: &str, server_url: &str, store_mode: OAuthCredentialsStoreMode, + http_headers: Option>, + env_http_headers: Option>, ) -> Result<()> { let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); let guard = CallbackServerGuard { @@ -51,7 +57,10 @@ pub async fn perform_oauth_login( let (tx, rx) = oneshot::channel(); spawn_callback_server(server, tx); - let mut oauth_state = OAuthState::new(server_url, None).await?; + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?; + + let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?; oauth_state .start_authorization(&[], &redirect_uri, Some("Codex")) .await?; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 3d12e508..038f32ef 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -14,6 +14,7 @@ use mcp_types::InitializeRequestParams; use mcp_types::InitializeResult; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; +use reqwest::header::HeaderMap; use rmcp::model::CallToolRequestParam; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; @@ -38,6 +39,8 @@ use crate::logging_client_handler::LoggingClientHandler; use crate::oauth::OAuthCredentialsStoreMode; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; +use crate::utils::apply_default_headers; +use crate::utils::build_default_headers; use crate::utils::convert_call_tool_result; use crate::utils::convert_to_mcp; use crate::utils::convert_to_rmcp; @@ -116,12 +119,17 @@ impl RmcpClient { }) } + #[allow(clippy::too_many_arguments)] pub async fn new_streamable_http_client( server_name: &str, url: &str, bearer_token: Option, + http_headers: Option>, + env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, ) -> Result { + let default_headers = build_default_headers(http_headers, env_http_headers)?; + let initial_oauth_tokens = match bearer_token { Some(_) => None, None => match load_oauth_tokens(server_name, url, store_mode) { @@ -132,21 +140,30 @@ impl RmcpClient { } }, }; + let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() { - let (transport, oauth_persistor) = - create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode) - .await?; + let (transport, oauth_persistor) = create_oauth_transport_and_runtime( + server_name, + url, + initial_tokens, + store_mode, + default_headers.clone(), + ) + .await?; PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, } } else { let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); - if let Some(bearer_token) = bearer_token { + if let Some(bearer_token) = bearer_token.clone() { http_config = http_config.auth_header(bearer_token); } - let transport = StreamableHttpClientTransport::from_config(http_config); + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; + + let transport = StreamableHttpClientTransport::with_client(http_client, http_config); PendingTransport::StreamableHttp { transport } }; Ok(Self { @@ -290,11 +307,13 @@ async fn create_oauth_transport_and_runtime( url: &str, initial_tokens: StoredOAuthTokens, credentials_store: OAuthCredentialsStoreMode, + default_headers: HeaderMap, ) -> Result<( StreamableHttpClientTransport>, OAuthPersistor, )> { - let http_client = reqwest::Client::builder().build()?; + let http_client = + apply_default_headers(reqwest::Client::builder(), &default_headers).build()?; let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?; oauth_state diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs index 6b7bd894..cccb12c0 100644 --- a/codex-rs/rmcp-client/src/utils.rs +++ b/codex-rs/rmcp-client/src/utils.rs @@ -6,6 +6,10 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use mcp_types::CallToolResult; +use reqwest::ClientBuilder; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderName; +use reqwest::header::HeaderValue; use rmcp::model::CallToolResult as RmcpCallToolResult; use rmcp::service::ServiceError; use serde_json::Value; @@ -78,6 +82,75 @@ pub(crate) fn create_env_for_mcp_server( .collect() } +pub(crate) fn build_default_headers( + http_headers: Option>, + env_http_headers: Option>, +) -> Result { + let mut headers = HeaderMap::new(); + + if let Some(static_headers) = http_headers { + for (name, value) in static_headers { + let header_name = match HeaderName::from_bytes(name.as_bytes()) { + Ok(name) => name, + Err(err) => { + tracing::warn!("invalid HTTP header name `{name}`: {err}"); + continue; + } + }; + let header_value = match HeaderValue::from_str(value.as_str()) { + Ok(value) => value, + Err(err) => { + tracing::warn!("invalid HTTP header value for `{name}`: {err}"); + continue; + } + }; + headers.insert(header_name, header_value); + } + } + + if let Some(env_headers) = env_http_headers { + for (name, env_var) in env_headers { + if let Ok(value) = env::var(&env_var) { + if value.trim().is_empty() { + continue; + } + + let header_name = match HeaderName::from_bytes(name.as_bytes()) { + Ok(name) => name, + Err(err) => { + tracing::warn!("invalid HTTP header name `{name}`: {err}"); + continue; + } + }; + + let header_value = match HeaderValue::from_str(value.as_str()) { + Ok(value) => value, + Err(err) => { + tracing::warn!( + "invalid HTTP header value read from {env_var} for `{name}`: {err}" + ); + continue; + } + }; + headers.insert(header_name, header_value); + } + } + } + + Ok(headers) +} + +pub(crate) fn apply_default_headers( + builder: ClientBuilder, + default_headers: &HeaderMap, +) -> ClientBuilder { + if default_headers.is_empty() { + builder + } else { + builder.default_headers(default_headers.clone()) + } +} + #[cfg(unix)] pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ "HOME", diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index fe37d5fa..2dd2b578 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -1031,8 +1031,37 @@ pub(crate) fn new_mcp_tools_output( lines.push(vec![" • Env: ".into(), env_pairs.join(" ").into()].into()); } } - McpServerTransportConfig::StreamableHttp { url, .. } => { + McpServerTransportConfig::StreamableHttp { + url, + http_headers, + env_http_headers, + .. + } => { lines.push(vec![" • URL: ".into(), url.clone().into()].into()); + if let Some(headers) = http_headers.as_ref() + && !headers.is_empty() + { + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + let display = pairs + .into_iter() + .map(|(name, value)| format!("{name}={value}")) + .collect::>() + .join(", "); + lines.push(vec![" • HTTP headers: ".into(), display.into()].into()); + } + if let Some(headers) = env_http_headers.as_ref() + && !headers.is_empty() + { + let mut pairs: Vec<_> = headers.iter().collect(); + pairs.sort_by(|(a, _), (b, _)| a.cmp(b)); + let display = pairs + .into_iter() + .map(|(name, env_var)| format!("{name}={env_var}")) + .collect::>() + .join(", "); + lines.push(vec![" • Env HTTP headers: ".into(), display.into()].into()); + } } }