This PR adds support for [MCP resources](https://modelcontextprotocol.io/specification/2025-06-18/server/resources) by adding three new tools for the model: 1. `list_resources` 2. `list_resource_templates` 3. `read_resource` These 3 tools correspond to the [three primary MCP resource protocol messages](https://modelcontextprotocol.io/specification/2025-06-18/server/resources#protocol-messages). Example of listing and reading a GitHub resource tempalte <img width="2984" height="804" alt="CleanShot 2025-10-15 at 17 31 10" src="https://github.com/user-attachments/assets/89b7f215-2e2a-41c5-90dd-b932ac84a585" /> `/mcp` with Figma configured <img width="2984" height="442" alt="CleanShot 2025-10-15 at 18 29 35" src="https://github.com/user-attachments/assets/a7578080-2ed2-4c59-b9b4-d8461f90d8ee" /> Fixes #4956
415 lines
14 KiB
Rust
415 lines
14 KiB
Rust
use std::collections::HashMap;
|
|
use std::ffi::OsString;
|
|
use std::io;
|
|
use std::path::PathBuf;
|
|
use std::process::Stdio;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Result;
|
|
use anyhow::anyhow;
|
|
use futures::FutureExt;
|
|
use mcp_types::CallToolRequestParams;
|
|
use mcp_types::CallToolResult;
|
|
use mcp_types::InitializeRequestParams;
|
|
use mcp_types::InitializeResult;
|
|
use mcp_types::ListResourceTemplatesRequestParams;
|
|
use mcp_types::ListResourceTemplatesResult;
|
|
use mcp_types::ListResourcesRequestParams;
|
|
use mcp_types::ListResourcesResult;
|
|
use mcp_types::ListToolsRequestParams;
|
|
use mcp_types::ListToolsResult;
|
|
use mcp_types::ReadResourceRequestParams;
|
|
use mcp_types::ReadResourceResult;
|
|
use reqwest::header::HeaderMap;
|
|
use rmcp::model::CallToolRequestParam;
|
|
use rmcp::model::InitializeRequestParam;
|
|
use rmcp::model::PaginatedRequestParam;
|
|
use rmcp::model::ReadResourceRequestParam;
|
|
use rmcp::service::RoleClient;
|
|
use rmcp::service::RunningService;
|
|
use rmcp::service::{self};
|
|
use rmcp::transport::StreamableHttpClientTransport;
|
|
use rmcp::transport::auth::AuthClient;
|
|
use rmcp::transport::auth::OAuthState;
|
|
use rmcp::transport::child_process::TokioChildProcess;
|
|
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
|
use tokio::io::AsyncBufReadExt;
|
|
use tokio::io::BufReader;
|
|
use tokio::process::Command;
|
|
use tokio::sync::Mutex;
|
|
use tokio::time;
|
|
use tracing::info;
|
|
use tracing::warn;
|
|
|
|
use crate::load_oauth_tokens;
|
|
use crate::logging_client_handler::LoggingClientHandler;
|
|
use crate::oauth::OAuthCredentialsStoreMode;
|
|
use crate::oauth::OAuthPersistor;
|
|
use crate::oauth::StoredOAuthTokens;
|
|
use crate::utils::apply_default_headers;
|
|
use crate::utils::build_default_headers;
|
|
use crate::utils::convert_call_tool_result;
|
|
use crate::utils::convert_to_mcp;
|
|
use crate::utils::convert_to_rmcp;
|
|
use crate::utils::create_env_for_mcp_server;
|
|
use crate::utils::run_with_timeout;
|
|
|
|
enum PendingTransport {
|
|
ChildProcess(TokioChildProcess),
|
|
StreamableHttp {
|
|
transport: StreamableHttpClientTransport<reqwest::Client>,
|
|
},
|
|
StreamableHttpWithOAuth {
|
|
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
|
oauth_persistor: OAuthPersistor,
|
|
},
|
|
}
|
|
|
|
enum ClientState {
|
|
Connecting {
|
|
transport: Option<PendingTransport>,
|
|
},
|
|
Ready {
|
|
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
|
oauth: Option<OAuthPersistor>,
|
|
},
|
|
}
|
|
|
|
/// MCP client implemented on top of the official `rmcp` SDK.
|
|
/// https://github.com/modelcontextprotocol/rust-sdk
|
|
pub struct RmcpClient {
|
|
state: Mutex<ClientState>,
|
|
}
|
|
|
|
impl RmcpClient {
|
|
pub async fn new_stdio_client(
|
|
program: OsString,
|
|
args: Vec<OsString>,
|
|
env: Option<HashMap<String, String>>,
|
|
env_vars: &[String],
|
|
cwd: Option<PathBuf>,
|
|
) -> io::Result<Self> {
|
|
let program_name = program.to_string_lossy().into_owned();
|
|
let mut command = Command::new(&program);
|
|
command
|
|
.kill_on_drop(true)
|
|
.stdin(Stdio::piped())
|
|
.stdout(Stdio::piped())
|
|
.env_clear()
|
|
.envs(create_env_for_mcp_server(env, env_vars))
|
|
.args(&args);
|
|
if let Some(cwd) = cwd {
|
|
command.current_dir(cwd);
|
|
}
|
|
|
|
let (transport, stderr) = TokioChildProcess::builder(command)
|
|
.stderr(Stdio::piped())
|
|
.spawn()?;
|
|
|
|
if let Some(stderr) = stderr {
|
|
tokio::spawn(async move {
|
|
let mut reader = BufReader::new(stderr).lines();
|
|
loop {
|
|
match reader.next_line().await {
|
|
Ok(Some(line)) => {
|
|
info!("MCP server stderr ({program_name}): {line}");
|
|
}
|
|
Ok(None) => break,
|
|
Err(error) => {
|
|
warn!("Failed to read MCP server stderr ({program_name}): {error}");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
Ok(Self {
|
|
state: Mutex::new(ClientState::Connecting {
|
|
transport: Some(PendingTransport::ChildProcess(transport)),
|
|
}),
|
|
})
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
pub async fn new_streamable_http_client(
|
|
server_name: &str,
|
|
url: &str,
|
|
bearer_token: Option<String>,
|
|
http_headers: Option<HashMap<String, String>>,
|
|
env_http_headers: Option<HashMap<String, String>>,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<Self> {
|
|
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
|
|
|
let initial_oauth_tokens = match bearer_token {
|
|
Some(_) => None,
|
|
None => match load_oauth_tokens(server_name, url, store_mode) {
|
|
Ok(tokens) => tokens,
|
|
Err(err) => {
|
|
warn!("failed to read tokens for server `{server_name}`: {err}");
|
|
None
|
|
}
|
|
},
|
|
};
|
|
|
|
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
|
let (transport, oauth_persistor) = create_oauth_transport_and_runtime(
|
|
server_name,
|
|
url,
|
|
initial_tokens,
|
|
store_mode,
|
|
default_headers.clone(),
|
|
)
|
|
.await?;
|
|
PendingTransport::StreamableHttpWithOAuth {
|
|
transport,
|
|
oauth_persistor,
|
|
}
|
|
} else {
|
|
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
|
if let Some(bearer_token) = bearer_token.clone() {
|
|
http_config = http_config.auth_header(bearer_token);
|
|
}
|
|
|
|
let http_client =
|
|
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
|
|
|
let transport = StreamableHttpClientTransport::with_client(http_client, http_config);
|
|
PendingTransport::StreamableHttp { transport }
|
|
};
|
|
Ok(Self {
|
|
state: Mutex::new(ClientState::Connecting {
|
|
transport: Some(transport),
|
|
}),
|
|
})
|
|
}
|
|
|
|
/// Perform the initialization handshake with the MCP server.
|
|
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
|
|
pub async fn initialize(
|
|
&self,
|
|
params: InitializeRequestParams,
|
|
timeout: Option<Duration>,
|
|
) -> Result<InitializeResult> {
|
|
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
|
|
let client_handler = LoggingClientHandler::new(rmcp_params);
|
|
|
|
let (transport, oauth_persistor) = {
|
|
let mut guard = self.state.lock().await;
|
|
match &mut *guard {
|
|
ClientState::Connecting { transport } => match transport.take() {
|
|
Some(PendingTransport::ChildProcess(transport)) => (
|
|
service::serve_client(client_handler.clone(), transport).boxed(),
|
|
None,
|
|
),
|
|
Some(PendingTransport::StreamableHttp { transport }) => (
|
|
service::serve_client(client_handler.clone(), transport).boxed(),
|
|
None,
|
|
),
|
|
Some(PendingTransport::StreamableHttpWithOAuth {
|
|
transport,
|
|
oauth_persistor,
|
|
}) => (
|
|
service::serve_client(client_handler.clone(), transport).boxed(),
|
|
Some(oauth_persistor),
|
|
),
|
|
None => return Err(anyhow!("client already initializing")),
|
|
},
|
|
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
|
|
}
|
|
};
|
|
|
|
let service = match timeout {
|
|
Some(duration) => time::timeout(duration, transport)
|
|
.await
|
|
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
|
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
|
None => transport
|
|
.await
|
|
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
|
};
|
|
|
|
let initialize_result_rmcp = service
|
|
.peer()
|
|
.peer_info()
|
|
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
|
|
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
|
|
|
|
{
|
|
let mut guard = self.state.lock().await;
|
|
*guard = ClientState::Ready {
|
|
service: Arc::new(service),
|
|
oauth: oauth_persistor.clone(),
|
|
};
|
|
}
|
|
|
|
if let Some(runtime) = oauth_persistor
|
|
&& let Err(error) = runtime.persist_if_needed().await
|
|
{
|
|
warn!("failed to persist OAuth tokens after initialize: {error}");
|
|
}
|
|
|
|
Ok(initialize_result)
|
|
}
|
|
|
|
pub async fn list_tools(
|
|
&self,
|
|
params: Option<ListToolsRequestParams>,
|
|
timeout: Option<Duration>,
|
|
) -> Result<ListToolsResult> {
|
|
let service = self.service().await?;
|
|
let rmcp_params = params
|
|
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
|
.transpose()?;
|
|
|
|
let fut = service.list_tools(rmcp_params);
|
|
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
|
let converted = convert_to_mcp(result)?;
|
|
self.persist_oauth_tokens().await;
|
|
Ok(converted)
|
|
}
|
|
|
|
pub async fn list_resources(
|
|
&self,
|
|
params: Option<ListResourcesRequestParams>,
|
|
timeout: Option<Duration>,
|
|
) -> Result<ListResourcesResult> {
|
|
let service = self.service().await?;
|
|
let rmcp_params = params
|
|
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
|
.transpose()?;
|
|
|
|
let fut = service.list_resources(rmcp_params);
|
|
let result = run_with_timeout(fut, timeout, "resources/list").await?;
|
|
let converted = convert_to_mcp(result)?;
|
|
self.persist_oauth_tokens().await;
|
|
Ok(converted)
|
|
}
|
|
|
|
pub async fn list_resource_templates(
|
|
&self,
|
|
params: Option<ListResourceTemplatesRequestParams>,
|
|
timeout: Option<Duration>,
|
|
) -> Result<ListResourceTemplatesResult> {
|
|
let service = self.service().await?;
|
|
let rmcp_params = params
|
|
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
|
.transpose()?;
|
|
|
|
let fut = service.list_resource_templates(rmcp_params);
|
|
let result = run_with_timeout(fut, timeout, "resources/templates/list").await?;
|
|
let converted = convert_to_mcp(result)?;
|
|
self.persist_oauth_tokens().await;
|
|
Ok(converted)
|
|
}
|
|
|
|
pub async fn read_resource(
|
|
&self,
|
|
params: ReadResourceRequestParams,
|
|
timeout: Option<Duration>,
|
|
) -> Result<ReadResourceResult> {
|
|
let service = self.service().await?;
|
|
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
|
|
let fut = service.read_resource(rmcp_params);
|
|
let result = run_with_timeout(fut, timeout, "resources/read").await?;
|
|
let converted = convert_to_mcp(result)?;
|
|
self.persist_oauth_tokens().await;
|
|
Ok(converted)
|
|
}
|
|
|
|
pub async fn call_tool(
|
|
&self,
|
|
name: String,
|
|
arguments: Option<serde_json::Value>,
|
|
timeout: Option<Duration>,
|
|
) -> Result<CallToolResult> {
|
|
let service = self.service().await?;
|
|
let params = CallToolRequestParams { arguments, name };
|
|
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
|
let fut = service.call_tool(rmcp_params);
|
|
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
|
let converted = convert_call_tool_result(rmcp_result)?;
|
|
self.persist_oauth_tokens().await;
|
|
Ok(converted)
|
|
}
|
|
|
|
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
|
let guard = self.state.lock().await;
|
|
match &*guard {
|
|
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
|
|
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
|
}
|
|
}
|
|
|
|
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
|
|
let guard = self.state.lock().await;
|
|
match &*guard {
|
|
ClientState::Ready {
|
|
oauth: Some(runtime),
|
|
service: _,
|
|
} => Some(runtime.clone()),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// This should be called after every tool call so that if a given tool call triggered
|
|
/// a refresh of the OAuth tokens, they are persisted.
|
|
async fn persist_oauth_tokens(&self) {
|
|
if let Some(runtime) = self.oauth_persistor().await
|
|
&& let Err(error) = runtime.persist_if_needed().await
|
|
{
|
|
warn!("failed to persist OAuth tokens: {error}");
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn create_oauth_transport_and_runtime(
|
|
server_name: &str,
|
|
url: &str,
|
|
initial_tokens: StoredOAuthTokens,
|
|
credentials_store: OAuthCredentialsStoreMode,
|
|
default_headers: HeaderMap,
|
|
) -> Result<(
|
|
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
|
OAuthPersistor,
|
|
)> {
|
|
let http_client =
|
|
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
|
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
|
|
|
oauth_state
|
|
.set_credentials(
|
|
&initial_tokens.client_id,
|
|
initial_tokens.token_response.0.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let manager = match oauth_state {
|
|
OAuthState::Authorized(manager) => manager,
|
|
OAuthState::Unauthorized(manager) => manager,
|
|
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
|
return Err(anyhow!("unexpected OAuth state during client setup"));
|
|
}
|
|
};
|
|
|
|
let auth_client = AuthClient::new(http_client, manager);
|
|
let auth_manager = auth_client.auth_manager.clone();
|
|
|
|
let transport = StreamableHttpClientTransport::with_client(
|
|
auth_client,
|
|
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
|
);
|
|
|
|
let runtime = OAuthPersistor::new(
|
|
server_name.to_string(),
|
|
url.to_string(),
|
|
auth_manager,
|
|
credentials_store,
|
|
Some(initial_tokens),
|
|
);
|
|
|
|
Ok((transport, runtime))
|
|
}
|