use std::borrow::Cow; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; use std::sync::Arc; use axum::Router; use axum::body::Body; use axum::extract::State; use axum::http::Request; use axum::http::StatusCode; use axum::http::header::AUTHORIZATION; use axum::middleware; use axum::middleware::Next; use axum::response::Response; use rmcp::ErrorData as McpError; use rmcp::handler::server::ServerHandler; use rmcp::model::CallToolRequestParam; use rmcp::model::CallToolResult; use rmcp::model::JsonObject; use rmcp::model::ListResourceTemplatesResult; use rmcp::model::ListResourcesResult; use rmcp::model::ListToolsResult; use rmcp::model::PaginatedRequestParam; use rmcp::model::RawResource; use rmcp::model::RawResourceTemplate; use rmcp::model::ReadResourceRequestParam; use rmcp::model::ReadResourceResult; use rmcp::model::Resource; use rmcp::model::ResourceContents; use rmcp::model::ResourceTemplate; use rmcp::model::ServerCapabilities; use rmcp::model::ServerInfo; use rmcp::model::Tool; use rmcp::transport::StreamableHttpServerConfig; use rmcp::transport::StreamableHttpService; use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; use serde::Deserialize; use serde_json::json; use tokio::task; #[derive(Clone)] struct TestToolServer { tools: Arc>, resources: Arc>, resource_templates: Arc>, } const MEMO_URI: &str = "memo://codex/example-note"; const MEMO_CONTENT: &str = "This is a sample MCP resource served by the rmcp test server."; impl TestToolServer { fn new() -> Self { let tools = vec![Self::echo_tool()]; let resources = vec![Self::memo_resource()]; let resource_templates = vec![Self::memo_template()]; Self { tools: Arc::new(tools), resources: Arc::new(resources), resource_templates: Arc::new(resource_templates), } } fn echo_tool() -> Tool { #[expect(clippy::expect_used)] let schema: JsonObject = serde_json::from_value(json!({ "type": "object", "properties": { "message": { "type": "string" }, "env_var": { "type": "string" } }, "required": ["message"], "additionalProperties": false })) .expect("echo tool schema should deserialize"); Tool::new( Cow::Borrowed("echo"), Cow::Borrowed("Echo back the provided message and include environment data."), Arc::new(schema), ) } fn memo_resource() -> Resource { let raw = RawResource { uri: MEMO_URI.to_string(), name: "example-note".to_string(), title: Some("Example Note".to_string()), description: Some("A sample MCP resource exposed for integration tests.".to_string()), mime_type: Some("text/plain".to_string()), size: None, icons: None, }; Resource::new(raw, None) } fn memo_template() -> ResourceTemplate { let raw = RawResourceTemplate { uri_template: "memo://codex/{slug}".to_string(), name: "codex-memo".to_string(), title: Some("Codex Memo".to_string()), description: Some( "Template for memo://codex/{slug} resources used in tests.".to_string(), ), mime_type: Some("text/plain".to_string()), }; ResourceTemplate::new(raw, None) } fn memo_text() -> &'static str { MEMO_CONTENT } } #[derive(Deserialize)] struct EchoArgs { message: String, #[allow(dead_code)] env_var: Option, } impl ServerHandler for TestToolServer { fn get_info(&self) -> ServerInfo { ServerInfo { capabilities: ServerCapabilities::builder() .enable_tools() .enable_tool_list_changed() .enable_resources() .build(), ..ServerInfo::default() } } fn list_tools( &self, _request: Option, _context: rmcp::service::RequestContext, ) -> impl std::future::Future> + Send + '_ { let tools = self.tools.clone(); async move { Ok(ListToolsResult { tools: (*tools).clone(), next_cursor: None, }) } } fn list_resources( &self, _request: Option, _context: rmcp::service::RequestContext, ) -> impl std::future::Future> + Send + '_ { let resources = self.resources.clone(); async move { Ok(ListResourcesResult { resources: (*resources).clone(), next_cursor: None, }) } } async fn list_resource_templates( &self, _request: Option, _context: rmcp::service::RequestContext, ) -> Result { Ok(ListResourceTemplatesResult { resource_templates: (*self.resource_templates).clone(), next_cursor: None, }) } async fn read_resource( &self, ReadResourceRequestParam { uri }: ReadResourceRequestParam, _context: rmcp::service::RequestContext, ) -> Result { if uri == MEMO_URI { Ok(ReadResourceResult { contents: vec![ResourceContents::TextResourceContents { uri, mime_type: Some("text/plain".to_string()), text: Self::memo_text().to_string(), meta: None, }], }) } else { Err(McpError::resource_not_found( "resource_not_found", Some(json!({ "uri": uri })), )) } } async fn call_tool( &self, request: CallToolRequestParam, _context: rmcp::service::RequestContext, ) -> Result { match request.name.as_ref() { "echo" => { let args: EchoArgs = match request.arguments { Some(arguments) => serde_json::from_value(serde_json::Value::Object( arguments.into_iter().collect(), )) .map_err(|err| McpError::invalid_params(err.to_string(), None))?, None => { return Err(McpError::invalid_params( "missing arguments for echo tool", None, )); } }; let env_snapshot: HashMap = std::env::vars().collect(); let structured_content = json!({ "echo": format!("ECHOING: {}", args.message), "env": env_snapshot.get("MCP_TEST_VALUE"), }); Ok(CallToolResult { content: Vec::new(), structured_content: Some(structured_content), is_error: Some(false), meta: None, }) } other => Err(McpError::invalid_params( format!("unknown tool: {other}"), None, )), } } } fn parse_bind_addr() -> Result> { let default_addr = "127.0.0.1:3920"; let bind_addr = std::env::var("MCP_STREAMABLE_HTTP_BIND_ADDR") .or_else(|_| std::env::var("BIND_ADDR")) .unwrap_or_else(|_| default_addr.to_string()); Ok(bind_addr.parse()?) } #[tokio::main] async fn main() -> Result<(), Box> { let bind_addr = parse_bind_addr()?; let listener = match tokio::net::TcpListener::bind(&bind_addr).await { Ok(listener) => listener, Err(err) if err.kind() == ErrorKind::PermissionDenied => { eprintln!( "failed to bind to {bind_addr}: {err}. make sure the process has network access" ); return Ok(()); } Err(err) => return Err(err.into()), }; eprintln!("starting rmcp streamable http test server on http://{bind_addr}/mcp"); let router = Router::new().nest_service( "/mcp", StreamableHttpService::new( || Ok(TestToolServer::new()), Arc::new(LocalSessionManager::default()), StreamableHttpServerConfig::default(), ), ); let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") { let expected = Arc::new(format!("Bearer {token}")); router.layer(middleware::from_fn_with_state(expected, require_bearer)) } else { router }; axum::serve(listener, router).await?; task::yield_now().await; Ok(()) } async fn require_bearer( State(expected): State>, request: Request, next: Next, ) -> Result { if request .headers() .get(AUTHORIZATION) .is_some_and(|value| value.as_bytes() == expected.as_bytes()) { Ok(next.run(request).await) } else { Err(StatusCode::UNAUTHORIZED) } }