Phase 1: Repository & Infrastructure Setup
- Renamed directories: codex-rs -> llmx-rs, codex-cli -> llmx-cli
- Updated package.json files:
- Root: llmx-monorepo
- CLI: @llmx/llmx
- SDK: @llmx/llmx-sdk
- Updated pnpm workspace configuration
- Renamed binary: codex.js -> llmx.js
- Updated environment variables: CODEX_* -> LLMX_*
- Changed repository URLs to valknar/llmx
🤖 Generated with Claude Code
This commit is contained in:
138
llmx-rs/rmcp-client/src/auth_status.rs
Normal file
138
llmx-rs/rmcp-client/src/auth_status.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Error;
|
||||
use anyhow::Result;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use reqwest::Client;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::has_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const OAUTH_DISCOVERY_HEADER: &str = "MCP-Protocol-Version";
|
||||
const OAUTH_DISCOVERY_VERSION: &str = "2024-11-05";
|
||||
|
||||
/// Determine the authentication status for a streamable HTTP MCP server.
|
||||
pub async fn determine_streamable_http_auth_status(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
if bearer_token_env_var.is_some() {
|
||||
return Ok(McpAuthStatus::BearerToken);
|
||||
}
|
||||
|
||||
if has_oauth_tokens(server_name, url, store_mode)? {
|
||||
return Ok(McpAuthStatus::OAuth);
|
||||
}
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
match supports_oauth_login_with_headers(url, &default_headers).await {
|
||||
Ok(true) => Ok(McpAuthStatus::NotLoggedIn),
|
||||
Ok(false) => Ok(McpAuthStatus::Unsupported),
|
||||
Err(error) => {
|
||||
debug!(
|
||||
"failed to detect OAuth support for MCP server `{server_name}` at {url}: {error:?}"
|
||||
);
|
||||
Ok(McpAuthStatus::Unsupported)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to determine whether a streamable HTTP MCP server advertises OAuth login.
|
||||
pub async fn supports_oauth_login(url: &str) -> Result<bool> {
|
||||
supports_oauth_login_with_headers(url, &HeaderMap::new()).await
|
||||
}
|
||||
|
||||
async fn supports_oauth_login_with_headers(url: &str, default_headers: &HeaderMap) -> Result<bool> {
|
||||
let base_url = Url::parse(url)?;
|
||||
let builder = Client::builder().timeout(DISCOVERY_TIMEOUT);
|
||||
let client = apply_default_headers(builder, default_headers).build()?;
|
||||
|
||||
let mut last_error: Option<Error> = None;
|
||||
for candidate_path in discovery_paths(base_url.path()) {
|
||||
let mut discovery_url = base_url.clone();
|
||||
discovery_url.set_path(&candidate_path);
|
||||
|
||||
let response = match client
|
||||
.get(discovery_url.clone())
|
||||
.header(OAUTH_DISCOVERY_HEADER, OAUTH_DISCOVERY_VERSION)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if response.status() != StatusCode::OK {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metadata = match response.json::<OAuthDiscoveryMetadata>().await {
|
||||
Ok(metadata) => metadata,
|
||||
Err(err) => {
|
||||
last_error = Some(err.into());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.authorization_endpoint.is_some() && metadata.token_endpoint.is_some() {
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = last_error {
|
||||
debug!("OAuth discovery requests failed for {url}: {err:?}");
|
||||
}
|
||||
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthDiscoveryMetadata {
|
||||
#[serde(default)]
|
||||
authorization_endpoint: Option<String>,
|
||||
#[serde(default)]
|
||||
token_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
/// Implements RFC 8414 section 3.1 for discovering well-known oauth endpoints.
|
||||
/// This is a requirement for MCP servers to support OAuth.
|
||||
/// https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk/blob/main/crates/rmcp/src/transport/auth.rs#L182
|
||||
fn discovery_paths(base_path: &str) -> Vec<String> {
|
||||
let trimmed = base_path.trim_start_matches('/').trim_end_matches('/');
|
||||
let canonical = "/.well-known/oauth-authorization-server".to_string();
|
||||
|
||||
if trimmed.is_empty() {
|
||||
return vec![canonical];
|
||||
}
|
||||
|
||||
let mut candidates = Vec::new();
|
||||
let mut push_unique = |candidate: String| {
|
||||
if !candidates.contains(&candidate) {
|
||||
candidates.push(candidate);
|
||||
}
|
||||
};
|
||||
|
||||
push_unique(format!("{canonical}/{trimmed}"));
|
||||
push_unique(format!("/{trimmed}/.well-known/oauth-authorization-server"));
|
||||
push_unique(canonical);
|
||||
|
||||
candidates
|
||||
}
|
||||
142
llmx-rs/rmcp-client/src/bin/rmcp_test_server.rs
Normal file
142
llmx-rs/rmcp-client/src/bin/rmcp_test_server.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::ServiceExt;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
}
|
||||
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
|
||||
(tokio::io::stdin(), tokio::io::stdout())
|
||||
}
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" },
|
||||
"env_var": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("echo tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("echo"),
|
||||
Cow::Borrowed("Echo back the provided message and include environment data."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EchoArgs {
|
||||
message: String,
|
||||
#[allow(dead_code)]
|
||||
env_var: Option<String>,
|
||||
}
|
||||
|
||||
impl ServerHandler for TestToolServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn list_tools(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
|
||||
let tools = self.tools.clone();
|
||||
async move {
|
||||
Ok(ListToolsResult {
|
||||
tools: (*tools).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" => {
|
||||
let args: EchoArgs = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
))
|
||||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing arguments for echo tool",
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
|
||||
let structured_content = json!({
|
||||
"echo": args.message,
|
||||
"env": env_snapshot.get("MCP_TEST_VALUE"),
|
||||
});
|
||||
|
||||
Ok(CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(structured_content),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
other => Err(McpError::invalid_params(
|
||||
format!("unknown tool: {other}"),
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
eprintln!("starting rmcp test server");
|
||||
// Run the server with STDIO transport. If the client disconnects we simply
|
||||
// bubble up the error so the process exits.
|
||||
let service = TestToolServer::new();
|
||||
let running = service.serve(stdio()).await?;
|
||||
|
||||
// Wait for the client to finish interacting with the server.
|
||||
running.waiting().await?;
|
||||
// Drain background tasks to ensure clean shutdown.
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
283
llmx-rs/rmcp-client/src/bin/test_stdio_server.rs
Normal file
283
llmx-rs/rmcp-client/src/bin/test_stdio_server.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::ServiceExt;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::RawResource;
|
||||
use rmcp::model::RawResourceTemplate;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::model::ReadResourceResult;
|
||||
use rmcp::model::Resource;
|
||||
use rmcp::model::ResourceContents;
|
||||
use rmcp::model::ResourceTemplate;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
resources: Arc<Vec<Resource>>,
|
||||
resource_templates: Arc<Vec<ResourceTemplate>>,
|
||||
}
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) {
|
||||
(tokio::io::stdin(), tokio::io::stdout())
|
||||
}
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool(), Self::image_tool()];
|
||||
let resources = vec![Self::memo_resource()];
|
||||
let resource_templates = vec![Self::memo_template()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
resources: Arc::new(resources),
|
||||
resource_templates: Arc::new(resource_templates),
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" },
|
||||
"env_var": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("echo tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("echo"),
|
||||
Cow::Borrowed("Echo back the provided message and include environment data."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
fn image_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("image tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("image"),
|
||||
Cow::Borrowed("Return a single image content block."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
fn memo_resource() -> Resource {
|
||||
let raw = RawResource {
|
||||
uri: MEMO_URI.to_string(),
|
||||
name: "example-note".to_string(),
|
||||
title: Some("Example Note".to_string()),
|
||||
description: Some("A sample MCP resource exposed for integration tests.".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
size: None,
|
||||
icons: None,
|
||||
};
|
||||
Resource::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_template() -> ResourceTemplate {
|
||||
let raw = RawResourceTemplate {
|
||||
uri_template: "memo://codex/{slug}".to_string(),
|
||||
name: "codex-memo".to_string(),
|
||||
title: Some("Codex Memo".to_string()),
|
||||
description: Some(
|
||||
"Template for memo://codex/{slug} resources used in tests.".to_string(),
|
||||
),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
};
|
||||
ResourceTemplate::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_text() -> &'static str {
|
||||
MEMO_CONTENT
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EchoArgs {
|
||||
message: String,
|
||||
#[allow(dead_code)]
|
||||
env_var: Option<String>,
|
||||
}
|
||||
|
||||
impl ServerHandler for TestToolServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.enable_resources()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn list_tools(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
|
||||
let tools = self.tools.clone();
|
||||
async move {
|
||||
Ok(ListToolsResult {
|
||||
tools: (*tools).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
|
||||
let resources = self.resources.clone();
|
||||
async move {
|
||||
Ok(ListResourcesResult {
|
||||
resources: (*resources).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
resource_templates: (*self.resource_templates).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
if uri == MEMO_URI {
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::TextResourceContents {
|
||||
uri,
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: Self::memo_text().to_string(),
|
||||
meta: None,
|
||||
}],
|
||||
})
|
||||
} else {
|
||||
Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({ "uri": uri })),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" => {
|
||||
let args: EchoArgs = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
))
|
||||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing arguments for echo tool",
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
|
||||
let structured_content = json!({
|
||||
"echo": format!("ECHOING: {}", args.message),
|
||||
"env": env_snapshot.get("MCP_TEST_VALUE"),
|
||||
});
|
||||
|
||||
Ok(CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(structured_content),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
"image" => {
|
||||
// Read a data URL (e.g. data:image/png;base64,AAA...) from env and convert to
|
||||
// an MCP image content block. Tests set MCP_TEST_IMAGE_DATA_URL.
|
||||
let data_url = std::env::var("MCP_TEST_IMAGE_DATA_URL").map_err(|_| {
|
||||
McpError::invalid_params(
|
||||
"missing MCP_TEST_IMAGE_DATA_URL env var for image tool",
|
||||
None,
|
||||
)
|
||||
})?;
|
||||
|
||||
fn parse_data_url(url: &str) -> Option<(String, String)> {
|
||||
let rest = url.strip_prefix("data:")?;
|
||||
let (mime_and_opts, data) = rest.split_once(',')?;
|
||||
let (mime, _opts) =
|
||||
mime_and_opts.split_once(';').unwrap_or((mime_and_opts, ""));
|
||||
Some((mime.to_string(), data.to_string()))
|
||||
}
|
||||
|
||||
let (mime_type, data_b64) = parse_data_url(&data_url).ok_or_else(|| {
|
||||
McpError::invalid_params(
|
||||
format!("invalid data URL for image tool: {data_url}"),
|
||||
None,
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(CallToolResult::success(vec![rmcp::model::Content::image(
|
||||
data_b64, mime_type,
|
||||
)]))
|
||||
}
|
||||
other => Err(McpError::invalid_params(
|
||||
format!("unknown tool: {other}"),
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
eprintln!("starting rmcp test server");
|
||||
// Run the server with STDIO transport. If the client disconnects we simply
|
||||
// bubble up the error so the process exits.
|
||||
let service = TestToolServer::new();
|
||||
let running = service.serve(stdio()).await?;
|
||||
|
||||
// Wait for the client to finish interacting with the server.
|
||||
running.waiting().await?;
|
||||
// Drain background tasks to ensure clean shutdown.
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
320
llmx-rs/rmcp-client/src/bin/test_streamable_http_server.rs
Normal file
320
llmx-rs/rmcp-client/src/bin/test_streamable_http_server.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::io::ErrorKind;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::State;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
use axum::http::header::CONTENT_TYPE;
|
||||
use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use axum::routing::get;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::JsonObject;
|
||||
use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::RawResource;
|
||||
use rmcp::model::RawResourceTemplate;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::model::ReadResourceResult;
|
||||
use rmcp::model::Resource;
|
||||
use rmcp::model::ResourceContents;
|
||||
use rmcp::model::ResourceTemplate;
|
||||
use rmcp::model::ServerCapabilities;
|
||||
use rmcp::model::ServerInfo;
|
||||
use rmcp::model::Tool;
|
||||
use rmcp::transport::StreamableHttpServerConfig;
|
||||
use rmcp::transport::StreamableHttpService;
|
||||
use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
tools: Arc<Vec<Tool>>,
|
||||
resources: Arc<Vec<Resource>>,
|
||||
resource_templates: Arc<Vec<ResourceTemplate>>,
|
||||
}
|
||||
|
||||
const MEMO_URI: &str = "memo://codex/example-note";
|
||||
const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server.";
|
||||
|
||||
impl TestToolServer {
|
||||
fn new() -> Self {
|
||||
let tools = vec![Self::echo_tool()];
|
||||
let resources = vec![Self::memo_resource()];
|
||||
let resource_templates = vec![Self::memo_template()];
|
||||
Self {
|
||||
tools: Arc::new(tools),
|
||||
resources: Arc::new(resources),
|
||||
resource_templates: Arc::new(resource_templates),
|
||||
}
|
||||
}
|
||||
|
||||
fn echo_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": { "type": "string" },
|
||||
"env_var": { "type": "string" }
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("echo tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("echo"),
|
||||
Cow::Borrowed("Echo back the provided message and include environment data."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
fn memo_resource() -> Resource {
|
||||
let raw = RawResource {
|
||||
uri: MEMO_URI.to_string(),
|
||||
name: "example-note".to_string(),
|
||||
title: Some("Example Note".to_string()),
|
||||
description: Some("A sample MCP resource exposed for integration tests.".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
size: None,
|
||||
icons: None,
|
||||
};
|
||||
Resource::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_template() -> ResourceTemplate {
|
||||
let raw = RawResourceTemplate {
|
||||
uri_template: "memo://codex/{slug}".to_string(),
|
||||
name: "codex-memo".to_string(),
|
||||
title: Some("Codex Memo".to_string()),
|
||||
description: Some(
|
||||
"Template for memo://codex/{slug} resources used in tests.".to_string(),
|
||||
),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
};
|
||||
ResourceTemplate::new(raw, None)
|
||||
}
|
||||
|
||||
fn memo_text() -> &'static str {
|
||||
MEMO_CONTENT
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EchoArgs {
|
||||
message: String,
|
||||
#[allow(dead_code)]
|
||||
env_var: Option<String>,
|
||||
}
|
||||
|
||||
impl ServerHandler for TestToolServer {
|
||||
fn get_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
capabilities: ServerCapabilities::builder()
|
||||
.enable_tools()
|
||||
.enable_tool_list_changed()
|
||||
.enable_resources()
|
||||
.build(),
|
||||
..ServerInfo::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn list_tools(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListToolsResult, McpError>> + Send + '_ {
|
||||
let tools = self.tools.clone();
|
||||
async move {
|
||||
Ok(ListToolsResult {
|
||||
tools: (*tools).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn list_resources(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
|
||||
let resources = self.resources.clone();
|
||||
async move {
|
||||
Ok(ListResourcesResult {
|
||||
resources: (*resources).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_resource_templates(
|
||||
&self,
|
||||
_request: Option<PaginatedRequestParam>,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ListResourceTemplatesResult, McpError> {
|
||||
Ok(ListResourceTemplatesResult {
|
||||
resource_templates: (*self.resource_templates).clone(),
|
||||
next_cursor: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_resource(
|
||||
&self,
|
||||
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<ReadResourceResult, McpError> {
|
||||
if uri == MEMO_URI {
|
||||
Ok(ReadResourceResult {
|
||||
contents: vec![ResourceContents::TextResourceContents {
|
||||
uri,
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: Self::memo_text().to_string(),
|
||||
meta: None,
|
||||
}],
|
||||
})
|
||||
} else {
|
||||
Err(McpError::resource_not_found(
|
||||
"resource_not_found",
|
||||
Some(json!({ "uri": uri })),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParam,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" => {
|
||||
let args: EchoArgs = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
))
|
||||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing arguments for echo tool",
|
||||
None,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let env_snapshot: HashMap<String, String> = std::env::vars().collect();
|
||||
let structured_content = json!({
|
||||
"echo": format!("ECHOING: {}", args.message),
|
||||
"env": env_snapshot.get("MCP_TEST_VALUE"),
|
||||
});
|
||||
|
||||
Ok(CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(structured_content),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
other => Err(McpError::invalid_params(
|
||||
format!("unknown tool: {other}"),
|
||||
None,
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bind_addr() -> Result<SocketAddr, Box<dyn std::error::Error>> {
|
||||
let default_addr = "127.0.0.1:3920";
|
||||
let bind_addr = std::env::var("MCP_STREAMABLE_HTTP_BIND_ADDR")
|
||||
.or_else(|_| std::env::var("BIND_ADDR"))
|
||||
.unwrap_or_else(|_| default_addr.to_string());
|
||||
Ok(bind_addr.parse()?)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let bind_addr = parse_bind_addr()?;
|
||||
let listener = match tokio::net::TcpListener::bind(&bind_addr).await {
|
||||
Ok(listener) => listener,
|
||||
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
|
||||
eprintln!(
|
||||
"failed to bind to {bind_addr}: {err}. make sure the process has network access"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp");
|
||||
|
||||
let router = Router::new()
|
||||
.route(
|
||||
"/.well-known/oauth-authorization-server/mcp",
|
||||
get({
|
||||
move || async move {
|
||||
let metadata_base = format!("http://{bind_addr}");
|
||||
#[expect(clippy::expect_used)]
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_vec(&json!({
|
||||
"authorization_endpoint": format!("{metadata_base}/oauth/authorize"),
|
||||
"token_endpoint": format!("{metadata_base}/oauth/token"),
|
||||
"scopes_supported": [""],
|
||||
})).expect("failed to serialize metadata"),
|
||||
))
|
||||
.expect("valid metadata response")
|
||||
}
|
||||
}),
|
||||
)
|
||||
.nest_service(
|
||||
"/mcp",
|
||||
StreamableHttpService::new(
|
||||
|| Ok(TestToolServer::new()),
|
||||
Arc::new(LocalSessionManager::default()),
|
||||
StreamableHttpServerConfig::default(),
|
||||
),
|
||||
);
|
||||
|
||||
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
|
||||
let expected = Arc::new(format!("Bearer {token}"));
|
||||
router.layer(middleware::from_fn_with_state(expected, require_bearer))
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
axum::serve(listener, router).await?;
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn require_bearer(
|
||||
State(expected): State<Arc<String>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request.uri().path().contains("/.well-known/") {
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
if request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.is_some_and(|value| value.as_bytes() == expected.as_bytes())
|
||||
{
|
||||
Ok(next.run(request).await)
|
||||
} else {
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
33
llmx-rs/rmcp-client/src/find_codex_home.rs
Normal file
33
llmx-rs/rmcp-client/src/find_codex_home.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use dirs::home_dir;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// This was copied from codex-core but codex-core depends on this crate.
|
||||
/// TODO: move this to a shared crate lower in the dependency tree.
|
||||
///
|
||||
///
|
||||
/// Returns the path to the Codex configuration directory, which can be
|
||||
/// specified by the `CODEX_HOME` environment variable. If not set, defaults to
|
||||
/// `~/.codex`.
|
||||
///
|
||||
/// - If `CODEX_HOME` is set, the value will be canonicalized and this
|
||||
/// function will Err if the path does not exist.
|
||||
/// - If `CODEX_HOME` is not set, this function does not verify that the
|
||||
/// directory exists.
|
||||
pub(crate) fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
// Honor the `CODEX_HOME` environment variable when it is set to allow users
|
||||
// (and tests) to override the default location.
|
||||
if let Ok(val) = std::env::var("CODEX_HOME")
|
||||
&& !val.is_empty()
|
||||
{
|
||||
return PathBuf::from(val).canonicalize();
|
||||
}
|
||||
|
||||
let mut p = home_dir().ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"Could not find home directory",
|
||||
)
|
||||
})?;
|
||||
p.push(".codex");
|
||||
Ok(p)
|
||||
}
|
||||
19
llmx-rs/rmcp-client/src/lib.rs
Normal file
19
llmx-rs/rmcp-client/src/lib.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
mod auth_status;
|
||||
mod find_codex_home;
|
||||
mod logging_client_handler;
|
||||
mod oauth;
|
||||
mod perform_oauth_login;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use auth_status::determine_streamable_http_auth_status;
|
||||
pub use auth_status::supports_oauth_login;
|
||||
pub use codex_protocol::protocol::McpAuthStatus;
|
||||
pub use oauth::OAuthCredentialsStoreMode;
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
pub(crate) use oauth::load_oauth_tokens;
|
||||
pub use oauth::save_oauth_tokens;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use rmcp_client::RmcpClient;
|
||||
134
llmx-rs/rmcp-client/src/logging_client_handler.rs
Normal file
134
llmx-rs/rmcp-client/src/logging_client_handler.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use rmcp::ClientHandler;
|
||||
use rmcp::RoleClient;
|
||||
use rmcp::model::CancelledNotificationParam;
|
||||
use rmcp::model::ClientInfo;
|
||||
use rmcp::model::CreateElicitationRequestParam;
|
||||
use rmcp::model::CreateElicitationResult;
|
||||
use rmcp::model::ElicitationAction;
|
||||
use rmcp::model::LoggingLevel;
|
||||
use rmcp::model::LoggingMessageNotificationParam;
|
||||
use rmcp::model::ProgressNotificationParam;
|
||||
use rmcp::model::ResourceUpdatedNotificationParam;
|
||||
use rmcp::service::NotificationContext;
|
||||
use rmcp::service::RequestContext;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoggingClientHandler {
|
||||
client_info: ClientInfo,
|
||||
}
|
||||
|
||||
impl LoggingClientHandler {
|
||||
pub(crate) fn new(client_info: ClientInfo) -> Self {
|
||||
Self { client_info }
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientHandler for LoggingClientHandler {
|
||||
// TODO (CODEX-3571): support elicitations.
|
||||
async fn create_elicitation(
|
||||
&self,
|
||||
request: CreateElicitationRequestParam,
|
||||
_context: RequestContext<RoleClient>,
|
||||
) -> Result<CreateElicitationResult, rmcp::ErrorData> {
|
||||
info!(
|
||||
"MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.",
|
||||
request.message
|
||||
);
|
||||
Ok(CreateElicitationResult {
|
||||
action: ElicitationAction::Decline,
|
||||
content: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn on_cancelled(
|
||||
&self,
|
||||
params: CancelledNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!(
|
||||
"MCP server cancelled request (request_id: {}, reason: {:?})",
|
||||
params.request_id, params.reason
|
||||
);
|
||||
}
|
||||
|
||||
async fn on_progress(
|
||||
&self,
|
||||
params: ProgressNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!(
|
||||
"MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})",
|
||||
params.progress_token, params.progress, params.total, params.message
|
||||
);
|
||||
}
|
||||
|
||||
async fn on_resource_updated(
|
||||
&self,
|
||||
params: ResourceUpdatedNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
info!("MCP server resource updated (uri: {})", params.uri);
|
||||
}
|
||||
|
||||
async fn on_resource_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server resource list changed");
|
||||
}
|
||||
|
||||
async fn on_tool_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server tool list changed");
|
||||
}
|
||||
|
||||
async fn on_prompt_list_changed(&self, _context: NotificationContext<RoleClient>) {
|
||||
info!("MCP server prompt list changed");
|
||||
}
|
||||
|
||||
fn get_info(&self) -> ClientInfo {
|
||||
self.client_info.clone()
|
||||
}
|
||||
|
||||
async fn on_logging_message(
|
||||
&self,
|
||||
params: LoggingMessageNotificationParam,
|
||||
_context: NotificationContext<RoleClient>,
|
||||
) {
|
||||
let LoggingMessageNotificationParam {
|
||||
level,
|
||||
logger,
|
||||
data,
|
||||
} = params;
|
||||
let logger = logger.as_deref();
|
||||
match level {
|
||||
LoggingLevel::Emergency
|
||||
| LoggingLevel::Alert
|
||||
| LoggingLevel::Critical
|
||||
| LoggingLevel::Error => {
|
||||
error!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Warning => {
|
||||
warn!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Notice | LoggingLevel::Info => {
|
||||
info!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
LoggingLevel::Debug => {
|
||||
debug!(
|
||||
"MCP server log message (level: {:?}, logger: {:?}, data: {})",
|
||||
level, logger, data
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
814
llmx-rs/rmcp-client/src/oauth.rs
Normal file
814
llmx-rs/rmcp-client/src/oauth.rs
Normal file
@@ -0,0 +1,814 @@
|
||||
//! This file handles all logic related to managing MCP OAuth credentials.
|
||||
//! All credentials are stored using the keyring crate which uses os-specific keyring services.
|
||||
//! https://crates.io/crates/keyring
|
||||
//! macOS: macOS keychain.
|
||||
//! Windows: Windows Credential Manager
|
||||
//! Linux: DBus-based Secret Service, the kernel keyutils, and a combo of the two
|
||||
//! FreeBSD, OpenBSD: DBus-based Secret Service
|
||||
//!
|
||||
//! For Linux, we use linux-native-async-persistent which uses both keyutils and async-secret-service (see below) for storage.
|
||||
//! See the docs for the keyutils_persistent module for a full explanation of why both are used. Because this store uses the
|
||||
//! async-secret-service, you must specify the additional features required by that store
|
||||
//!
|
||||
//! async-secret-service provides access to the DBus-based Secret Service storage on Linux, FreeBSD, and OpenBSD. This is an asynchronous
|
||||
//! keystore that always encrypts secrets when they are transferred across the bus. If DBus isn't installed the keystore will fall back to the json
|
||||
//! file because we don't use the "vendored" feature.
|
||||
//!
|
||||
//! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents.
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Error;
|
||||
use anyhow::Result;
|
||||
use oauth2::AccessToken;
|
||||
use oauth2::EmptyExtraTokenFields;
|
||||
use oauth2::RefreshToken;
|
||||
use oauth2::Scope;
|
||||
use oauth2::TokenResponse;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use rmcp::transport::auth::OAuthTokenResponse;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::map::Map as JsonMap;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fs;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use tracing::warn;
|
||||
|
||||
use codex_keyring_store::DefaultKeyringStore;
|
||||
use codex_keyring_store::KeyringStore;
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::find_codex_home::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct StoredOAuthTokens {
|
||||
pub server_name: String,
|
||||
pub url: String,
|
||||
pub client_id: String,
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
/// Determine where Codex should store and read MCP credentials.
|
||||
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OAuthCredentialsStoreMode {
|
||||
/// `Keyring` when available; otherwise, `File`.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
#[default]
|
||||
Auto,
|
||||
/// CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
File,
|
||||
/// Keyring when available, otherwise fail.
|
||||
Keyring,
|
||||
}
|
||||
|
||||
/// Wrap OAuthTokenResponse to allow for partial equality comparison.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WrappedOAuthTokenResponse(pub OAuthTokenResponse);
|
||||
|
||||
impl PartialEq for WrappedOAuthTokenResponse {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (serde_json::to_string(self), serde_json::to_string(other)) {
|
||||
(Ok(s1), Ok(s2)) => s1 == s2,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn has_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<bool> {
|
||||
Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some())
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => Ok(Some(tokens)),
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match keyring_store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => Ok(None),
|
||||
Err(error) => Err(Error::new(error.into_error())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
tokens,
|
||||
),
|
||||
OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = format!(
|
||||
"failed to write OAuth tokens to keyring: {}",
|
||||
error.message()
|
||||
);
|
||||
warn!("{message}");
|
||||
Err(Error::new(error.into_error()).context(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) => {
|
||||
let message = error.to_string();
|
||||
warn!("falling back to file storage for OAuth tokens: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<bool> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
|
||||
let keyring_removed = match keyring_result {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
|
||||
return Err(error.into_error())
|
||||
.context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let file_removed = delete_oauth_tokens_from_file(&key)?;
|
||||
Ok(keyring_removed || file_removed)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct OAuthPersistor {
|
||||
inner: Arc<OAuthPersistorInner>,
|
||||
}
|
||||
|
||||
struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the latest stored credentials if they have changed.
|
||||
/// Deletes the credentials if they are no longer present.
|
||||
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
|
||||
let (client_id, maybe_credentials) = {
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.get_credentials().await
|
||||
}?;
|
||||
|
||||
match maybe_credentials {
|
||||
Some(credentials) => {
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.inner.server_name.clone(),
|
||||
url: self.inner.url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials.clone()),
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) = delete_oauth_tokens(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
self.inner.server_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
const FALLBACK_FILENAME: &str = ".credentials.json";
|
||||
const MCP_SERVER_TYPE: &str = "http";
|
||||
|
||||
type FallbackFile = BTreeMap<String, FallbackTokenEntry>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FallbackTokenEntry {
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
client_id: String,
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let Some(store) = read_fallback_file()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
|
||||
for entry in store.values() {
|
||||
let entry_key = compute_store_key(&entry.server_name, &entry.server_url)?;
|
||||
if entry_key != key {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut token_response = OAuthTokenResponse::new(
|
||||
AccessToken::new(entry.access_token.clone()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
|
||||
if let Some(refresh) = entry.refresh_token.clone() {
|
||||
token_response.set_refresh_token(Some(RefreshToken::new(refresh)));
|
||||
}
|
||||
|
||||
let scopes = entry.scopes.clone();
|
||||
if !scopes.is_empty() {
|
||||
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at
|
||||
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
|
||||
{
|
||||
let duration = Duration::from_secs(seconds);
|
||||
token_response.set_expires_in(Some(&duration));
|
||||
}
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: entry.server_name.clone(),
|
||||
url: entry.server_url.clone(),
|
||||
client_id: entry.client_id.clone(),
|
||||
token_response: WrappedOAuthTokenResponse(token_response),
|
||||
};
|
||||
|
||||
return Ok(Some(stored));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let key = compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let mut store = read_fallback_file()?.unwrap_or_default();
|
||||
|
||||
let token_response = &tokens.token_response.0;
|
||||
let refresh_token = token_response
|
||||
.refresh_token()
|
||||
.map(|token| token.secret().to_string());
|
||||
let scopes = token_response
|
||||
.scopes()
|
||||
.map(|s| s.iter().map(|s| s.to_string()).collect())
|
||||
.unwrap_or_default();
|
||||
let entry = FallbackTokenEntry {
|
||||
server_name: tokens.server_name.clone(),
|
||||
server_url: tokens.url.clone(),
|
||||
client_id: tokens.client_id.clone(),
|
||||
access_token: token_response.access_token().secret().to_string(),
|
||||
expires_at: compute_expires_at_millis(token_response),
|
||||
refresh_token,
|
||||
scopes,
|
||||
};
|
||||
|
||||
store.insert(key, entry);
|
||||
write_fallback_file(&store)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
|
||||
let mut store = match read_fallback_file()? {
|
||||
Some(store) => store,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let removed = store.remove(key).is_some();
|
||||
|
||||
if removed {
|
||||
write_fallback_file(&store)?;
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
||||
let expires_in = response.expires_in()?;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let expiry = now.checked_add(expires_in)?;
|
||||
let millis = expiry.as_millis();
|
||||
if millis > u128::from(u64::MAX) {
|
||||
Some(u64::MAX)
|
||||
} else {
|
||||
Some(millis as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let now_ms = now.as_millis() as u64;
|
||||
|
||||
if expires_at <= now_ms {
|
||||
None
|
||||
} else {
|
||||
Some((expires_at - now_ms) / 1000)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
||||
let mut payload = JsonMap::new();
|
||||
payload.insert(
|
||||
"type".to_string(),
|
||||
Value::String(MCP_SERVER_TYPE.to_string()),
|
||||
);
|
||||
payload.insert("url".to_string(), Value::String(server_url.to_string()));
|
||||
payload.insert("headers".to_string(), Value::Object(JsonMap::new()));
|
||||
|
||||
let truncated = sha_256_prefix(&Value::Object(payload))?;
|
||||
Ok(format!("{server_name}|{truncated}"))
|
||||
}
|
||||
|
||||
fn fallback_file_path() -> Result<PathBuf> {
|
||||
let mut path = find_codex_home()?;
|
||||
path.push(FALLBACK_FILENAME);
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
fn read_fallback_file() -> Result<Option<FallbackFile>> {
|
||||
let path = fallback_file_path()?;
|
||||
let contents = match fs::read_to_string(&path) {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).context(format!(
|
||||
"failed to read credentials file at {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_str::<FallbackFile>(&contents) {
|
||||
Ok(store) => Ok(Some(store)),
|
||||
Err(e) => Err(e).context(format!(
|
||||
"failed to parse credentials file at {}",
|
||||
path.display()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_fallback_file(store: &FallbackFile) -> Result<()> {
|
||||
let path = fallback_file_path()?;
|
||||
|
||||
if store.is_empty() {
|
||||
if path.exists() {
|
||||
fs::remove_file(path)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let serialized = serde_json::to_string(store)?;
|
||||
fs::write(&path, serialized)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = fs::Permissions::from_mode(0o600);
|
||||
fs::set_permissions(&path, perms)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sha_256_prefix(value: &Value) -> Result<String> {
|
||||
let serialized =
|
||||
serde_json::to_string(&value).context("failed to serialize MCP OAuth key payload")?;
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(serialized.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let hex = format!("{digest:x}");
|
||||
let truncated = &hex[..16];
|
||||
Ok(truncated.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use keyring::Error as KeyringError;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::PoisonError;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use codex_keyring_store::tests::MockKeyringStore;
|
||||
|
||||
struct TempCodexHome {
|
||||
_guard: MutexGuard<'static, ()>,
|
||||
_dir: tempfile::TempDir,
|
||||
}
|
||||
|
||||
impl TempCodexHome {
|
||||
fn new() -> Self {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
let guard = LOCK
|
||||
.get_or_init(Mutex::default)
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner);
|
||||
let dir = tempdir().expect("create CODEX_HOME temp dir");
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", dir.path());
|
||||
}
|
||||
Self {
|
||||
_guard: guard,
|
||||
_dir: dir,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempCodexHome {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
std::env::remove_var("CODEX_HOME");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_reads_from_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded =
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_missing_in_keyring() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
let stored = store.saved_value(&key).expect("value saved to keyring");
|
||||
assert_eq!(serde_json::from_str::<StoredOAuthTokens>(&stored)?, tokens);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
let saved = super::read_fallback_file()?.expect("fallback file should load");
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let entry = saved.get(&key).expect("entry for key");
|
||||
assert_eq!(entry.server_name, tokens.server_name);
|
||||
assert_eq!(entry.server_url, tokens.url);
|
||||
assert_eq!(entry.client_id, tokens.client_id);
|
||||
assert_eq!(
|
||||
entry.access_token,
|
||||
tokens.token_response.0.access_token().secret().as_str()
|
||||
);
|
||||
assert!(store.saved_value(&key).is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_removes_all_storage() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
assert!(store.contains(&key));
|
||||
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockKeyringStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
) {
|
||||
assert_eq!(actual.server_name, expected.server_name);
|
||||
assert_eq!(actual.url, expected.url);
|
||||
assert_eq!(actual.client_id, expected.client_id);
|
||||
assert_token_response_match_without_expiry(
|
||||
&actual.token_response,
|
||||
&expected.token_response,
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_token_response_match_without_expiry(
|
||||
actual: &WrappedOAuthTokenResponse,
|
||||
expected: &WrappedOAuthTokenResponse,
|
||||
) {
|
||||
let actual_response = &actual.0;
|
||||
let expected_response = &expected.0;
|
||||
|
||||
assert_eq!(
|
||||
actual_response.access_token().secret(),
|
||||
expected_response.access_token().secret()
|
||||
);
|
||||
assert_eq!(actual_response.token_type(), expected_response.token_type());
|
||||
assert_eq!(
|
||||
actual_response.refresh_token().map(RefreshToken::secret),
|
||||
expected_response.refresh_token().map(RefreshToken::secret),
|
||||
);
|
||||
assert_eq!(actual_response.scopes(), expected_response.scopes());
|
||||
assert_eq!(
|
||||
actual_response.extra_fields(),
|
||||
expected_response.extra_fields()
|
||||
);
|
||||
assert_eq!(
|
||||
actual_response.expires_in().is_some(),
|
||||
expected_response.expires_in().is_some()
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_tokens() -> StoredOAuthTokens {
|
||||
let mut response = OAuthTokenResponse::new(
|
||||
AccessToken::new("access-token".to_string()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
response.set_refresh_token(Some(RefreshToken::new("refresh-token".to_string())));
|
||||
response.set_scopes(Some(vec![
|
||||
Scope::new("scope-a".to_string()),
|
||||
Scope::new("scope-b".to_string()),
|
||||
]));
|
||||
let expires_in = Duration::from_secs(3600);
|
||||
response.set_expires_in(Some(&expires_in));
|
||||
|
||||
StoredOAuthTokens {
|
||||
server_name: "test-server".to_string(),
|
||||
url: "https://example.test".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
token_response: WrappedOAuthTokenResponse(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
159
llmx-rs/rmcp-client/src/perform_oauth_login.rs
Normal file
159
llmx-rs/rmcp-client/src/perform_oauth_login.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use std::collections::HashMap;
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use reqwest::ClientBuilder;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
}
|
||||
|
||||
impl Drop for CallbackServerGuard {
|
||||
fn drop(&mut self) {
|
||||
self.server.unblock();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
scopes: &[String],
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
};
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
format!("http://{}:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
let http_client = apply_default_headers(ClientBuilder::new(), &default_headers).build()?;
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, Some(http_client)).await?;
|
||||
let scope_refs: Vec<&str> = scopes.iter().map(String::as_str).collect();
|
||||
oauth_state
|
||||
.start_authorization(&scope_refs, &redirect_uri, Some("Codex"))
|
||||
.await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
|
||||
if webbrowser::open(&auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
|
||||
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
while let Ok(request) = server.recv() {
|
||||
let path = request.url().to_string();
|
||||
if let Some(OauthCallbackResult { code, state }) = parse_oauth_callback(&path) {
|
||||
let response =
|
||||
Response::from_string("Authentication complete. You may close this window.");
|
||||
if let Err(err) = request.respond(response) {
|
||||
eprintln!("Failed to respond to OAuth callback: {err}");
|
||||
}
|
||||
if let Err(err) = tx.send((code, state)) {
|
||||
eprintln!("Failed to send OAuth callback: {err:?}");
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
let response =
|
||||
Response::from_string("Invalid OAuth callback").with_status_code(400);
|
||||
if let Err(err) = request.respond(response) {
|
||||
eprintln!("Failed to respond to OAuth callback: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
struct OauthCallbackResult {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
|
||||
let (route, query) = path.split_once('?')?;
|
||||
if route != "/callback" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut code = None;
|
||||
let mut state = None;
|
||||
|
||||
for pair in query.split('&') {
|
||||
let (key, value) = pair.split_once('=')?;
|
||||
let decoded = decode(value).ok()?.into_owned();
|
||||
match key {
|
||||
"code" => code = Some(decoded),
|
||||
"state" => state = Some(decoded),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Some(OauthCallbackResult {
|
||||
code: code?,
|
||||
state: state?,
|
||||
})
|
||||
}
|
||||
414
llmx-rs/rmcp-client/src/rmcp_client.rs
Normal file
414
llmx-rs/rmcp-client/src/rmcp_client.rs
Normal file
@@ -0,0 +1,414 @@
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use futures::FutureExt;
|
||||
use mcp_types::CallToolRequestParams;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::InitializeRequestParams;
|
||||
use mcp_types::InitializeResult;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ListToolsRequestParams;
|
||||
use mcp_types::ListToolsResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use reqwest::header::HeaderMap;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
use rmcp::model::InitializeRequestParam;
|
||||
use rmcp::model::PaginatedRequestParam;
|
||||
use rmcp::model::ReadResourceRequestParam;
|
||||
use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
use crate::utils::create_env_for_mcp_server;
|
||||
use crate::utils::run_with_timeout;
|
||||
|
||||
enum PendingTransport {
|
||||
ChildProcess(TokioChildProcess),
|
||||
StreamableHttp {
|
||||
transport: StreamableHttpClientTransport<reqwest::Client>,
|
||||
},
|
||||
StreamableHttpWithOAuth {
|
||||
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
oauth_persistor: OAuthPersistor,
|
||||
},
|
||||
}
|
||||
|
||||
enum ClientState {
|
||||
Connecting {
|
||||
transport: Option<PendingTransport>,
|
||||
},
|
||||
Ready {
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
oauth: Option<OAuthPersistor>,
|
||||
},
|
||||
}
|
||||
|
||||
/// MCP client implemented on top of the official `rmcp` SDK.
|
||||
/// https://github.com/modelcontextprotocol/rust-sdk
|
||||
pub struct RmcpClient {
|
||||
state: Mutex<ClientState>,
|
||||
}
|
||||
|
||||
impl RmcpClient {
|
||||
pub async fn new_stdio_client(
|
||||
program: OsString,
|
||||
args: Vec<OsString>,
|
||||
env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
cwd: Option<PathBuf>,
|
||||
) -> io::Result<Self> {
|
||||
let program_name = program.to_string_lossy().into_owned();
|
||||
let mut command = Command::new(&program);
|
||||
command
|
||||
.kill_on_drop(true)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.env_clear()
|
||||
.envs(create_env_for_mcp_server(env, env_vars))
|
||||
.args(&args);
|
||||
if let Some(cwd) = cwd {
|
||||
command.current_dir(cwd);
|
||||
}
|
||||
|
||||
let (transport, stderr) = TokioChildProcess::builder(command)
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
if let Some(stderr) = stderr {
|
||||
tokio::spawn(async move {
|
||||
let mut reader = BufReader::new(stderr).lines();
|
||||
loop {
|
||||
match reader.next_line().await {
|
||||
Ok(Some(line)) => {
|
||||
info!("MCP server stderr ({program_name}): {line}");
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(error) => {
|
||||
warn!("Failed to read MCP server stderr ({program_name}): {error}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(PendingTransport::ChildProcess(transport)),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new_streamable_http_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let default_headers = build_default_headers(http_headers, env_http_headers)?;
|
||||
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) = create_oauth_transport_and_runtime(
|
||||
server_name,
|
||||
url,
|
||||
initial_tokens,
|
||||
store_mode,
|
||||
default_headers.clone(),
|
||||
)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token.clone() {
|
||||
http_config = http_config.auth_header(bearer_token);
|
||||
}
|
||||
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(http_client, http_config);
|
||||
PendingTransport::StreamableHttp { transport }
|
||||
};
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(transport),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform the initialization handshake with the MCP server.
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
|
||||
pub async fn initialize(
|
||||
&self,
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<InitializeResult> {
|
||||
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(rmcp_params);
|
||||
|
||||
let (transport, oauth_persistor) = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match &mut *guard {
|
||||
ClientState::Connecting { transport } => match transport.take() {
|
||||
Some(PendingTransport::ChildProcess(transport)) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttp { transport }) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
Some(oauth_persistor),
|
||||
),
|
||||
None => return Err(anyhow!("client already initializing")),
|
||||
},
|
||||
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
|
||||
}
|
||||
};
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, transport)
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
None => transport
|
||||
.await
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
|
||||
let initialize_result_rmcp = service
|
||||
.peer()
|
||||
.peer_info()
|
||||
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
|
||||
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
|
||||
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
service: Arc::new(service),
|
||||
oauth: oauth_persistor.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(runtime) = oauth_persistor
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens after initialize: {error}");
|
||||
}
|
||||
|
||||
Ok(initialize_result)
|
||||
}
|
||||
|
||||
pub async fn list_tools(
|
||||
&self,
|
||||
params: Option<ListToolsRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListToolsResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_tools(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_resources(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/list").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params = params
|
||||
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
|
||||
.transpose()?;
|
||||
|
||||
let fut = service.list_resource_templates(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/templates/list").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
params: ReadResourceRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let service = self.service().await?;
|
||||
let rmcp_params: ReadResourceRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.read_resource(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "resources/read").await?;
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
name: String,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<CallToolResult> {
|
||||
let service = self.service().await?;
|
||||
let params = CallToolRequestParams { arguments, name };
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.call_tool(rmcp_params);
|
||||
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
||||
let converted = convert_call_tool_result(rmcp_result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
|
||||
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready {
|
||||
oauth: Some(runtime),
|
||||
service: _,
|
||||
} => Some(runtime.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// This should be called after every tool call so that if a given tool call triggered
|
||||
/// a refresh of the OAuth tokens, they are persisted.
|
||||
async fn persist_oauth_tokens(&self) {
|
||||
if let Some(runtime) = self.oauth_persistor().await
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens: {error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
default_headers: HeaderMap,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client =
|
||||
apply_default_headers(reqwest::Client::builder(), &default_headers).build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(http_client, manager);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
auth_client,
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
||||
);
|
||||
|
||||
let runtime = OAuthPersistor::new(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Ok((transport, runtime))
|
||||
}
|
||||
302
llmx-rs/rmcp-client/src/utils.rs
Normal file
302
llmx-rs/rmcp-client/src/utils.rs
Normal file
@@ -0,0 +1,302 @@
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use mcp_types::CallToolResult;
|
||||
use reqwest::ClientBuilder;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use rmcp::service::ServiceError;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
|
||||
pub(crate) async fn run_with_timeout<F, T>(
|
||||
fut: F,
|
||||
timeout: Option<Duration>,
|
||||
label: &str,
|
||||
) -> Result<T>
|
||||
where
|
||||
F: std::future::Future<Output = Result<T, ServiceError>>,
|
||||
{
|
||||
if let Some(duration) = timeout {
|
||||
let result = time::timeout(duration, fut)
|
||||
.await
|
||||
.with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?;
|
||||
result.map_err(|err| anyhow!("{label} failed: {err}"))
|
||||
} else {
|
||||
fut.await.map_err(|err| anyhow!("{label} failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result<CallToolResult> {
|
||||
let mut value = serde_json::to_value(result)?;
|
||||
if let Some(obj) = value.as_object_mut()
|
||||
&& (obj.get("content").is_none()
|
||||
|| obj.get("content").is_some_and(serde_json::Value::is_null))
|
||||
{
|
||||
obj.insert("content".to_string(), Value::Array(Vec::new()));
|
||||
}
|
||||
serde_json::from_value(value).context("failed to convert call tool result")
|
||||
}
|
||||
|
||||
/// Convert from mcp-types to Rust SDK types.
|
||||
///
|
||||
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
||||
/// derived from the same MCP specification.
|
||||
/// As a result, it should be safe to convert directly from one to the other.
|
||||
pub(crate) fn convert_to_rmcp<T, U>(value: T) -> Result<U>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
U: serde::de::DeserializeOwned,
|
||||
{
|
||||
let json = serde_json::to_value(value)?;
|
||||
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
||||
}
|
||||
|
||||
/// Convert from Rust SDK types to mcp-types.
|
||||
///
|
||||
/// The Rust SDK types are the same as our mcp-types crate because they are both
|
||||
/// derived from the same MCP specification.
|
||||
/// As a result, it should be safe to convert directly from one to the other.
|
||||
pub(crate) fn convert_to_mcp<T, U>(value: T) -> Result<U>
|
||||
where
|
||||
T: serde::Serialize,
|
||||
U: serde::de::DeserializeOwned,
|
||||
{
|
||||
let json = serde_json::to_value(value)?;
|
||||
serde_json::from_value(json).map_err(|err| anyhow!(err))
|
||||
}
|
||||
|
||||
pub(crate) fn create_env_for_mcp_server(
|
||||
extra_env: Option<HashMap<String, String>>,
|
||||
env_vars: &[String],
|
||||
) -> HashMap<String, String> {
|
||||
DEFAULT_ENV_VARS
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(env_vars.iter().map(String::as_str))
|
||||
.filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value)))
|
||||
.chain(extra_env.unwrap_or_default())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn build_default_headers(
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
) -> Result<HeaderMap> {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
if let Some(static_headers) = http_headers {
|
||||
for (name, value) in static_headers {
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header value for `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = env_http_headers {
|
||||
for (name, env_var) in env_headers {
|
||||
if let Ok(value) = env::var(&env_var) {
|
||||
if value.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
|
||||
Ok(name) => name,
|
||||
Err(err) => {
|
||||
tracing::warn!("invalid HTTP header name `{name}`: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let header_value = match HeaderValue::from_str(value.as_str()) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
tracing::warn!(
|
||||
"invalid HTTP header value read from {env_var} for `{name}`: {err}"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
pub(crate) fn apply_default_headers(
|
||||
builder: ClientBuilder,
|
||||
default_headers: &HeaderMap,
|
||||
) -> ClientBuilder {
|
||||
if default_headers.is_empty() {
|
||||
builder
|
||||
} else {
|
||||
builder.default_headers(default_headers.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
"HOME",
|
||||
"LOGNAME",
|
||||
"PATH",
|
||||
"SHELL",
|
||||
"USER",
|
||||
"__CF_USER_TEXT_ENCODING",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"TERM",
|
||||
"TMPDIR",
|
||||
"TZ",
|
||||
];
|
||||
|
||||
#[cfg(windows)]
|
||||
pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[
|
||||
// Core path resolution
|
||||
"PATH",
|
||||
"PATHEXT",
|
||||
// Shell and system roots
|
||||
"COMSPEC",
|
||||
"SYSTEMROOT",
|
||||
"SYSTEMDRIVE",
|
||||
// User context and profiles
|
||||
"USERNAME",
|
||||
"USERDOMAIN",
|
||||
"USERPROFILE",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
// Program locations
|
||||
"PROGRAMFILES",
|
||||
"PROGRAMFILES(X86)",
|
||||
"PROGRAMW6432",
|
||||
"PROGRAMDATA",
|
||||
// App data and caches
|
||||
"LOCALAPPDATA",
|
||||
"APPDATA",
|
||||
// Temp locations
|
||||
"TEMP",
|
||||
"TMP",
|
||||
// Common shells/pwsh hints
|
||||
"POWERSHELL",
|
||||
"PWSH",
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ContentBlock;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::CallToolResult as RmcpCallToolResult;
|
||||
use serde_json::json;
|
||||
|
||||
use serial_test::serial;
|
||||
use std::ffi::OsString;
|
||||
|
||||
struct EnvVarGuard {
|
||||
key: String,
|
||||
original: Option<OsString>,
|
||||
}
|
||||
|
||||
impl EnvVarGuard {
|
||||
fn set(key: &str, value: &str) -> Self {
|
||||
let original = std::env::var_os(key);
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
Self {
|
||||
key: key.to_string(),
|
||||
original,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvVarGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Some(value) = &self.original {
|
||||
unsafe {
|
||||
std::env::set_var(&self.key, value);
|
||||
}
|
||||
} else {
|
||||
unsafe {
|
||||
std::env::remove_var(&self.key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn create_env_honors_overrides() {
|
||||
let value = "custom".to_string();
|
||||
let env =
|
||||
create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())])), &[]);
|
||||
assert_eq!(env.get("TZ"), Some(&value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(extra_rmcp_env)]
|
||||
fn create_env_includes_additional_whitelisted_variables() {
|
||||
let custom_var = "EXTRA_RMCP_ENV";
|
||||
let value = "from-env";
|
||||
let _guard = EnvVarGuard::set(custom_var, value);
|
||||
let env = create_env_for_mcp_server(None, &[custom_var.to_string()]);
|
||||
assert_eq!(env.get(custom_var), Some(&value.to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_call_tool_result_defaults_missing_content() -> Result<()> {
|
||||
let structured_content = json!({ "key": "value" });
|
||||
let rmcp_result = RmcpCallToolResult {
|
||||
content: vec![],
|
||||
structured_content: Some(structured_content.clone()),
|
||||
is_error: Some(true),
|
||||
meta: None,
|
||||
};
|
||||
|
||||
let result = convert_call_tool_result(rmcp_result)?;
|
||||
|
||||
assert!(result.content.is_empty());
|
||||
assert_eq!(result.structured_content, Some(structured_content));
|
||||
assert_eq!(result.is_error, Some(true));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_call_tool_result_preserves_existing_content() -> Result<()> {
|
||||
let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]);
|
||||
|
||||
let result = convert_call_tool_result(rmcp_result)?;
|
||||
|
||||
assert_eq!(result.content.len(), 1);
|
||||
match &result.content[0] {
|
||||
ContentBlock::TextContent(text_content) => {
|
||||
assert_eq!(text_content.text, "hello");
|
||||
assert_eq!(text_content.r#type, "text");
|
||||
}
|
||||
other => panic!("expected text content got {other:?}"),
|
||||
}
|
||||
assert_eq!(result.structured_content, None);
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user