[MCP] Allow specifying custom headers with streamable http servers (#5241)
This adds two new config fields to streamable http mcp servers: `http_headers`: a map of key to value `env_http_headers` a map of key to env var which will be resolved at request time All headers will be passed to all MCP requests to that server just like authorization headers. There is a test ensuring that headers are not passed to other servers. Fixes #5180
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Error;
|
||||
@@ -6,11 +7,14 @@ use codex_protocol::protocol::McpAuthStatus;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
@@ -21,6 +25,8 @@ pub async fn determine_streamable_http_auth_status(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if bearer_token_env_var.is_some() {
|
||||
@@ -31,7 +37,9 @@ pub async fn determine_streamable_http_auth_status(
|
||||
return Ok(McpAuthStatus::OAuth);
|
||||
}
|
||||
|
||||
match supports_oauth_login(url).await {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
match supports_oauth_login_with_headers(url, &default_headers).await {
|
||||
Ok(true) => Ok(McpAuthStatus::NotLoggedIn),
|
||||
Ok(false) => Ok(McpAuthStatus::Unsupported),
|
||||
Err(error) => {
|
||||
@@ -45,8 +53,13 @@ pub async fn determine_streamable_http_auth_status(
|
||||
|
||||
/// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login.
|
||||
pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
||||
supports_oauth_login_with_headers(url, &HeaderMap::new()).await
|
||||
}
|
||||
|
||||
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let client = Client::builder().timeout(DISCOVERY_TIMEOUT).build()?;
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT);
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
@@ -5,6 +6,7 @@ use std::time::Duration;
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use reqwest::ClientBuilder;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
@@ -16,6 +18,8 @@ use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
@@ -31,6 +35,8 @@ pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
@@ -51,7 +57,10 @@ pub async fn perform_oauth_login(
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, None).await?;
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
oauth_state
|
||||
.start_authorization(&[], &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
|
||||
@@ -14,6 +14,7 @@ use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::InitializeResult;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::ListToolsResult;
|
||||
use reqwest::header::HeaderMap;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::InitializeRequestParam;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
@@ -38,6 +39,8 @@ use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
@@ -116,12 +119,17 @@ impl RmcpClient {
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new_streamable_http_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
@@ -132,21 +140,30 @@ impl RmcpClient {
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode)
|
||||
.await?;
|
||||
let (transport, oauth_persistor) = create_oauth_transport_and_runtime(
|
||||
server_name,
|
||||
url,
|
||||
initial_tokens,
|
||||
store_mode,
|
||||
default_headers.clone(),
|
||||
)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token {
|
||||
if let Some(bearer_token) = bearer_token.clone() {
|
||||
http_config = http_config.auth_header(bearer_token);
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(http_config);
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(http_client, http_config);
|
||||
PendingTransport::StreamableHttp { transport }
|
||||
};
|
||||
Ok(Self {
|
||||
@@ -290,11 +307,13 @@ async fn create_oauth_transport_and_runtime(
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client = reqwest::Client::builder().build()?;
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
|
||||
@@ -6,6 +6,10 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolResult;
|
||||
use reqwest::ClientBuilder;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use rmcp::service::ServiceError;
|
||||
use serde_json::Value;
|
||||
@@ -78,6 +82,75 @@ pub(crate) fn create_env_for_mcp_server(
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn build_default_headers(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
if let Some(static_headers) = http_headers {
|
||||
for (name, value) in static_headers {
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header value for `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = env_http_headers {
|
||||
for (name, env_var) in env_headers {
|
||||
if let Ok(value) = env::var(&env_var) {
|
||||
if value.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"invalid HTTP header value read from {env_var} for `{name}`: {err}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_headers(
|
||||
builder: ClientBuilder,
|
||||
default_headers: &HeaderMap,
|
||||
) -> ClientBuilder {
|
||||
if default_headers.is_empty() {
|
||||
builder
|
||||
} else {
|
||||
builder.default_headers(default_headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
|
||||
Reference in New Issue
Block a user