use std::collections::HashMap; use std::env; use std::time::Duration; use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use mcp_types::CallToolResult; use reqwest::ClientBuilder; use reqwest::header::HeaderMap; use reqwest::header::HeaderName; use reqwest::header::HeaderValue; use rmcp::model::CallToolResult as RmcpCallToolResult; use rmcp::service::ServiceError; use serde_json::Value; use tokio::time; pub(crate) async fn run_with_timeout( fut: F, timeout: Option, label: &str, ) -> Result where F: std::future::Future>, { if let Some(duration) = timeout { let result = time::timeout(duration, fut) .await .with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?; result.map_err(|err| anyhow!("{label} failed: {err}")) } else { fut.await.map_err(|err| anyhow!("{label} failed: {err}")) } } pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result { let mut value = serde_json::to_value(result)?; if let Some(obj) = value.as_object_mut() && (obj.get("content").is_none() || obj.get("content").is_some_and(serde_json::Value::is_null)) { obj.insert("content".to_string(), Value::Array(Vec::new())); } serde_json::from_value(value).context("failed to convert call tool result") } /// Convert from mcp-types to Rust SDK types. /// /// The Rust SDK types are the same as our mcp-types crate because they are both /// derived from the same MCP specification. /// As a result, it should be safe to convert directly from one to the other. pub(crate) fn convert_to_rmcp(value: T) -> Result where T: serde::Serialize, U: serde::de::DeserializeOwned, { let json = serde_json::to_value(value)?; serde_json::from_value(json).map_err(|err| anyhow!(err)) } /// Convert from Rust SDK types to mcp-types. /// /// The Rust SDK types are the same as our mcp-types crate because they are both /// derived from the same MCP specification. /// As a result, it should be safe to convert directly from one to the other. pub(crate) fn convert_to_mcp(value: T) -> Result where T: serde::Serialize, U: serde::de::DeserializeOwned, { let json = serde_json::to_value(value)?; serde_json::from_value(json).map_err(|err| anyhow!(err)) } pub(crate) fn create_env_for_mcp_server( extra_env: Option>, env_vars: &[String], ) -> HashMap { DEFAULT_ENV_VARS .iter() .copied() .chain(env_vars.iter().map(String::as_str)) .filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value))) .chain(extra_env.unwrap_or_default()) .collect() } pub(crate) fn build_default_headers( http_headers: Option>, env_http_headers: Option>, ) -> Result { let mut headers = HeaderMap::new(); if let Some(static_headers) = http_headers { for (name, value) in static_headers { let header_name = match HeaderName::from_bytes(name.as_bytes()) { Ok(name) => name, Err(err) => { tracing::warn!("invalid HTTP header name `{name}`: {err}"); continue; } }; let header_value = match HeaderValue::from_str(value.as_str()) { Ok(value) => value, Err(err) => { tracing::warn!("invalid HTTP header value for `{name}`: {err}"); continue; } }; headers.insert(header_name, header_value); } } if let Some(env_headers) = env_http_headers { for (name, env_var) in env_headers { if let Ok(value) = env::var(&env_var) { if value.trim().is_empty() { continue; } let header_name = match HeaderName::from_bytes(name.as_bytes()) { Ok(name) => name, Err(err) => { tracing::warn!("invalid HTTP header name `{name}`: {err}"); continue; } }; let header_value = match HeaderValue::from_str(value.as_str()) { Ok(value) => value, Err(err) => { tracing::warn!( "invalid HTTP header value read from {env_var} for `{name}`: {err}" ); continue; } }; headers.insert(header_name, header_value); } } } Ok(headers) } pub(crate) fn apply_default_headers( builder: ClientBuilder, default_headers: &HeaderMap, ) -> ClientBuilder { if default_headers.is_empty() { builder } else { builder.default_headers(default_headers.clone()) } } #[cfg(unix)] pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ "HOME", "LOGNAME", "PATH", "SHELL", "USER", "__CF_USER_TEXT_ENCODING", "LANG", "LC_ALL", "TERM", "TMPDIR", "TZ", ]; #[cfg(windows)] pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ // Core path resolution "PATH", "PATHEXT", // Shell and system roots "COMSPEC", "SYSTEMROOT", "SYSTEMDRIVE", // User context and profiles "USERNAME", "USERDOMAIN", "USERPROFILE", "HOMEDRIVE", "HOMEPATH", // Program locations "PROGRAMFILES", "PROGRAMFILES(X86)", "PROGRAMW6432", "PROGRAMDATA", // App data and caches "LOCALAPPDATA", "APPDATA", // Temp locations "TEMP", "TMP", // Common shells/pwsh hints "POWERSHELL", "PWSH", ]; #[cfg(test)] mod tests { use super::*; use mcp_types::ContentBlock; use pretty_assertions::assert_eq; use rmcp::model::CallToolResult as RmcpCallToolResult; use serde_json::json; use serial_test::serial; use std::ffi::OsString; struct EnvVarGuard { key: String, original: Option, } impl EnvVarGuard { fn set(key: &str, value: &str) -> Self { let original = std::env::var_os(key); unsafe { std::env::set_var(key, value); } Self { key: key.to_string(), original, } } } impl Drop for EnvVarGuard { fn drop(&mut self) { if let Some(value) = &self.original { unsafe { std::env::set_var(&self.key, value); } } else { unsafe { std::env::remove_var(&self.key); } } } } #[tokio::test] async fn create_env_honors_overrides() { let value = "custom".to_string(); let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]); assert_eq!(env.get("TZ"), Some(&value)); } #[test] #[serial(extra_rmcp_env)] fn create_env_includes_additional_whitelisted_variables() { let custom_var = "EXTRA_RMCP_ENV"; let value = "from-env"; let _guard = EnvVarGuard::set(custom_var, value); let env = create_env_for_mcp_server(None, &[custom_var.to_string()]); assert_eq!(env.get(custom_var), Some(&value.to_string())); } #[test] fn convert_call_tool_result_defaults_missing_content() -> Result<()> { let structured_content = json!({ "key": "value" }); let rmcp_result = RmcpCallToolResult { content: vec![], structured_content: Some(structured_content.clone()), is_error: Some(true), meta: None, }; let result = convert_call_tool_result(rmcp_result)?; assert!(result.content.is_empty()); assert_eq!(result.structured_content, Some(structured_content)); assert_eq!(result.is_error, Some(true)); Ok(()) } #[test] fn convert_call_tool_result_preserves_existing_content() -> Result<()> { let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]); let result = convert_call_tool_result(rmcp_result)?; assert_eq!(result.content.len(), 1); match &result.content[0] { ContentBlock::TextContent(text_content) => { assert_eq!(text_content.text, "hello"); assert_eq!(text_content.r#type, "text"); } other => panic!("expected text content got {other:?}"), } assert_eq!(result.structured_content, None); assert_eq!(result.is_error, Some(false)); Ok(()) } }