//! Connection manager for Model Context Protocol (MCP) servers. //! //! The [`McpConnectionManager`] owns one [`codex_mcp_client::McpClient`] per //! configured server (keyed by the *server name*). It offers convenience //! helpers to query the available tools across *all* servers and returns them //! in a single aggregated map using the fully-qualified tool name //! `""` as the key. use std::collections::HashMap; use anyhow::anyhow; use anyhow::Context; use anyhow::Result; use codex_mcp_client::McpClient; use mcp_types::Tool; use tokio::task::JoinSet; use tracing::info; use crate::mcp_server_config::McpServerConfig; /// Delimiter used to separate the server name from the tool name in a fully /// qualified tool name. /// /// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must /// choose a delimiter from this character set. const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__"; fn fully_qualified_tool_name(server: &str, tool: &str) -> String { format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}") } pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> { let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?; if server.is_empty() || tool.is_empty() { return None; } Some((server.to_string(), tool.to_string())) } /// A thin wrapper around a set of running [`McpClient`] instances. #[derive(Default)] pub(crate) struct McpConnectionManager { /// Server-name -> client instance. /// /// The server name originates from the keys of the `mcp_servers` map in /// the user configuration. clients: HashMap>, /// Fully qualified tool name -> tool instance. tools: HashMap, } impl McpConnectionManager { /// Spawn a [`McpClient`] for each configured server. /// /// * `mcp_servers` – Map loaded from the user configuration where *keys* /// are human-readable server identifiers and *values* are the spawn /// instructions. pub async fn new(mcp_servers: HashMap) -> Result { // Early exit if no servers are configured. if mcp_servers.is_empty() { return Ok(Self::default()); } // Spin up all servers concurrently. let mut join_set = JoinSet::new(); // Spawn tasks to launch each server. for (server_name, cfg) in mcp_servers { // TODO: Verify server name: require `^[a-zA-Z0-9_-]+$`? join_set.spawn(async move { let McpServerConfig { command, args, env } = cfg; let client_res = McpClient::new_stdio_client(command, args, env).await; (server_name, client_res) }); } let mut clients: HashMap> = HashMap::with_capacity(join_set.len()); while let Some(res) = join_set.join_next().await { let (server_name, client_res) = res?; let client = client_res .with_context(|| format!("failed to spawn MCP server `{server_name}`"))?; clients.insert(server_name, std::sync::Arc::new(client)); } let tools = list_all_tools(&clients).await?; Ok(Self { clients, tools }) } /// Returns a single map that contains **all** tools. Each key is the /// fully-qualified name for the tool. pub fn list_all_tools(&self) -> HashMap { self.tools.clone() } /// Invoke the tool indicated by the (server, tool) pair. pub async fn call_tool( &self, server: &str, tool: &str, arguments: Option, ) -> Result { let client = self .clients .get(server) .ok_or_else(|| anyhow!("unknown MCP server '{server}'"))? .clone(); client .call_tool(tool.to_string(), arguments) .await .with_context(|| format!("tool call failed for `{server}/{tool}`")) } } /// Query every server for its available tools and return a single map that /// contains **all** tools. Each key is the fully-qualified name for the tool. pub async fn list_all_tools( clients: &HashMap>, ) -> Result> { let mut join_set = JoinSet::new(); // Spawn one task per server so we can query them concurrently. This // keeps the overall latency roughly at the slowest server instead of // the cumulative latency. for (server_name, client) in clients { let server_name_cloned = server_name.clone(); let client_clone = client.clone(); join_set.spawn(async move { let res = client_clone.list_tools(None).await; (server_name_cloned, res) }); } let mut aggregated: HashMap = HashMap::with_capacity(join_set.len()); while let Some(join_res) = join_set.join_next().await { let (server_name, list_result) = join_res?; let list_result = list_result?; for tool in list_result.tools { // TODO(mbolin): escape tool names that contain invalid characters. let fq_name = fully_qualified_tool_name(&server_name, &tool.name); if aggregated.insert(fq_name.clone(), tool).is_some() { panic!("tool name collision for '{fq_name}': suspicious"); } } } info!( "aggregated {} tools from {} servers", aggregated.len(), clients.len() ); Ok(aggregated) }