[MCP] Introduce an experimental official rust sdk based mcp client (#4252)

The [official Rust
SDK](57fc428c57)
has come a long way since we first started our mcp client implementation
5 months ago and, today, it is much more complete than our own
stdio-only implementation.

This PR introduces a new config flag `experimental_use_rmcp_client`
which will use a new mcp client powered by the sdk instead of our own.

To keep this PR simple, I've only implemented the same stdio MCP
functionality that we had but will expand on it with future PRs.

---------

Co-authored-by: pakrym-oai <pakrym@openai.com>
This commit is contained in:
Gabriel Peal
2025-09-26 10:13:37 -07:00
committed by GitHub
parent ea095e30c1
commit e555a36c6a
17 changed files with 1136 additions and 62 deletions

View File

@@ -0,0 +1,142 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;
use rmcp::ErrorData as McpError;
use rmcp::ServiceExt;
use rmcp::handler::server::ServerHandler;
use rmcp::model::CallToolRequestParam;
use rmcp::model::CallToolResult;
use rmcp::model::JsonObject;
use rmcp::model::ListToolsResult;
use rmcp::model::PaginatedRequestParam;
use rmcp::model::ServerCapabilities;
use rmcp::model::ServerInfo;
use rmcp::model::Tool;
use serde::Deserialize;
use serde_json::json;
use tokio::task;
#[derive(Clone)]
struct TestToolServer {
tools: Arc<Vec<Tool>>,
}
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
(tokio::io::stdin(), tokio::io::stdout())
}
impl TestToolServer {
fn new() -> Self {
let tools = vec![Self::echo_tool()];
Self {
tools: Arc::new(tools),
}
}
fn echo_tool() -> Tool {
#[expect(clippy::expect_used)]
let schema: JsonObject = serde_json::from_value(json!({
"type": "object",
"properties": {
"message": { "type": "string" },
"env_var": { "type": "string" }
},
"required": ["message"],
"additionalProperties": false
}))
.expect("echo tool schema should deserialize");
Tool::new(
Cow::Borrowed("echo"),
Cow::Borrowed("Echo back the provided message and include environment data."),
Arc::new(schema),
)
}
}
#[derive(Deserialize)]
struct EchoArgs {
message: String,
#[allow(dead_code)]
env_var: Option<String>,
}
impl ServerHandler for TestToolServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
capabilities: ServerCapabilities::builder()
.enable_tools()
.enable_tool_list_changed()
.build(),
..ServerInfo::default()
}
}
fn list_tools(
&self,
_request: Option<PaginatedRequestParam>,
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
let tools = self.tools.clone();
async move {
Ok(ListToolsResult {
tools: (*tools).clone(),
next_cursor: None,
})
}
}
async fn call_tool(
&self,
request: CallToolRequestParam,
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
) -> Result<CallToolResult, McpError> {
match request.name.as_ref() {
"echo" => {
let args: EchoArgs = match request.arguments {
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
arguments.into_iter().collect(),
))
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
None => {
return Err(McpError::invalid_params(
"missing arguments for echo tool",
None,
));
}
};
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
let structured_content = json!({
"echo": args.message,
"env": env_snapshot.get("MCP_TEST_VALUE"),
});
Ok(CallToolResult {
content: Vec::new(),
structured_content: Some(structured_content),
is_error: Some(false),
meta: None,
})
}
other => Err(McpError::invalid_params(
format!("unknown tool: {other}"),
None,
)),
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("starting rmcp test server");
// Run the server with STDIO transport. If the client disconnects we simply
// bubble up the error so the process exits.
let service = TestToolServer::new();
let running = service.serve(stdio()).await?;
// Wait for the client to finish interacting with the server.
running.waiting().await?;
// Drain background tasks to ensure clean shutdown.
task::yield_now().await;
Ok(())
}

View File

@@ -0,0 +1,5 @@
mod logging_client_handler;
mod rmcp_client;
mod utils;
pub use rmcp_client::RmcpClient;

View File

@@ -0,0 +1,134 @@
use rmcp::ClientHandler;
use rmcp::RoleClient;
use rmcp::model::CancelledNotificationParam;
use rmcp::model::ClientInfo;
use rmcp::model::CreateElicitationRequestParam;
use rmcp::model::CreateElicitationResult;
use rmcp::model::ElicitationAction;
use rmcp::model::LoggingLevel;
use rmcp::model::LoggingMessageNotificationParam;
use rmcp::model::ProgressNotificationParam;
use rmcp::model::ResourceUpdatedNotificationParam;
use rmcp::service::NotificationContext;
use rmcp::service::RequestContext;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::warn;
#[derive(Debug, Clone)]
pub(crate) struct LoggingClientHandler {
client_info: ClientInfo,
}
impl LoggingClientHandler {
pub(crate) fn new(client_info: ClientInfo) -> Self {
Self { client_info }
}
}
impl ClientHandler for LoggingClientHandler {
// TODO (CODEX-3571): support elicitations.
async fn create_elicitation(
&self,
request: CreateElicitationRequestParam,
_context: RequestContext<RoleClient>,
) -> Result<CreateElicitationResult, rmcp::ErrorData> {
info!(
"MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.",
request.message
);
Ok(CreateElicitationResult {
action: ElicitationAction::Decline,
content: None,
})
}
async fn on_cancelled(
&self,
params: CancelledNotificationParam,
_context: NotificationContext<RoleClient>,
) {
info!(
"MCP server cancelled request (request_id: {}, reason: {:?})",
params.request_id, params.reason
);
}
async fn on_progress(
&self,
params: ProgressNotificationParam,
_context: NotificationContext<RoleClient>,
) {
info!(
"MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})",
params.progress_token, params.progress, params.total, params.message
);
}
async fn on_resource_updated(
&self,
params: ResourceUpdatedNotificationParam,
_context: NotificationContext<RoleClient>,
) {
info!("MCP server resource updated (uri: {})", params.uri);
}
async fn on_resource_list_changed(&self, _context: NotificationContext<RoleClient>) {
info!("MCP server resource list changed");
}
async fn on_tool_list_changed(&self, _context: NotificationContext<RoleClient>) {
info!("MCP server tool list changed");
}
async fn on_prompt_list_changed(&self, _context: NotificationContext<RoleClient>) {
info!("MCP server prompt list changed");
}
fn get_info(&self) -> ClientInfo {
self.client_info.clone()
}
async fn on_logging_message(
&self,
params: LoggingMessageNotificationParam,
_context: NotificationContext<RoleClient>,
) {
let LoggingMessageNotificationParam {
level,
logger,
data,
} = params;
let logger = logger.as_deref();
match level {
LoggingLevel::Emergency
| LoggingLevel::Alert
| LoggingLevel::Critical
| LoggingLevel::Error => {
error!(
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
level, logger, data
);
}
LoggingLevel::Warning => {
warn!(
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
level, logger, data
);
}
LoggingLevel::Notice | LoggingLevel::Info => {
info!(
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
level, logger, data
);
}
LoggingLevel::Debug => {
debug!(
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
level, logger, data
);
}
}
}
}

View File

@@ -0,0 +1,183 @@
use std::collections::HashMap;
use std::ffi::OsString;
use std::io;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::InitializeRequestParams;
use mcp_types::InitializeResult;
use mcp_types::ListToolsRequestParams;
use mcp_types::ListToolsResult;
use rmcp::model::CallToolRequestParam;
use rmcp::model::InitializeRequestParam;
use rmcp::model::PaginatedRequestParam;
use rmcp::service::RoleClient;
use rmcp::service::RunningService;
use rmcp::service::{self};
use rmcp::transport::child_process::TokioChildProcess;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::time;
use tracing::info;
use tracing::warn;
use crate::logging_client_handler::LoggingClientHandler;
use crate::utils::convert_call_tool_result;
use crate::utils::convert_to_mcp;
use crate::utils::convert_to_rmcp;
use crate::utils::create_env_for_mcp_server;
use crate::utils::run_with_timeout;
enum ClientState {
Connecting {
transport: Option<TokioChildProcess>,
},
Ready {
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
},
}
/// MCP client implemented on top of the official `rmcp` SDK.
/// https://github.com/modelcontextprotocol/rust-sdk
pub struct RmcpClient {
state: Mutex<ClientState>,
}
impl RmcpClient {
pub async fn new_stdio_client(
program: OsString,
args: Vec<OsString>,
env: Option<HashMap<String, String>>,
) -> io::Result<Self> {
let program_name = program.to_string_lossy().into_owned();
let mut command = Command::new(&program);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.envs(create_env_for_mcp_server(env))
.args(&args);
let (transport, stderr) = TokioChildProcess::builder(command)
.stderr(Stdio::piped())
.spawn()?;
if let Some(stderr) = stderr {
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
loop {
match reader.next_line().await {
Ok(Some(line)) => {
info!("MCP server stderr ({program_name}): {line}");
}
Ok(None) => break,
Err(error) => {
warn!("Failed to read MCP server stderr ({program_name}): {error}");
break;
}
}
}
});
}
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(transport),
}),
})
}
/// Perform the initialization handshake with the MCP server.
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
pub async fn initialize(
&self,
params: InitializeRequestParams,
timeout: Option<Duration>,
) -> Result<InitializeResult> {
let transport = {
let mut guard = self.state.lock().await;
match &mut *guard {
ClientState::Connecting { transport } => transport
.take()
.ok_or_else(|| anyhow!("client already initializing"))?,
ClientState::Ready { .. } => {
return Err(anyhow!("client already initialized"));
}
}
};
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
let client_handler = LoggingClientHandler::new(client_info);
let service_future = service::serve_client(client_handler, transport);
let service = match timeout {
Some(duration) => time::timeout(duration, service_future)
.await
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
None => service_future
.await
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
};
let initialize_result_rmcp = service
.peer()
.peer_info()
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
{
let mut guard = self.state.lock().await;
*guard = ClientState::Ready {
service: Arc::new(service),
};
}
Ok(initialize_result)
}
pub async fn list_tools(
&self,
params: Option<ListToolsRequestParams>,
timeout: Option<Duration>,
) -> Result<ListToolsResult> {
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
.transpose()?;
let fut = service.list_tools(rmcp_params);
let result = run_with_timeout(fut, timeout, "tools/list").await?;
convert_to_mcp(result)
}
pub async fn call_tool(
&self,
name: String,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<CallToolResult> {
let service = self.service().await?;
let params = CallToolRequestParams { arguments, name };
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
let fut = service.call_tool(rmcp_params);
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
convert_call_tool_result(rmcp_result)
}
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready { service } => Ok(Arc::clone(service)),
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
}
}
}

View File

@@ -0,0 +1,160 @@
use std::collections::HashMap;
use std::env;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use mcp_types::CallToolResult;
use rmcp::model::CallToolResult as RmcpCallToolResult;
use rmcp::service::ServiceError;
use serde_json::Value;
use tokio::time;
pub(crate) async fn run_with_timeout<F, T>(
fut: F,
timeout: Option<Duration>,
label: &str,
) -> Result<T>
where
F: std::future::Future<Output = Result<T, ServiceError>>,
{
if let Some(duration) = timeout {
let result = time::timeout(duration, fut)
.await
.with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?;
result.map_err(|err| anyhow!("{label} failed: {err}"))
} else {
fut.await.map_err(|err| anyhow!("{label} failed: {err}"))
}
}
pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result<CallToolResult> {
let mut value = serde_json::to_value(result)?;
if let Some(obj) = value.as_object_mut()
&& (obj.get("content").is_none()
|| obj.get("content").is_some_and(serde_json::Value::is_null))
{
obj.insert("content".to_string(), Value::Array(Vec::new()));
}
serde_json::from_value(value).context("failed to convert call tool result")
}
/// Convert from mcp-types to Rust SDK types.
///
/// The Rust SDK types are the same as our mcp-types crate because they are both
/// derived from the same MCP specification.
/// As a result, it should be safe to convert directly from one to the other.
pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
where
T: serde::Serialize,
U: serde::de::DeserializeOwned,
{
let json = serde_json::to_value(value)?;
serde_json::from_value(json).map_err(|err| anyhow!(err))
}
/// Convert from Rust SDK types to mcp-types.
///
/// The Rust SDK types are the same as our mcp-types crate because they are both
/// derived from the same MCP specification.
/// As a result, it should be safe to convert directly from one to the other.
pub(crate) fn convert_to_mcp<T, U>(value: T) -> Result<U>
where
T: serde::Serialize,
U: serde::de::DeserializeOwned,
{
let json = serde_json::to_value(value)?;
serde_json::from_value(json).map_err(|err| anyhow!(err))
}
pub(crate) fn create_env_for_mcp_server(
extra_env: Option<HashMap<String, String>>,
) -> HashMap<String, String> {
DEFAULT_ENV_VARS
.iter()
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
.chain(extra_env.unwrap_or_default())
.collect()
}
#[cfg(unix)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
"HOME",
"LOGNAME",
"PATH",
"SHELL",
"USER",
"__CF_USER_TEXT_ENCODING",
"LANG",
"LC_ALL",
"TERM",
"TMPDIR",
"TZ",
];
#[cfg(windows)]
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
"PATH",
"PATHEXT",
"USERNAME",
"USERDOMAIN",
"USERPROFILE",
"TEMP",
"TMP",
];
#[cfg(test)]
mod tests {
use super::*;
use mcp_types::ContentBlock;
use pretty_assertions::assert_eq;
use rmcp::model::CallToolResult as RmcpCallToolResult;
use serde_json::json;
#[tokio::test]
async fn create_env_honors_overrides() {
let value = "custom".to_string();
let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])));
assert_eq!(env.get("TZ"), Some(&value));
}
#[test]
fn convert_call_tool_result_defaults_missing_content() -> Result<()> {
let structured_content = json!({ "key": "value" });
let rmcp_result = RmcpCallToolResult {
content: vec![],
structured_content: Some(structured_content.clone()),
is_error: Some(true),
meta: None,
};
let result = convert_call_tool_result(rmcp_result)?;
assert!(result.content.is_empty());
assert_eq!(result.structured_content, Some(structured_content));
assert_eq!(result.is_error, Some(true));
Ok(())
}
#[test]
fn convert_call_tool_result_preserves_existing_content() -> Result<()> {
let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]);
let result = convert_call_tool_result(rmcp_result)?;
assert_eq!(result.content.len(), 1);
match &result.content[0] {
ContentBlock::TextContent(text_content) => {
assert_eq!(text_content.text, "hello");
assert_eq!(text_content.r#type, "text");
}
other => panic!("expected text content got {other:?}"),
}
assert_eq!(result.structured_content, None);
assert_eq!(result.is_error, Some(false));
Ok(())
}
}