From 21cd953dbda85061f4605df0035b79723d4da7bb Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Fri, 2 May 2025 17:25:58 -0700 Subject: [PATCH] feat: introduce mcp-server crate (#792) This introduces the `mcp-server` crate, which contains a barebones MCP server that provides an `echo` tool that echoes the user's request back to them. To test it out, I launched [modelcontextprotocol/inspector](https://github.com/modelcontextprotocol/inspector) like so: ``` mcp-server$ npx @modelcontextprotocol/inspector cargo run -- ``` and opened up `http://127.0.0.1:6274` in my browser: ![image](https://github.com/user-attachments/assets/83fc55d4-25c2-4497-80cd-e9702283ff93) I also had to make a small fix to `mcp-types`, adding `#[serde(untagged)]` to a number of `enum`s. --- codex-rs/Cargo.lock | 13 + codex-rs/Cargo.toml | 1 + codex-rs/mcp-server/Cargo.toml | 30 ++ codex-rs/mcp-server/src/main.rs | 110 +++++ codex-rs/mcp-server/src/message_processor.rs | 425 +++++++++++++++++++ codex-rs/mcp-types/generate_mcp_types.py | 5 +- codex-rs/mcp-types/src/lib.rs | 13 + 7 files changed, 593 insertions(+), 4 deletions(-) create mode 100644 codex-rs/mcp-server/Cargo.toml create mode 100644 codex-rs/mcp-server/src/main.rs create mode 100644 codex-rs/mcp-server/src/message_processor.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index ed0b562b..f2f865b0 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -556,6 +556,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "codex-mcp-server" +version = "0.1.0" +dependencies = [ + "codex-core", + "mcp-types", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "codex-tui" version = "0.1.0" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index ded97915..55aab210 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -7,6 +7,7 @@ members = [ "core", "exec", "execpolicy", + "mcp-server", "mcp-types", "tui", ] diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml new file mode 100644 index 00000000..258a37aa --- /dev/null +++ b/codex-rs/mcp-server/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "codex-mcp-server" +version = "0.1.0" +edition = "2021" + +[dependencies] +# +# codex-core contains optional functionality that is gated behind the "cli" +# feature. Unfortunately there is an unconditional reference to a module that +# is only compiled when the feature is enabled, which breaks the build when +# the default (no-feature) variant is used. +# +# We therefore explicitly enable the "cli" feature when codex-mcp-server pulls +# in codex-core so that the required symbols are present. This does _not_ +# change the public API of codex-core – it merely opts into compiling the +# extra, feature-gated source files so the build succeeds. +# +codex-core = { path = "../core", features = ["cli"] } +mcp-types = { path = "../mcp-types" } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tracing = { version = "0.1.41", features = ["log"] } +tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } +tokio = { version = "1", features = [ + "io-std", + "macros", + "process", + "rt-multi-thread", + "signal", +] } diff --git a/codex-rs/mcp-server/src/main.rs b/codex-rs/mcp-server/src/main.rs new file mode 100644 index 00000000..b0fb7fec --- /dev/null +++ b/codex-rs/mcp-server/src/main.rs @@ -0,0 +1,110 @@ +//! Prototype MCP server. + +use std::io::Result as IoResult; + +use mcp_types::JSONRPCMessage; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::io::{self}; +use tokio::sync::mpsc; +use tracing::debug; +use tracing::error; +use tracing::info; + +mod message_processor; +use crate::message_processor::MessageProcessor; + +/// Size of the bounded channels used to communicate between tasks. The value +/// is a balance between throughput and memory usage – 128 messages should be +/// plenty for an interactive CLI. +const CHANNEL_CAPACITY: usize = 128; + +#[tokio::main] +async fn main() -> IoResult<()> { + // Install a simple subscriber so `tracing` output is visible. Users can + // control the log level with `RUST_LOG`. + tracing_subscriber::fmt() + .with_writer(std::io::stderr) + .init(); + + // Set up channels. + let (incoming_tx, mut incoming_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + + // Task: read from stdin, push to `incoming_tx`. + let stdin_reader_handle = tokio::spawn({ + let incoming_tx = incoming_tx.clone(); + async move { + let stdin = io::stdin(); + let reader = BufReader::new(stdin); + let mut lines = reader.lines(); + + while let Some(line) = lines.next_line().await.unwrap_or_default() { + match serde_json::from_str::(&line) { + Ok(msg) => { + if incoming_tx.send(msg).await.is_err() { + // Receiver gone – nothing left to do. + break; + } + } + Err(e) => error!("Failed to deserialize JSONRPCMessage: {e}"), + } + } + + debug!("stdin reader finished (EOF)"); + } + }); + + // Task: process incoming messages. + let processor_handle = tokio::spawn({ + let mut processor = MessageProcessor::new(outgoing_tx.clone()); + async move { + while let Some(msg) = incoming_rx.recv().await { + match msg { + JSONRPCMessage::Request(r) => processor.process_request(r), + JSONRPCMessage::Response(r) => processor.process_response(r), + JSONRPCMessage::Notification(n) => processor.process_notification(n), + JSONRPCMessage::BatchRequest(b) => processor.process_batch_request(b), + JSONRPCMessage::Error(e) => processor.process_error(e), + JSONRPCMessage::BatchResponse(b) => processor.process_batch_response(b), + } + } + + info!("processor task exited (channel closed)"); + } + }); + + // Task: write outgoing messages to stdout. + let stdout_writer_handle = tokio::spawn(async move { + let mut stdout = io::stdout(); + while let Some(msg) = outgoing_rx.recv().await { + match serde_json::to_string(&msg) { + Ok(json) => { + if let Err(e) = stdout.write_all(json.as_bytes()).await { + error!("Failed to write to stdout: {e}"); + break; + } + if let Err(e) = stdout.write_all(b"\n").await { + error!("Failed to write newline to stdout: {e}"); + break; + } + if let Err(e) = stdout.flush().await { + error!("Failed to flush stdout: {e}"); + break; + } + } + Err(e) => error!("Failed to serialize JSONRPCMessage: {e}"), + } + } + + info!("stdout writer exited (channel closed)"); + }); + + // Wait for all tasks to finish. The typical exit path is the stdin reader + // hitting EOF which, once it drops `incoming_tx`, propagates shutdown to + // the processor and then to the stdout task. + let _ = tokio::join!(stdin_reader_handle, processor_handle, stdout_writer_handle); + + Ok(()) +} diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs new file mode 100644 index 00000000..6fcdc75d --- /dev/null +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -0,0 +1,425 @@ +//! Very small proof-of-concept request router for the MCP prototype server. + +use mcp_types::CallToolRequestParams; +use mcp_types::CallToolResultContent; +use mcp_types::ClientRequest; +use mcp_types::JSONRPCBatchRequest; +use mcp_types::JSONRPCBatchResponse; +use mcp_types::JSONRPCError; +use mcp_types::JSONRPCErrorError; +use mcp_types::JSONRPCMessage; +use mcp_types::JSONRPCNotification; +use mcp_types::JSONRPCRequest; +use mcp_types::JSONRPCResponse; +use mcp_types::ListToolsResult; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use mcp_types::ServerCapabilitiesTools; +use mcp_types::ServerNotification; +use mcp_types::TextContent; +use mcp_types::Tool; +use mcp_types::ToolInputSchema; +use mcp_types::JSONRPC_VERSION; +use serde_json::json; +use tokio::sync::mpsc; + +pub(crate) struct MessageProcessor { + outgoing: mpsc::Sender, + initialized: bool, +} + +impl MessageProcessor { + /// Create a new `MessageProcessor`, retaining a handle to the outgoing + /// `Sender` so handlers can enqueue messages to be written to stdout. + pub(crate) fn new(outgoing: mpsc::Sender) -> Self { + Self { + outgoing, + initialized: false, + } + } + + pub(crate) fn process_request(&mut self, request: JSONRPCRequest) { + // Hold on to the ID so we can respond. + let request_id = request.id.clone(); + + let client_request = match ClientRequest::try_from(request) { + Ok(client_request) => client_request, + Err(e) => { + tracing::warn!("Failed to convert request: {e}"); + return; + } + }; + + // Dispatch to a dedicated handler for each request type. + match client_request { + ClientRequest::InitializeRequest(params) => { + self.handle_initialize(request_id, params); + } + ClientRequest::PingRequest(params) => { + self.handle_ping(request_id, params); + } + ClientRequest::ListResourcesRequest(params) => { + self.handle_list_resources(params); + } + ClientRequest::ListResourceTemplatesRequest(params) => { + self.handle_list_resource_templates(params); + } + ClientRequest::ReadResourceRequest(params) => { + self.handle_read_resource(params); + } + ClientRequest::SubscribeRequest(params) => { + self.handle_subscribe(params); + } + ClientRequest::UnsubscribeRequest(params) => { + self.handle_unsubscribe(params); + } + ClientRequest::ListPromptsRequest(params) => { + self.handle_list_prompts(params); + } + ClientRequest::GetPromptRequest(params) => { + self.handle_get_prompt(params); + } + ClientRequest::ListToolsRequest(params) => { + self.handle_list_tools(request_id, params); + } + ClientRequest::CallToolRequest(params) => { + self.handle_call_tool(request_id, params); + } + ClientRequest::SetLevelRequest(params) => { + self.handle_set_level(params); + } + ClientRequest::CompleteRequest(params) => { + self.handle_complete(params); + } + } + } + + /// Handle a standalone JSON-RPC response originating from the peer. + pub(crate) fn process_response(&mut self, response: JSONRPCResponse) { + tracing::info!("<- response: {:?}", response); + } + + /// Handle a fire-and-forget JSON-RPC notification. + pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) { + let server_notification = match ServerNotification::try_from(notification) { + Ok(n) => n, + Err(e) => { + tracing::warn!("Failed to convert notification: {e}"); + return; + } + }; + + // Similar to requests, route each notification type to its own stub + // handler so additional logic can be implemented incrementally. + match server_notification { + ServerNotification::CancelledNotification(params) => { + self.handle_cancelled_notification(params); + } + ServerNotification::ProgressNotification(params) => { + self.handle_progress_notification(params); + } + ServerNotification::ResourceListChangedNotification(params) => { + self.handle_resource_list_changed(params); + } + ServerNotification::ResourceUpdatedNotification(params) => { + self.handle_resource_updated(params); + } + ServerNotification::PromptListChangedNotification(params) => { + self.handle_prompt_list_changed(params); + } + ServerNotification::ToolListChangedNotification(params) => { + self.handle_tool_list_changed(params); + } + ServerNotification::LoggingMessageNotification(params) => { + self.handle_logging_message(params); + } + } + } + + /// Handle a batch of requests and/or notifications. + pub(crate) fn process_batch_request(&mut self, batch: JSONRPCBatchRequest) { + tracing::info!("<- batch request containing {} item(s)", batch.len()); + for item in batch { + match item { + mcp_types::JSONRPCBatchRequestItem::JSONRPCRequest(req) => { + self.process_request(req); + } + mcp_types::JSONRPCBatchRequestItem::JSONRPCNotification(note) => { + self.process_notification(note); + } + } + } + } + + /// Handle an error object received from the peer. + pub(crate) fn process_error(&mut self, err: JSONRPCError) { + tracing::error!("<- error: {:?}", err); + } + + /// Handle a batch of responses/errors. + pub(crate) fn process_batch_response(&mut self, batch: JSONRPCBatchResponse) { + tracing::info!("<- batch response containing {} item(s)", batch.len()); + for item in batch { + match item { + mcp_types::JSONRPCBatchResponseItem::JSONRPCResponse(resp) => { + self.process_response(resp); + } + mcp_types::JSONRPCBatchResponseItem::JSONRPCError(err) => { + self.process_error(err); + } + } + } + } + + fn handle_initialize( + &mut self, + id: RequestId, + params: ::Params, + ) { + tracing::info!("initialize -> params: {:?}", params); + + if self.initialized { + // Already initialised: send JSON-RPC error response. + let error_msg = JSONRPCMessage::Error(JSONRPCError { + jsonrpc: JSONRPC_VERSION.into(), + id, + error: JSONRPCErrorError { + code: -32600, // Invalid Request + message: "initialize called more than once".to_string(), + data: None, + }, + }); + + if let Err(e) = self.outgoing.try_send(error_msg) { + tracing::error!("Failed to send initialization error: {e}"); + } + return; + } + + self.initialized = true; + + // Build a minimal InitializeResult. Fill with placeholders. + let result = mcp_types::InitializeResult { + capabilities: mcp_types::ServerCapabilities { + completions: None, + experimental: None, + logging: None, + prompts: None, + resources: None, + tools: Some(ServerCapabilitiesTools { + list_changed: Some(true), + }), + }, + instructions: None, + protocol_version: params.protocol_version.clone(), + server_info: mcp_types::Implementation { + name: "codex-mcp-server".to_string(), + version: mcp_types::MCP_SCHEMA_VERSION.to_string(), + }, + }; + + self.send_response::(id, result); + } + + fn send_response(&self, id: RequestId, result: T::Result) + where + T: ModelContextProtocolRequest, + { + let response = JSONRPCMessage::Response(JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id, + result: serde_json::to_value(result).unwrap(), + }); + + if let Err(e) = self.outgoing.try_send(response) { + tracing::error!("Failed to send response: {e}"); + } + } + + fn handle_ping( + &self, + id: RequestId, + params: ::Params, + ) { + tracing::info!("ping -> params: {:?}", params); + let result = json!({}); + self.send_response::(id, result); + } + + fn handle_list_resources( + &self, + params: ::Params, + ) { + tracing::info!("resources/list -> params: {:?}", params); + } + + fn handle_list_resource_templates( + &self, + params: + ::Params, + ) { + tracing::info!("resources/templates/list -> params: {:?}", params); + } + + fn handle_read_resource( + &self, + params: ::Params, + ) { + tracing::info!("resources/read -> params: {:?}", params); + } + + fn handle_subscribe( + &self, + params: ::Params, + ) { + tracing::info!("resources/subscribe -> params: {:?}", params); + } + + fn handle_unsubscribe( + &self, + params: ::Params, + ) { + tracing::info!("resources/unsubscribe -> params: {:?}", params); + } + + fn handle_list_prompts( + &self, + params: ::Params, + ) { + tracing::info!("prompts/list -> params: {:?}", params); + } + + fn handle_get_prompt( + &self, + params: ::Params, + ) { + tracing::info!("prompts/get -> params: {:?}", params); + } + + fn handle_list_tools( + &self, + id: RequestId, + params: ::Params, + ) { + tracing::trace!("tools/list -> {params:?}"); + let result = ListToolsResult { + tools: vec![Tool { + name: "echo".to_string(), + input_schema: ToolInputSchema { + r#type: "object".to_string(), + properties: Some(json!({ + "input": { + "type": "string", + "description": "The input to echo back" + } + })), + required: Some(vec!["input".to_string()]), + }, + description: Some("Echoes the request back".to_string()), + annotations: None, + }], + next_cursor: None, + }; + + self.send_response::(id, result); + } + + fn handle_call_tool( + &self, + id: RequestId, + params: ::Params, + ) { + tracing::info!("tools/call -> params: {:?}", params); + let CallToolRequestParams { name, arguments } = params; + match name.as_str() { + "echo" => { + let result = mcp_types::CallToolResult { + content: vec![CallToolResultContent::TextContent(TextContent { + r#type: "text".to_string(), + text: format!("Echo: {arguments:?}"), + annotations: None, + })], + is_error: None, + }; + self.send_response::(id, result); + } + _ => { + let result = mcp_types::CallToolResult { + content: vec![], + is_error: Some(true), + }; + self.send_response::(id, result); + } + } + } + + fn handle_set_level( + &self, + params: ::Params, + ) { + tracing::info!("logging/setLevel -> params: {:?}", params); + } + + fn handle_complete( + &self, + params: ::Params, + ) { + tracing::info!("completion/complete -> params: {:?}", params); + } + + // --------------------------------------------------------------------- + // Notification handlers + // --------------------------------------------------------------------- + + fn handle_cancelled_notification( + &self, + params: ::Params, + ) { + tracing::info!("notifications/cancelled -> params: {:?}", params); + } + + fn handle_progress_notification( + &self, + params: ::Params, + ) { + tracing::info!("notifications/progress -> params: {:?}", params); + } + + fn handle_resource_list_changed( + &self, + params: ::Params, + ) { + tracing::info!( + "notifications/resources/list_changed -> params: {:?}", + params + ); + } + + fn handle_resource_updated( + &self, + params: ::Params, + ) { + tracing::info!("notifications/resources/updated -> params: {:?}", params); + } + + fn handle_prompt_list_changed( + &self, + params: ::Params, + ) { + tracing::info!("notifications/prompts/list_changed -> params: {:?}", params); + } + + fn handle_tool_list_changed( + &self, + params: ::Params, + ) { + tracing::info!("notifications/tools/list_changed -> params: {:?}", params); + } + + fn handle_logging_message( + &self, + params: ::Params, + ) { + tracing::info!("notifications/message -> params: {:?}", params); + } +} diff --git a/codex-rs/mcp-types/generate_mcp_types.py b/codex-rs/mcp-types/generate_mcp_types.py index 92ac9812..ff11dbf0 100755 --- a/codex-rs/mcp-types/generate_mcp_types.py +++ b/codex-rs/mcp-types/generate_mcp_types.py @@ -359,7 +359,6 @@ def implements_notification_trait(name: str) -> bool: def add_trait_impl( type_name: str, trait_name: str, fields: list[StructField], out: list[str] ) -> None: - # out.append("#[derive(Debug)]\n") out.append(STANDARD_DERIVE) out.append(f"pub enum {type_name} {{}}\n\n") @@ -507,10 +506,8 @@ def get_serde_annotation_for_anyof_type(type_name: str) -> str | None: return '#[serde(tag = "method", content = "params")]' case "ServerNotification": return '#[serde(tag = "method", content = "params")]' - case "JSONRPCMessage": - return "#[serde(untagged)]" case _: - return None + return "#[serde(untagged)]" def map_type( diff --git a/codex-rs/mcp-types/src/lib.rs b/codex-rs/mcp-types/src/lib.rs index c8925cfe..a1880ccd 100644 --- a/codex-rs/mcp-types/src/lib.rs +++ b/codex-rs/mcp-types/src/lib.rs @@ -92,6 +92,7 @@ pub struct CallToolResult { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum CallToolResultContent { TextContent(TextContent), ImageContent(ImageContent), @@ -144,6 +145,7 @@ pub struct ClientCapabilitiesRoots { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum ClientNotification { CancelledNotification(CancelledNotification), InitializedNotification(InitializedNotification), @@ -185,6 +187,7 @@ pub enum ClientRequest { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum ClientResult { Result(Result), CreateMessageResult(CreateMessageResult), @@ -214,6 +217,7 @@ pub struct CompleteRequestParamsArgument { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum CompleteRequestParamsRef { PromptReference(PromptReference), ResourceReference(ResourceReference), @@ -299,6 +303,7 @@ pub struct CreateMessageResult { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum CreateMessageResultContent { TextContent(TextContent), ImageContent(ImageContent), @@ -327,6 +332,7 @@ pub struct EmbeddedResource { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum EmbeddedResourceResource { TextResourceContents(TextResourceContents), BlobResourceContents(BlobResourceContents), @@ -427,6 +433,7 @@ impl ModelContextProtocolNotification for InitializedNotification { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum JSONRPCBatchRequestItem { JSONRPCRequest(JSONRPCRequest), JSONRPCNotification(JSONRPCNotification), @@ -435,6 +442,7 @@ pub enum JSONRPCBatchRequestItem { pub type JSONRPCBatchRequest = Vec; #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum JSONRPCBatchResponseItem { JSONRPCResponse(JSONRPCResponse), JSONRPCError(JSONRPCError), @@ -852,6 +860,7 @@ pub struct PromptMessage { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum PromptMessageContent { TextContent(TextContent), ImageContent(ImageContent), @@ -887,6 +896,7 @@ pub struct ReadResourceResult { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum ReadResourceResultContents { TextResourceContents(TextResourceContents), BlobResourceContents(BlobResourceContents), @@ -1012,6 +1022,7 @@ pub struct SamplingMessage { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum SamplingMessageContent { TextContent(TextContent), ImageContent(ImageContent), @@ -1100,6 +1111,7 @@ pub enum ServerNotification { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum ServerRequest { PingRequest(PingRequest), CreateMessageRequest(CreateMessageRequest), @@ -1107,6 +1119,7 @@ pub enum ServerRequest { } #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(untagged)] pub enum ServerResult { Result(Result), InitializeResult(InitializeResult),