Files
llmx/codex-rs/core/src/mcp_connection_manager.rs
Michael Bolin 115fb0b95d fix: navigate initialization phase before tools/list request in MCP client (#904)
Apparently the MCP server implemented in JavaScript did not require the
`initialize` handshake before responding to tool list/call, so I missed
this.
2025-05-12 15:15:26 -07:00

211 lines
7.6 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Connection manager for Model Context Protocol (MCP) servers.
//!
//! The [`McpConnectionManager`] owns one [`codex_mcp_client::McpClient`] per
//! configured server (keyed by the *server name*). It offers convenience
//! helpers to query the available tools across *all* servers and returns them
//! in a single aggregated map using the fully-qualified tool name
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
use std::collections::HashMap;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use codex_mcp_client::McpClient;
use mcp_types::ClientCapabilities;
use mcp_types::Implementation;
use mcp_types::Tool;
use tokio::task::JoinSet;
use tracing::info;
use crate::mcp_server_config::McpServerConfig;
/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
///
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
/// choose a delimiter from this character set.
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
/// Timeout for the `tools/list` request.
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
/// Map that holds a startup error for every MCP server that could **not** be
/// spawned successfully.
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
}
pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> {
let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?;
if server.is_empty() || tool.is_empty() {
return None;
}
Some((server.to_string(), tool.to_string()))
}
/// A thin wrapper around a set of running [`McpClient`] instances.
#[derive(Default)]
pub(crate) struct McpConnectionManager {
/// Server-name -> client instance.
///
/// The server name originates from the keys of the `mcp_servers` map in
/// the user configuration.
clients: HashMap<String, std::sync::Arc<McpClient>>,
/// Fully qualified tool name -> tool instance.
tools: HashMap<String, Tool>,
}
impl McpConnectionManager {
/// Spawn a [`McpClient`] for each configured server.
///
/// * `mcp_servers` Map loaded from the user configuration where *keys*
/// are human-readable server identifiers and *values* are the spawn
/// instructions.
///
/// Servers that fail to start are reported in `ClientStartErrors`: the
/// user should be informed about these errors.
pub async fn new(
mcp_servers: HashMap<String, McpServerConfig>,
) -> Result<(Self, ClientStartErrors)> {
// Early exit if no servers are configured.
if mcp_servers.is_empty() {
return Ok((Self::default(), ClientStartErrors::default()));
}
// Launch all configured servers concurrently.
let mut join_set = JoinSet::new();
for (server_name, cfg) in mcp_servers {
// TODO: Verify server name: require `^[a-zA-Z0-9_-]+$`?
join_set.spawn(async move {
let McpServerConfig { command, args, env } = cfg;
let client_res = McpClient::new_stdio_client(command, args, env).await;
match client_res {
Ok(client) => {
// Initialize the client.
let params = mcp_types::InitializeRequestParams {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
},
client_info: Implementation {
name: "codex-mcp-client".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
},
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
};
let initialize_notification_params = None;
let timeout = Some(Duration::from_secs(10));
match client
.initialize(params, initialize_notification_params, timeout)
.await
{
Ok(_response) => (server_name, Ok(client)),
Err(e) => (server_name, Err(e)),
}
}
Err(e) => (server_name, Err(e.into())),
}
});
}
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
HashMap::with_capacity(join_set.len());
let mut errors = ClientStartErrors::new();
while let Some(res) = join_set.join_next().await {
let (server_name, client_res) = res?; // JoinError propagation
match client_res {
Ok(client) => {
clients.insert(server_name, std::sync::Arc::new(client));
}
Err(e) => {
errors.insert(server_name, e);
}
}
}
let tools = list_all_tools(&clients).await?;
Ok((Self { clients, tools }, errors))
}
/// Returns a single map that contains **all** tools. Each key is the
/// fully-qualified name for the tool.
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
self.tools.clone()
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<mcp_types::CallToolResult> {
let client = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.clone();
client
.call_tool(tool.to_string(), arguments, timeout)
.await
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
}
}
/// Query every server for its available tools and return a single map that
/// contains **all** tools. Each key is the fully-qualified name for the tool.
pub async fn list_all_tools(
clients: &HashMap<String, std::sync::Arc<McpClient>>,
) -> Result<HashMap<String, Tool>> {
let mut join_set = JoinSet::new();
// Spawn one task per server so we can query them concurrently. This
// keeps the overall latency roughly at the slowest server instead of
// the cumulative latency.
for (server_name, client) in clients {
let server_name_cloned = server_name.clone();
let client_clone = client.clone();
join_set.spawn(async move {
let res = client_clone
.list_tools(None, Some(LIST_TOOLS_TIMEOUT))
.await;
(server_name_cloned, res)
});
}
let mut aggregated: HashMap<String, Tool> = HashMap::with_capacity(join_set.len());
while let Some(join_res) = join_set.join_next().await {
let (server_name, list_result) = join_res?;
let list_result = list_result?;
for tool in list_result.tools {
// TODO(mbolin): escape tool names that contain invalid characters.
let fq_name = fully_qualified_tool_name(&server_name, &tool.name);
if aggregated.insert(fq_name.clone(), tool).is_some() {
panic!("tool name collision for '{fq_name}': suspicious");
}
}
}
info!(
"aggregated {} tools from {} servers",
aggregated.len(),
clients.len()
);
Ok(aggregated)
}