use std::collections::HashMap; use std::ffi::OsString; use std::io; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; use anyhow::Result; use anyhow::anyhow; use futures::FutureExt; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::InitializeRequestParams; use mcp_types::InitializeResult; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; use rmcp::model::CallToolRequestParam; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; use rmcp::service::RoleClient; use rmcp::service::RunningService; use rmcp::service::{self}; use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::OAuthState; use rmcp::transport::child_process::TokioChildProcess; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; use tokio::sync::Mutex; use tokio::time; use tracing::info; use tracing::warn; use crate::load_oauth_tokens; use crate::logging_client_handler::LoggingClientHandler; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; use crate::utils::convert_call_tool_result; use crate::utils::convert_to_mcp; use crate::utils::convert_to_rmcp; use crate::utils::create_env_for_mcp_server; use crate::utils::run_with_timeout; enum PendingTransport { ChildProcess(TokioChildProcess), StreamableHttp { transport: StreamableHttpClientTransport, }, StreamableHttpWithOAuth { transport: StreamableHttpClientTransport>, oauth_persistor: OAuthPersistor, }, } enum ClientState { Connecting { transport: Option, }, Ready { service: Arc>, oauth: Option, }, } /// MCP client implemented on top of the official `rmcp` SDK. /// https://github.com/modelcontextprotocol/rust-sdk pub struct RmcpClient { state: Mutex, } impl RmcpClient { pub async fn new_stdio_client( program: OsString, args: Vec, env: Option>, ) -> io::Result { let program_name = program.to_string_lossy().into_owned(); let mut command = Command::new(&program); command .kill_on_drop(true) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .env_clear() .envs(create_env_for_mcp_server(env)) .args(&args); let (transport, stderr) = TokioChildProcess::builder(command) .stderr(Stdio::piped()) .spawn()?; if let Some(stderr) = stderr { tokio::spawn(async move { let mut reader = BufReader::new(stderr).lines(); loop { match reader.next_line().await { Ok(Some(line)) => { info!("MCP server stderr ({program_name}): {line}"); } Ok(None) => break, Err(error) => { warn!("Failed to read MCP server stderr ({program_name}): {error}"); break; } } } }); } Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(PendingTransport::ChildProcess(transport)), }), }) } pub async fn new_streamable_http_client( server_name: &str, url: &str, bearer_token: Option, ) -> Result { let initial_tokens = match load_oauth_tokens(server_name, url) { Ok(tokens) => tokens, Err(err) => { warn!("failed to read tokens for server `{server_name}`: {err}"); None } }; let transport = if let Some(initial_tokens) = initial_tokens.clone() { let (transport, oauth_persistor) = create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?; PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, } } else { let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string()); if let Some(bearer_token) = bearer_token { http_config = http_config.auth_header(format!("Bearer {bearer_token}")); } let transport = StreamableHttpClientTransport::from_config(http_config); PendingTransport::StreamableHttp { transport } }; Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), }), }) } /// Perform the initialization handshake with the MCP server. /// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization pub async fn initialize( &self, params: InitializeRequestParams, timeout: Option, ) -> Result { let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?; let client_handler = LoggingClientHandler::new(rmcp_params); let (transport, oauth_persistor) = { let mut guard = self.state.lock().await; match &mut *guard { ClientState::Connecting { transport } => match transport.take() { Some(PendingTransport::ChildProcess(transport)) => ( service::serve_client(client_handler.clone(), transport).boxed(), None, ), Some(PendingTransport::StreamableHttp { transport }) => ( service::serve_client(client_handler.clone(), transport).boxed(), None, ), Some(PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, }) => ( service::serve_client(client_handler.clone(), transport).boxed(), Some(oauth_persistor), ), None => return Err(anyhow!("client already initializing")), }, ClientState::Ready { .. } => return Err(anyhow!("client already initialized")), } }; let service = match timeout { Some(duration) => time::timeout(duration, transport) .await .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, None => transport .await .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, }; let initialize_result_rmcp = service .peer() .peer_info() .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; let initialize_result = convert_to_mcp(initialize_result_rmcp)?; { let mut guard = self.state.lock().await; *guard = ClientState::Ready { service: Arc::new(service), oauth: oauth_persistor.clone(), }; } if let Some(runtime) = oauth_persistor && let Err(error) = runtime.persist_if_needed().await { warn!("failed to persist OAuth tokens after initialize: {error}"); } Ok(initialize_result) } pub async fn list_tools( &self, params: Option, timeout: Option, ) -> Result { let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) .transpose()?; let fut = service.list_tools(rmcp_params); let result = run_with_timeout(fut, timeout, "tools/list").await?; let converted = convert_to_mcp(result)?; self.persist_oauth_tokens().await; Ok(converted) } pub async fn call_tool( &self, name: String, arguments: Option, timeout: Option, ) -> Result { let service = self.service().await?; let params = CallToolRequestParams { arguments, name }; let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; let fut = service.call_tool(rmcp_params); let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?; let converted = convert_call_tool_result(rmcp_result)?; self.persist_oauth_tokens().await; Ok(converted) } async fn service(&self) -> Result>> { let guard = self.state.lock().await; match &*guard { ClientState::Ready { service, .. } => Ok(Arc::clone(service)), ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")), } } async fn oauth_persistor(&self) -> Option { let guard = self.state.lock().await; match &*guard { ClientState::Ready { oauth: Some(runtime), service: _, } => Some(runtime.clone()), _ => None, } } async fn persist_oauth_tokens(&self) { if let Some(runtime) = self.oauth_persistor().await && let Err(error) = runtime.persist_if_needed().await { warn!("failed to persist OAuth tokens: {error}"); } } } async fn create_oauth_transport_and_runtime( server_name: &str, url: &str, initial_tokens: StoredOAuthTokens, ) -> Result<( StreamableHttpClientTransport>, OAuthPersistor, )> { let http_client = reqwest::Client::builder().build()?; let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?; oauth_state .set_credentials( &initial_tokens.client_id, initial_tokens.token_response.0.clone(), ) .await?; let manager = match oauth_state { OAuthState::Authorized(manager) => manager, OAuthState::Unauthorized(manager) => manager, OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => { return Err(anyhow!("unexpected OAuth state during client setup")); } }; let auth_client = AuthClient::new(http_client, manager); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( auth_client, StreamableHttpClientTransportConfig::with_uri(url.to_string()), ); let runtime = OAuthPersistor::new( server_name.to_string(), url.to_string(), auth_manager, Some(initial_tokens), ); Ok((transport, runtime)) }