From 11fd3123beb2f09371e4ab0f6568673236130457 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Sat, 19 Jul 2025 00:30:56 -0400 Subject: [PATCH] chore: introduce OutgoingMessageSender (#1622) Previous to this change, `MessageProcessor` had a `tokio::sync::mpsc::Sender` as an abstraction for server code to send a message down to the MCP client. Because `Sender` is cheap to `clone()`, it was straightforward to make it available to tasks scheduled with `tokio::task::spawn()`. This worked well when we were only sending notifications or responses back down to the client, but we want to add support for sending elicitations in #1623, which means that we need to be able to send _requests_ to the client, and now we need a bit of centralization to ensure all request ids are unique. To that end, this PR introduces `OutgoingMessageSender`, which houses the existing `Sender` as well as an `AtomicI64` to mint out new, unique request ids. It has methods like `send_request()` and `send_response()` so that callers do not have to deal with `JSONRPCMessage` directly, as having to set the `jsonrpc` for each message was a bit tedious (this cleans up `codex_tool_runner.rs` quite a bit). We do not have `OutgoingMessageSender` implement `Clone` because it is important that the `AtomicI64` is shared across all users of `OutgoingMessageSender`. As such, `Arc` must be used instead, as it is frequently shared with new tokio tasks. As part of this change, we update `message_processor.rs` to embrace `await`, though we must be careful that no individual handler blocks the main loop and prevents other messages from being handled. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/openai/codex/pull/1622). * #1623 * __->__ #1622 * #1621 * #1620 --- codex-rs/mcp-server/src/codex_tool_runner.rs | 65 ++-------- codex-rs/mcp-server/src/lib.rs | 13 +- codex-rs/mcp-server/src/message_processor.rs | 82 ++++++------ codex-rs/mcp-server/src/outgoing_message.rs | 129 +++++++++++++++++++ 4 files changed, 186 insertions(+), 103 deletions(-) create mode 100644 codex-rs/mcp-server/src/outgoing_message.rs diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 00cadcf0..a20566d6 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -2,10 +2,11 @@ //! Tokio task. Separated from `message_processor.rs` to keep that file small //! and to make future feature-growth easier to manage. +use std::sync::Arc; + use codex_core::codex_wrapper::init_codex; use codex_core::config::Config as CodexConfig; use codex_core::protocol::AgentMessageEvent; -use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; @@ -13,22 +14,10 @@ use codex_core::protocol::Submission; use codex_core::protocol::TaskCompleteEvent; use mcp_types::CallToolResult; use mcp_types::ContentBlock; -use mcp_types::JSONRPC_VERSION; -use mcp_types::JSONRPCMessage; -use mcp_types::JSONRPCResponse; use mcp_types::RequestId; use mcp_types::TextContent; -use tokio::sync::mpsc::Sender; -/// Convert a Codex [`Event`] to an MCP notification. -fn codex_event_to_notification(event: &Event) -> JSONRPCMessage { - #[expect(clippy::expect_used)] - JSONRPCMessage::Notification(mcp_types::JSONRPCNotification { - jsonrpc: JSONRPC_VERSION.into(), - method: "codex/event".into(), - params: Some(serde_json::to_value(event).expect("Event must serialize")), - }) -} +use crate::outgoing_message::OutgoingMessageSender; /// Run a complete Codex session and stream events back to the client. /// @@ -38,7 +27,7 @@ pub async fn run_codex_tool_session( id: RequestId, initial_prompt: String, config: CodexConfig, - outgoing: Sender, + outgoing: Arc, ) { let (codex, first_event, _ctrl_c) = match init_codex(config).await { Ok(res) => res, @@ -52,21 +41,13 @@ pub async fn run_codex_tool_session( is_error: Some(true), structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id, - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; return; } }; // Send initial SessionConfigured event. - let _ = outgoing - .send(codex_event_to_notification(&first_event)) - .await; + outgoing.send_event_as_notification(&first_event).await; // Use the original MCP request ID as the `sub_id` for the Codex submission so that // any events emitted for this tool-call can be correlated with the @@ -94,7 +75,7 @@ pub async fn run_codex_tool_session( loop { match codex.next_event().await { Ok(event) => { - let _ = outgoing.send(codex_event_to_notification(&event)).await; + outgoing.send_event_as_notification(&event).await; match &event.msg { EventMsg::ExecApprovalRequest(_) => { @@ -107,13 +88,7 @@ pub async fn run_codex_tool_session( is_error: None, structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; break; } EventMsg::ApplyPatchApprovalRequest(_) => { @@ -126,13 +101,7 @@ pub async fn run_codex_tool_session( is_error: None, structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; break; } EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { @@ -149,13 +118,7 @@ pub async fn run_codex_tool_session( is_error: None, structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; break; } EventMsg::SessionConfigured(_) => { @@ -203,13 +166,7 @@ pub async fn run_codex_tool_session( // structured way. structured_content: None, }; - let _ = outgoing - .send(JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: id.clone(), - result: result.into(), - })) - .await; + outgoing.send_response(id.clone(), result.into()).await; break; } } diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index db41013a..b968b497 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -18,8 +18,11 @@ mod codex_tool_config; mod codex_tool_runner; mod json_to_toml; mod message_processor; +mod outgoing_message; use crate::message_processor::MessageProcessor; +use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::OutgoingMessageSender; /// Size of the bounded channels used to communicate between tasks. The value /// is a balance between throughput and memory usage – 128 messages should be @@ -35,7 +38,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // Set up channels. let (incoming_tx, mut incoming_rx) = mpsc::channel::(CHANNEL_CAPACITY); - let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(CHANNEL_CAPACITY); // Task: read from stdin, push to `incoming_tx`. let stdin_reader_handle = tokio::spawn({ @@ -63,11 +66,12 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // Task: process incoming messages. let processor_handle = tokio::spawn({ - let mut processor = MessageProcessor::new(outgoing_tx.clone(), codex_linux_sandbox_exe); + let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx); + let mut processor = MessageProcessor::new(outgoing_message_sender, codex_linux_sandbox_exe); async move { while let Some(msg) = incoming_rx.recv().await { match msg { - JSONRPCMessage::Request(r) => processor.process_request(r), + JSONRPCMessage::Request(r) => processor.process_request(r).await, JSONRPCMessage::Response(r) => processor.process_response(r), JSONRPCMessage::Notification(n) => processor.process_notification(n), JSONRPCMessage::Error(e) => processor.process_error(e), @@ -81,7 +85,8 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> // Task: write outgoing messages to stdout. let stdout_writer_handle = tokio::spawn(async move { let mut stdout = io::stdout(); - while let Some(msg) = outgoing_rx.recv().await { + while let Some(outgoing_message) = outgoing_rx.recv().await { + let msg: JSONRPCMessage = outgoing_message.into(); match serde_json::to_string(&msg) { Ok(json) => { if let Err(e) = stdout.write_all(json.as_bytes()).await { diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index dcc6ae62..aad7f211 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,17 +1,17 @@ use std::path::PathBuf; +use std::sync::Arc; use crate::codex_tool_config::CodexToolCallParam; use crate::codex_tool_config::create_tool_for_codex_tool_call_param; +use crate::outgoing_message::OutgoingMessageSender; use codex_core::config::Config as CodexConfig; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::ClientRequest; use mcp_types::ContentBlock; -use mcp_types::JSONRPC_VERSION; use mcp_types::JSONRPCError; use mcp_types::JSONRPCErrorError; -use mcp_types::JSONRPCMessage; use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCRequest; use mcp_types::JSONRPCResponse; @@ -22,11 +22,10 @@ use mcp_types::ServerCapabilitiesTools; use mcp_types::ServerNotification; use mcp_types::TextContent; use serde_json::json; -use tokio::sync::mpsc; use tokio::task; pub(crate) struct MessageProcessor { - outgoing: mpsc::Sender, + outgoing: Arc, initialized: bool, codex_linux_sandbox_exe: Option, } @@ -35,17 +34,17 @@ impl MessageProcessor { /// Create a new `MessageProcessor`, retaining a handle to the outgoing /// `Sender` so handlers can enqueue messages to be written to stdout. pub(crate) fn new( - outgoing: mpsc::Sender, + outgoing: OutgoingMessageSender, codex_linux_sandbox_exe: Option, ) -> Self { Self { - outgoing, + outgoing: Arc::new(outgoing), initialized: false, codex_linux_sandbox_exe, } } - pub(crate) fn process_request(&mut self, request: JSONRPCRequest) { + pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { // Hold on to the ID so we can respond. let request_id = request.id.clone(); @@ -60,10 +59,10 @@ impl MessageProcessor { // Dispatch to a dedicated handler for each request type. match client_request { ClientRequest::InitializeRequest(params) => { - self.handle_initialize(request_id, params); + self.handle_initialize(request_id, params).await; } ClientRequest::PingRequest(params) => { - self.handle_ping(request_id, params); + self.handle_ping(request_id, params).await; } ClientRequest::ListResourcesRequest(params) => { self.handle_list_resources(params); @@ -87,10 +86,10 @@ impl MessageProcessor { self.handle_get_prompt(params); } ClientRequest::ListToolsRequest(params) => { - self.handle_list_tools(request_id, params); + self.handle_list_tools(request_id, params).await; } ClientRequest::CallToolRequest(params) => { - self.handle_call_tool(request_id, params); + self.handle_call_tool(request_id, params).await; } ClientRequest::SetLevelRequest(params) => { self.handle_set_level(params); @@ -148,7 +147,7 @@ impl MessageProcessor { tracing::error!("<- error: {:?}", err); } - fn handle_initialize( + async fn handle_initialize( &mut self, id: RequestId, params: ::Params, @@ -157,19 +156,12 @@ impl MessageProcessor { if self.initialized { // Already initialised: send JSON-RPC error response. - let error_msg = JSONRPCMessage::Error(JSONRPCError { - jsonrpc: JSONRPC_VERSION.into(), - id, - error: JSONRPCErrorError { - code: -32600, // Invalid Request - message: "initialize called more than once".to_string(), - data: None, - }, - }); - - if let Err(e) = self.outgoing.try_send(error_msg) { - tracing::error!("Failed to send initialization error: {e}"); - } + let error = JSONRPCErrorError { + code: -32600, // Invalid Request + message: "initialize called more than once".to_string(), + data: None, + }; + self.outgoing.send_error(id, error).await; return; } @@ -196,34 +188,29 @@ impl MessageProcessor { }, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; } - fn send_response(&self, id: RequestId, result: T::Result) + async fn send_response(&self, id: RequestId, result: T::Result) where T: ModelContextProtocolRequest, { // result has `Serialized` instance so should never fail #[expect(clippy::unwrap_used)] - let response = JSONRPCMessage::Response(JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id, - result: serde_json::to_value(result).unwrap(), - }); - - if let Err(e) = self.outgoing.try_send(response) { - tracing::error!("Failed to send response: {e}"); - } + let result = serde_json::to_value(result).unwrap(); + self.outgoing.send_response(id, result).await; } - fn handle_ping( + async fn handle_ping( &self, id: RequestId, params: ::Params, ) { tracing::info!("ping -> params: {:?}", params); let result = json!({}); - self.send_response::(id, result); + self.send_response::(id, result) + .await; } fn handle_list_resources( @@ -276,7 +263,7 @@ impl MessageProcessor { tracing::info!("prompts/get -> params: {:?}", params); } - fn handle_list_tools( + async fn handle_list_tools( &self, id: RequestId, params: ::Params, @@ -287,10 +274,11 @@ impl MessageProcessor { next_cursor: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; } - fn handle_call_tool( + async fn handle_call_tool( &self, id: RequestId, params: ::Params, @@ -310,7 +298,8 @@ impl MessageProcessor { is_error: Some(true), structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } @@ -330,7 +319,8 @@ impl MessageProcessor { is_error: Some(true), structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }, @@ -344,7 +334,8 @@ impl MessageProcessor { is_error: Some(true), structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }, @@ -360,7 +351,8 @@ impl MessageProcessor { is_error: Some(true), structured_content: None, }; - self.send_response::(id, result); + self.send_response::(id, result) + .await; return; } }; diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs new file mode 100644 index 00000000..93a760d3 --- /dev/null +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -0,0 +1,129 @@ +use std::sync::atomic::AtomicI64; +use std::sync::atomic::Ordering; + +use codex_core::protocol::Event; +use mcp_types::JSONRPC_VERSION; +use mcp_types::JSONRPCError; +use mcp_types::JSONRPCErrorError; +use mcp_types::JSONRPCMessage; +use mcp_types::JSONRPCNotification; +use mcp_types::JSONRPCRequest; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use mcp_types::Result; +use serde::Serialize; +use tokio::sync::mpsc; + +pub(crate) struct OutgoingMessageSender { + next_request_id: AtomicI64, + sender: mpsc::Sender, +} + +impl OutgoingMessageSender { + pub(crate) fn new(sender: mpsc::Sender) -> Self { + Self { + next_request_id: AtomicI64::new(0), + sender, + } + } + + #[allow(dead_code)] + pub(crate) async fn send_request(&self, method: &str, params: Option) { + let outgoing_message = OutgoingMessage::Request(OutgoingRequest { + id: RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)), + method: method.to_string(), + params, + }); + let _ = self.sender.send(outgoing_message).await; + } + + pub(crate) async fn send_response(&self, id: RequestId, result: Result) { + let outgoing_message = OutgoingMessage::Response(OutgoingResponse { id, result }); + let _ = self.sender.send(outgoing_message).await; + } + + pub(crate) async fn send_event_as_notification(&self, event: &Event) { + #[expect(clippy::expect_used)] + let params = Some(serde_json::to_value(event).expect("Event must serialize")); + let outgoing_message = OutgoingMessage::Notification(OutgoingNotification { + method: "codex/event".to_string(), + params, + }); + let _ = self.sender.send(outgoing_message).await; + } + + pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { + let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error }); + let _ = self.sender.send(outgoing_message).await; + } +} + +/// Outgoing message from the server to the client. +pub(crate) enum OutgoingMessage { + Request(OutgoingRequest), + Notification(OutgoingNotification), + Response(OutgoingResponse), + Error(OutgoingError), +} + +impl From for JSONRPCMessage { + fn from(val: OutgoingMessage) -> Self { + use OutgoingMessage::*; + match val { + Request(OutgoingRequest { id, method, params }) => { + JSONRPCMessage::Request(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id, + method, + params, + }) + } + Notification(OutgoingNotification { method, params }) => { + JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method, + params, + }) + } + Response(OutgoingResponse { id, result }) => { + JSONRPCMessage::Response(JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id, + result, + }) + } + Error(OutgoingError { id, error }) => JSONRPCMessage::Error(JSONRPCError { + jsonrpc: JSONRPC_VERSION.into(), + id, + error, + }), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingRequest { + pub id: RequestId, + pub method: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingNotification { + pub method: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingResponse { + pub id: RequestId, + pub result: Result, +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub(crate) struct OutgoingError { + pub error: JSONRPCErrorError, + pub id: RequestId, +}