[MCP] Introduce an experimental official rust sdk based mcp client (#4252)
The [official Rust
SDK](57fc428c57)
has come a long way since we first started our mcp client implementation
5 months ago and, today, it is much more complete than our own
stdio-only implementation.
This PR introduces a new config flag `experimental_use_rmcp_client`
which will use a new mcp client powered by the sdk instead of our own.
To keep this PR simple, I've only implemented the same stdio MCP
functionality that we had but will expand on it with future PRs.
---------
Co-authored-by: pakrym-oai <pakrym@openai.com>
This commit is contained in:
183
codex-rs/rmcp-client/src/rmcp_client.rs
Normal file
183
codex-rs/rmcp-client/src/rmcp_client.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::io;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::InitializeResult;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::ListToolsResult;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::InitializeRequestParam;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use crate::utils::run_with_timeout;
|
||||
|
||||
enum ClientState {
|
||||
Connecting {
|
||||
transport: Option<TokioChildProcess>,
|
||||
},
|
||||
Ready {
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// MCP client implemented on top of the official `rmcp` SDK.
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk
|
||||
pub struct RmcpClient {
|
||||
state: Mutex<ClientState>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
pub async fn new_stdio_client(
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
) -> io::Result<Self> {
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let mut command = Command::new(&program);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env))
|
||||
.args(&args);
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match reader.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
info!("MCP server stderr ({program_name}): {line}");
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
warn!("Failed to read MCP server stderr ({program_name}): {error}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform the initialization handshake with the MCP server.
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<InitializeResult> {
|
||||
let transport = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match &mut *guard {
|
||||
ClientState::Connecting { transport } => transport
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("client already initializing"))?,
|
||||
ClientState::Ready { .. } => {
|
||||
return Err(anyhow!("client already initialized"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(client_info);
|
||||
let service_future = service::serve_client(client_handler, transport);
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, service_future)
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
None => service_future
|
||||
.await
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
|
||||
let initialize_result_rmcp = service
|
||||
.peer()
|
||||
.peer_info()
|
||||
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
|
||||
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
|
||||
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
service: Arc::new(service),
|
||||
};
|
||||
}
|
||||
|
||||
Ok(initialize_result)
|
||||
}
|
||||
|
||||
pub async fn list_tools(
|
||||
&self,
|
||||
params: Option<ListToolsRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListToolsResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_tools(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
||||
convert_to_mcp(result)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<CallToolResult> {
|
||||
let service = self.service().await?;
|
||||
let params = CallToolRequestParams { arguments, name };
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.call_tool(rmcp_params);
|
||||
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
||||
convert_call_tool_result(rmcp_result)
|
||||
}
|
||||
|
||||
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready { service } => Ok(Arc::clone(service)),
|
||||
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user