use std::collections::HashMap; use std::ffi::OsString; use std::io; use std::process::Stdio; use std::sync::Arc; use std::time::Duration; use anyhow::Result; use anyhow::anyhow; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::InitializeRequestParams; use mcp_types::InitializeResult; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; use rmcp::model::CallToolRequestParam; use rmcp::model::InitializeRequestParam; use rmcp::model::PaginatedRequestParam; use rmcp::service::RoleClient; use rmcp::service::RunningService; use rmcp::service::{self}; use rmcp::transport::child_process::TokioChildProcess; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; use tokio::sync::Mutex; use tokio::time; use tracing::info; use tracing::warn; use crate::logging_client_handler::LoggingClientHandler; use crate::utils::convert_call_tool_result; use crate::utils::convert_to_mcp; use crate::utils::convert_to_rmcp; use crate::utils::create_env_for_mcp_server; use crate::utils::run_with_timeout; enum ClientState { Connecting { transport: Option, }, Ready { service: Arc>, }, } /// MCP client implemented on top of the official `rmcp` SDK. /// https://github.com/modelcontextprotocol/rust-sdk pub struct RmcpClient { state: Mutex, } impl RmcpClient { pub async fn new_stdio_client( program: OsString, args: Vec, env: Option>, ) -> io::Result { let program_name = program.to_string_lossy().into_owned(); let mut command = Command::new(&program); command .kill_on_drop(true) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .env_clear() .envs(create_env_for_mcp_server(env)) .args(&args); let (transport, stderr) = TokioChildProcess::builder(command) .stderr(Stdio::piped()) .spawn()?; if let Some(stderr) = stderr { tokio::spawn(async move { let mut reader = BufReader::new(stderr).lines(); loop { match reader.next_line().await { Ok(Some(line)) => { info!("MCP server stderr ({program_name}): {line}"); } Ok(None) => break, Err(error) => { warn!("Failed to read MCP server stderr ({program_name}): {error}"); break; } } } }); } Ok(Self { state: Mutex::new(ClientState::Connecting { transport: Some(transport), }), }) } /// Perform the initialization handshake with the MCP server. /// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization pub async fn initialize( &self, params: InitializeRequestParams, timeout: Option, ) -> Result { let transport = { let mut guard = self.state.lock().await; match &mut *guard { ClientState::Connecting { transport } => transport .take() .ok_or_else(|| anyhow!("client already initializing"))?, ClientState::Ready { .. } => { return Err(anyhow!("client already initialized")); } } }; let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; let client_handler = LoggingClientHandler::new(client_info); let service_future = service::serve_client(client_handler, transport); let service = match timeout { Some(duration) => time::timeout(duration, service_future) .await .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, None => service_future .await .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, }; let initialize_result_rmcp = service .peer() .peer_info() .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; let initialize_result = convert_to_mcp(initialize_result_rmcp)?; { let mut guard = self.state.lock().await; *guard = ClientState::Ready { service: Arc::new(service), }; } Ok(initialize_result) } pub async fn list_tools( &self, params: Option, timeout: Option, ) -> Result { let service = self.service().await?; let rmcp_params = params .map(convert_to_rmcp::<_, PaginatedRequestParam>) .transpose()?; let fut = service.list_tools(rmcp_params); let result = run_with_timeout(fut, timeout, "tools/list").await?; convert_to_mcp(result) } pub async fn call_tool( &self, name: String, arguments: Option, timeout: Option, ) -> Result { let service = self.service().await?; let params = CallToolRequestParams { arguments, name }; let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; let fut = service.call_tool(rmcp_params); let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?; convert_call_tool_result(rmcp_result) } async fn service(&self) -> Result>> { let guard = self.state.lock().await; match &*guard { ClientState::Ready { service } => Ok(Arc::clone(service)), ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")), } } }