From 115fb0b95d3da9bcf9145d2220f1a7f65aa2a01f Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Mon, 12 May 2025 15:15:26 -0700 Subject: [PATCH] fix: navigate initialization phase before tools/list request in MCP client (#904) Apparently the MCP server implemented in JavaScript did not require the `initialize` handshake before responding to tool list/call, so I missed this. --- codex-rs/core/src/mcp_connection_manager.rs | 32 ++++++++++++- codex-rs/mcp-client/src/main.rs | 25 ++++++++++ codex-rs/mcp-client/src/mcp_client.rs | 53 +++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index e4124b90..714c9452 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -13,6 +13,8 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use codex_mcp_client::McpClient; +use mcp_types::ClientCapabilities; +use mcp_types::Implementation; use mcp_types::Tool; use tokio::task::JoinSet; use tracing::info; @@ -83,7 +85,33 @@ impl McpConnectionManager { join_set.spawn(async move { let McpServerConfig { command, args, env } = cfg; let client_res = McpClient::new_stdio_client(command, args, env).await; - (server_name, client_res) + match client_res { + Ok(client) => { + // Initialize the client. + let params = mcp_types::InitializeRequestParams { + capabilities: ClientCapabilities { + experimental: None, + roots: None, + sampling: None, + }, + client_info: Implementation { + name: "codex-mcp-client".to_owned(), + version: env!("CARGO_PKG_VERSION").to_owned(), + }, + protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), + }; + let initialize_notification_params = None; + let timeout = Some(Duration::from_secs(10)); + match client + .initialize(params, initialize_notification_params, timeout) + .await + { + Ok(_response) => (server_name, Ok(client)), + Err(e) => (server_name, Err(e)), + } + } + Err(e) => (server_name, Err(e.into())), + } }); } @@ -99,7 +127,7 @@ impl McpConnectionManager { clients.insert(server_name, std::sync::Arc::new(client)); } Err(e) => { - errors.insert(server_name, e.into()); + errors.insert(server_name, e); } } } diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index eb784252..af4b0509 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -10,10 +10,16 @@ //! program. The utility connects, issues a `tools/list` request and prints the //! server's response as pretty JSON. +use std::time::Duration; + use anyhow::Context; use anyhow::Result; use codex_mcp_client::McpClient; +use mcp_types::ClientCapabilities; +use mcp_types::Implementation; +use mcp_types::InitializeRequestParams; use mcp_types::ListToolsRequestParams; +use mcp_types::MCP_SCHEMA_VERSION; #[tokio::main] async fn main() -> Result<()> { @@ -33,6 +39,25 @@ async fn main() -> Result<()> { .await .with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?; + let params = InitializeRequestParams { + capabilities: ClientCapabilities { + experimental: None, + roots: None, + sampling: None, + }, + client_info: Implementation { + name: "codex-mcp-client".to_owned(), + version: env!("CARGO_PKG_VERSION").to_owned(), + }, + protocol_version: MCP_SCHEMA_VERSION.to_owned(), + }; + let initialize_notification_params = None; + let timeout = Some(Duration::from_secs(10)); + let response = client + .initialize(params, initialize_notification_params, timeout) + .await?; + eprintln!("initialize response: {response:?}"); + // Issue `tools/list` request (no params). let timeout = None; let tools = client diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 641de0e8..3c6a5218 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -17,10 +17,14 @@ use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; use std::time::Duration; +use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use mcp_types::CallToolRequest; use mcp_types::CallToolRequestParams; +use mcp_types::InitializeRequest; +use mcp_types::InitializeRequestParams; +use mcp_types::InitializedNotification; use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPCMessage; use mcp_types::JSONRPCNotification; @@ -29,6 +33,7 @@ use mcp_types::JSONRPCResponse; use mcp_types::ListToolsRequest; use mcp_types::ListToolsRequestParams; use mcp_types::ListToolsResult; +use mcp_types::ModelContextProtocolNotification; use mcp_types::ModelContextProtocolRequest; use mcp_types::RequestId; use serde::Serialize; @@ -74,6 +79,8 @@ pub struct McpClient { impl McpClient { /// Spawn the given command and establish an MCP session over its STDIO. + /// Caller is responsible for sending the `initialize` request. See + /// [`initialize`](Self::initialize) for details. pub async fn new_stdio_client( program: String, args: Vec, @@ -273,6 +280,52 @@ impl McpClient { } } + pub async fn send_notification(&self, params: N::Params) -> Result<()> + where + N: ModelContextProtocolNotification, + N::Params: Serialize, + { + // Serialize params -> JSON. For many request types `Params` is + // `Option` and `None` should be encoded as *absence* of the field. + let params_json = serde_json::to_value(¶ms)?; + let params_field = if params_json.is_null() { + None + } else { + Some(params_json) + }; + + let method = N::METHOD.to_string(); + let jsonrpc_notification = JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.to_string(), + method: method.clone(), + params: params_field, + }; + + let notification = JSONRPCMessage::Notification(jsonrpc_notification); + self.outgoing_tx + .send(notification) + .await + .with_context(|| format!("failed to send notification `{method}` to writer task")) + } + + /// Negotiates the initialization with the MCP server. Sends an `initialize` + /// request with the specified `initialize_params` and then the + /// `notifications/initialized` notification once the response has been + /// received. Returns the response to the `initialize` request. + pub async fn initialize( + &self, + initialize_params: InitializeRequestParams, + initialize_notification_params: Option, + timeout: Option, + ) -> Result { + let response = self + .send_request::(initialize_params, timeout) + .await?; + self.send_notification::(initialize_notification_params) + .await?; + Ok(response) + } + /// Convenience wrapper around `tools/list`. pub async fn list_tools( &self,