Files
llmx/codex-rs/rmcp-client/src/rmcp_client.rs
Gabriel Peal 40fba1bb4c [MCP] Add support for resources (#5239)
This PR adds support for [MCP
resources](https://modelcontextprotocol.io/specification/2025-06-18/server/resources)
by adding three new tools for the model:
1. `list_resources`
2. `list_resource_templates`
3. `read_resource`

These 3 tools correspond to the [three primary MCP resource protocol
messages](https://modelcontextprotocol.io/specification/2025-06-18/server/resources#protocol-messages).

Example of listing and reading a GitHub resource tempalte
<img width="2984" height="804" alt="CleanShot 2025-10-15 at 17 31 10"
src="https://github.com/user-attachments/assets/89b7f215-2e2a-41c5-90dd-b932ac84a585"
/>

`/mcp` with Figma configured
<img width="2984" height="442" alt="CleanShot 2025-10-15 at 18 29 35"
src="https://github.com/user-attachments/assets/a7578080-2ed2-4c59-b9b4-d8461f90d8ee"
/>

Fixes #4956
2025-10-17 01:05:15 -04:00

415 lines
14 KiB
Rust

use std::collections::HashMap;
use std::ffi::OsString;
use std::io;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use futures::FutureExt;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::InitializeRequestParams;
use mcp_types::InitializeResult;
use mcp_types::ListResourceTemplatesRequestParams;
use mcp_types::ListResourceTemplatesResult;
use mcp_types::ListResourcesRequestParams;
use mcp_types::ListResourcesResult;
use mcp_types::ListToolsRequestParams;
use mcp_types::ListToolsResult;
use mcp_types::ReadResourceRequestParams;
use mcp_types::ReadResourceResult;
use reqwest::header::HeaderMap;
use rmcp::model::CallToolRequestParam;
use rmcp::model::InitializeRequestParam;
use rmcp::model::PaginatedRequestParam;
use rmcp::model::ReadResourceRequestParam;
use rmcp::service::RoleClient;
use rmcp::service::RunningService;
use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::time;
use tracing::info;
use tracing::warn;
use crate::load_oauth_tokens;
use crate::logging_client_handler::LoggingClientHandler;
use crate::oauth::OAuthCredentialsStoreMode;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use crate::utils::convert_call_tool_result;
use crate::utils::convert_to_mcp;
use crate::utils::convert_to_rmcp;
use crate::utils::create_env_for_mcp_server;
use crate::utils::run_with_timeout;
enum PendingTransport {
ChildProcess(TokioChildProcess),
StreamableHttp {
transport: StreamableHttpClientTransport<reqwest::Client>,
},
StreamableHttpWithOAuth {
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
oauth_persistor: OAuthPersistor,
},
}
enum ClientState {
Connecting {
transport: Option<PendingTransport>,
},
Ready {
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
oauth: Option<OAuthPersistor>,
},
}
/// MCP client implemented on top of the official `rmcp` SDK.
/// https://github.com/modelcontextprotocol/rust-sdk
pub struct RmcpClient {
state: Mutex<ClientState>,
}
impl RmcpClient {
pub async fn new_stdio_client(
program: OsString,
args: Vec<OsString>,
env: Option<HashMap<String, String>>,
env_vars: &[String],
cwd: Option<PathBuf>,
) -> io::Result<Self> {
let program_name = program.to_string_lossy().into_owned();
let mut command = Command::new(&program);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.envs(create_env_for_mcp_server(env, env_vars))
.args(&args);
if let Some(cwd) = cwd {
command.current_dir(cwd);
}
let (transport, stderr) = TokioChildProcess::builder(command)
.stderr(Stdio::piped())
.spawn()?;
if let Some(stderr) = stderr {
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
loop {
match reader.next_line().await {
Ok(Some(line)) => {
info!("MCP server stderr ({program_name}): {line}");
}
Ok(None) => break,
Err(error) => {
warn!("Failed to read MCP server stderr ({program_name}): {error}");
break;
}
}
}
});
}
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(PendingTransport::ChildProcess(transport)),
}),
})
}
#[allow(clippy::too_many_arguments)]
pub async fn new_streamable_http_client(
server_name: &str,
url: &str,
bearer_token: Option<String>,
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
) -> Result<Self> {
let default_headers = build_default_headers(http_headers, env_http_headers)?;
let initial_oauth_tokens = match bearer_token {
Some(_) => None,
None => match load_oauth_tokens(server_name, url, store_mode) {
Ok(tokens) => tokens,
Err(err) => {
warn!("failed to read tokens for server `{server_name}`: {err}");
None
}
},
};
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
let (transport, oauth_persistor) = create_oauth_transport_and_runtime(
server_name,
url,
initial_tokens,
store_mode,
default_headers.clone(),
)
.await?;
PendingTransport::StreamableHttpWithOAuth {
transport,
oauth_persistor,
}
} else {
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
if let Some(bearer_token) = bearer_token.clone() {
http_config = http_config.auth_header(bearer_token);
}
let http_client =
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
let transport = StreamableHttpClientTransport::with_client(http_client, http_config);
PendingTransport::StreamableHttp { transport }
};
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(transport),
}),
})
}
/// Perform the initialization handshake with the MCP server.
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
pub async fn initialize(
&self,
params: InitializeRequestParams,
timeout: Option<Duration>,
) -> Result<InitializeResult> {
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
let client_handler = LoggingClientHandler::new(rmcp_params);
let (transport, oauth_persistor) = {
let mut guard = self.state.lock().await;
match &mut *guard {
ClientState::Connecting { transport } => match transport.take() {
Some(PendingTransport::ChildProcess(transport)) => (
service::serve_client(client_handler.clone(), transport).boxed(),
None,
),
Some(PendingTransport::StreamableHttp { transport }) => (
service::serve_client(client_handler.clone(), transport).boxed(),
None,
),
Some(PendingTransport::StreamableHttpWithOAuth {
transport,
oauth_persistor,
}) => (
service::serve_client(client_handler.clone(), transport).boxed(),
Some(oauth_persistor),
),
None => return Err(anyhow!("client already initializing")),
},
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
}
};
let service = match timeout {
Some(duration) => time::timeout(duration, transport)
.await
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
None => transport
.await
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
};
let initialize_result_rmcp = service
.peer()
.peer_info()
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
{
let mut guard = self.state.lock().await;
*guard = ClientState::Ready {
service: Arc::new(service),
oauth: oauth_persistor.clone(),
};
}
if let Some(runtime) = oauth_persistor
&& let Err(error) = runtime.persist_if_needed().await
{
warn!("failed to persist OAuth tokens after initialize: {error}");
}
Ok(initialize_result)
}
pub async fn list_tools(
&self,
params: Option<ListToolsRequestParams>,
timeout: Option<Duration>,
) -> Result<ListToolsResult> {
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
.transpose()?;
let fut = service.list_tools(rmcp_params);
let result = run_with_timeout(fut, timeout, "tools/list").await?;
let converted = convert_to_mcp(result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
pub async fn list_resources(
&self,
params: Option<ListResourcesRequestParams>,
timeout: Option<Duration>,
) -> Result<ListResourcesResult> {
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
.transpose()?;
let fut = service.list_resources(rmcp_params);
let result = run_with_timeout(fut, timeout, "resources/list").await?;
let converted = convert_to_mcp(result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
pub async fn list_resource_templates(
&self,
params: Option<ListResourceTemplatesRequestParams>,
timeout: Option<Duration>,
) -> Result<ListResourceTemplatesResult> {
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
.transpose()?;
let fut = service.list_resource_templates(rmcp_params);
let result = run_with_timeout(fut, timeout, "resources/templates/list").await?;
let converted = convert_to_mcp(result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
pub async fn read_resource(
&self,
params: ReadResourceRequestParams,
timeout: Option<Duration>,
) -> Result<ReadResourceResult> {
let service = self.service().await?;
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
let fut = service.read_resource(rmcp_params);
let result = run_with_timeout(fut, timeout, "resources/read").await?;
let converted = convert_to_mcp(result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
pub async fn call_tool(
&self,
name: String,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<CallToolResult> {
let service = self.service().await?;
let params = CallToolRequestParams { arguments, name };
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
let fut = service.call_tool(rmcp_params);
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
let converted = convert_call_tool_result(rmcp_result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
}
}
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready {
oauth: Some(runtime),
service: _,
} => Some(runtime.clone()),
_ => None,
}
}
/// This should be called after every tool call so that if a given tool call triggered
/// a refresh of the OAuth tokens, they are persisted.
async fn persist_oauth_tokens(&self) {
if let Some(runtime) = self.oauth_persistor().await
&& let Err(error) = runtime.persist_if_needed().await
{
warn!("failed to persist OAuth tokens: {error}");
}
}
}
async fn create_oauth_transport_and_runtime(
server_name: &str,
url: &str,
initial_tokens: StoredOAuthTokens,
credentials_store: OAuthCredentialsStoreMode,
default_headers: HeaderMap,
) -> Result<(
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
OAuthPersistor,
)> {
let http_client =
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
oauth_state
.set_credentials(
&initial_tokens.client_id,
initial_tokens.token_response.0.clone(),
)
.await?;
let manager = match oauth_state {
OAuthState::Authorized(manager) => manager,
OAuthState::Unauthorized(manager) => manager,
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
return Err(anyhow!("unexpected OAuth state during client setup"));
}
};
let auth_client = AuthClient::new(http_client, manager);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
);
let runtime = OAuthPersistor::new(
server_name.to_string(),
url.to_string(),
auth_manager,
credentials_store,
Some(initial_tokens),
);
Ok((transport, runtime))
}