From f918198bbbfa95e20c629eed758cbd53ad38ab97 Mon Sep 17 00:00:00 2001 From: aibrahim-oai Date: Fri, 1 Aug 2025 10:04:12 -0700 Subject: [PATCH] Introduce a new function to just send user message [Stack 3/3] (#1686) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MCP server: add send-user-message tool to send user input to a running Codex session - Added an integration tests for the happy and sad paths Changes: • Add tool definition and schema. • Expose tool in capabilities. • Route and handle tool requests with validation. • Tests for success, bad UUID, and missing session. follow‑ups • Listen path not implemented yet; the tool is present but marked “don’t use yet” in code comments. • Session run flag reset: clear running_session_id_set appropriately after turn completion/errors. This is the third PR in a stack. Stack: Final: #1686 Intermediate: #1751 First: #1750 --- codex-rs/mcp-server/src/lib.rs | 3 +- codex-rs/mcp-server/src/mcp_protocol.rs | 36 ++-- codex-rs/mcp-server/src/message_processor.rs | 63 +++++-- codex-rs/mcp-server/src/tool_handlers/mod.rs | 1 + .../src/tool_handlers/send_message.rs | 124 +++++++++++++ codex-rs/mcp-server/tests/send_message.rs | 163 ++++++++++++++++++ 6 files changed, 358 insertions(+), 32 deletions(-) create mode 100644 codex-rs/mcp-server/src/tool_handlers/mod.rs create mode 100644 codex-rs/mcp-server/src/tool_handlers/send_message.rs create mode 100644 codex-rs/mcp-server/tests/send_message.rs diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index ebef8ca9..6b3c0ddb 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -20,9 +20,10 @@ mod codex_tool_runner; mod exec_approval; mod json_to_toml; pub mod mcp_protocol; -mod message_processor; +pub(crate) mod message_processor; mod outgoing_message; mod patch_approval; +pub(crate) mod tool_handlers; use crate::message_processor::MessageProcessor; use crate::outgoing_message::OutgoingMessage; diff --git a/codex-rs/mcp-server/src/mcp_protocol.rs b/codex-rs/mcp-server/src/mcp_protocol.rs index 23304dc4..287890bf 100644 --- a/codex-rs/mcp-server/src/mcp_protocol.rs +++ b/codex-rs/mcp-server/src/mcp_protocol.rs @@ -132,32 +132,32 @@ impl From for CallToolResult { is_error, result, } = val; - let (content, structured_content, is_error_out) = match result { + match result { Some(res) => match serde_json::to_value(&res) { - Ok(v) => { - let content = vec![ContentBlock::TextContent(TextContent { + Ok(v) => CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_string(), text: v.to_string(), annotations: None, - })]; - (content, Some(v), is_error) - } - Err(e) => { - let content = vec![ContentBlock::TextContent(TextContent { + })], + is_error, + structured_content: Some(v), + }, + Err(e) => CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { r#type: "text".to_string(), text: format!("Failed to serialize tool result: {e}"), annotations: None, - })]; - (content, None, Some(true)) - } + })], + is_error: Some(true), + structured_content: None, + }, + }, + None => CallToolResult { + content: vec![], + is_error, + structured_content: None, }, - None => (vec![], None, is_error), - }; - - CallToolResult { - content, - is_error: is_error_out, - structured_content, } } } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index a4013cc3..14e9bb38 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; @@ -7,11 +8,15 @@ use crate::codex_tool_config::CodexToolCallReplyParam; use crate::codex_tool_config::create_tool_for_codex_tool_call_param; use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param; use crate::mcp_protocol::ToolCallRequestParams; +use crate::mcp_protocol::ToolCallResponse; +use crate::mcp_protocol::ToolCallResponseResult; use crate::outgoing_message::OutgoingMessageSender; +use crate::tool_handlers::send_message::handle_send_message; use codex_core::Codex; use codex_core::config::Config as CodexConfig; use codex_core::protocol::Submission; +use mcp_types::CallToolRequest; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::ClientRequest; @@ -38,6 +43,7 @@ pub(crate) struct MessageProcessor { codex_linux_sandbox_exe: Option, session_map: Arc>>>, running_requests_id_to_codex_uuid: Arc>>, + running_session_ids: Arc>>, } impl MessageProcessor { @@ -53,9 +59,18 @@ impl MessageProcessor { codex_linux_sandbox_exe, session_map: Arc::new(Mutex::new(HashMap::new())), running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())), + running_session_ids: Arc::new(Mutex::new(HashSet::new())), } } + pub(crate) fn session_map(&self) -> Arc>>> { + self.session_map.clone() + } + + pub(crate) fn running_session_ids(&self) -> Arc>> { + self.running_session_ids.clone() + } + pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { // Hold on to the ID so we can respond. let request_id = request.id.clone(); @@ -332,19 +347,25 @@ impl MessageProcessor { } } } - async fn handle_new_tool_calls(&self, request_id: RequestId, _params: ToolCallRequestParams) { - // TODO: implement the new tool calls - let result = CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: "Unknown tool".to_string(), - annotations: None, - })], - is_error: Some(true), - structured_content: None, - }; - self.send_response::(request_id, result) - .await; + async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) { + match params { + ToolCallRequestParams::ConversationSendMessage(args) => { + handle_send_message(self, request_id, args).await; + } + _ => { + let result = CallToolResult { + content: vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text: "Unknown tool".to_string(), + annotations: None, + })], + is_error: Some(true), + structured_content: None, + }; + self.send_response::(request_id, result) + .await; + } + } } async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option) { @@ -654,4 +675,20 @@ impl MessageProcessor { ) { tracing::info!("notifications/message -> params: {:?}", params); } + + pub(crate) async fn send_response_with_optional_error( + &self, + id: RequestId, + message: Option, + error: Option, + ) { + let response = ToolCallResponse { + request_id: id.clone(), + is_error: error, + result: message, + }; + let result: CallToolResult = response.into(); + self.send_response::(id.clone(), result) + .await; + } } diff --git a/codex-rs/mcp-server/src/tool_handlers/mod.rs b/codex-rs/mcp-server/src/tool_handlers/mod.rs new file mode 100644 index 00000000..1907ec64 --- /dev/null +++ b/codex-rs/mcp-server/src/tool_handlers/mod.rs @@ -0,0 +1 @@ +pub(crate) mod send_message; diff --git a/codex-rs/mcp-server/src/tool_handlers/send_message.rs b/codex-rs/mcp-server/src/tool_handlers/send_message.rs new file mode 100644 index 00000000..894176be --- /dev/null +++ b/codex-rs/mcp-server/src/tool_handlers/send_message.rs @@ -0,0 +1,124 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::protocol::Op; +use codex_core::protocol::Submission; +use mcp_types::RequestId; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::mcp_protocol::ConversationSendMessageArgs; +use crate::mcp_protocol::ConversationSendMessageResult; +use crate::mcp_protocol::ToolCallResponseResult; +use crate::message_processor::MessageProcessor; + +pub(crate) async fn handle_send_message( + message_processor: &MessageProcessor, + id: RequestId, + arguments: ConversationSendMessageArgs, +) { + let ConversationSendMessageArgs { + conversation_id, + content: items, + parent_message_id: _, + conversation_overrides: _, + } = arguments; + + if items.is_empty() { + message_processor + .send_response_with_optional_error( + id, + Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult::Error { + message: "No content items provided".to_string(), + }, + )), + Some(true), + ) + .await; + return; + } + + let session_id = conversation_id.0; + let Some(codex) = get_session(session_id, message_processor.session_map()).await else { + message_processor + .send_response_with_optional_error( + id, + Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult::Error { + message: "Session does not exist".to_string(), + }, + )), + Some(true), + ) + .await; + return; + }; + + let running = { + let running_sessions = message_processor.running_session_ids(); + let mut running_sessions = running_sessions.lock().await; + !running_sessions.insert(session_id) + }; + + if running { + message_processor + .send_response_with_optional_error( + id, + Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult::Error { + message: "Session is already running".to_string(), + }, + )), + Some(true), + ) + .await; + return; + } + + let request_id_string = match &id { + RequestId::String(s) => s.clone(), + RequestId::Integer(i) => i.to_string(), + }; + + let submit_res = codex + .submit_with_id(Submission { + id: request_id_string, + op: Op::UserInput { items }, + }) + .await; + + if let Err(e) = submit_res { + message_processor + .send_response_with_optional_error( + id, + Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult::Error { + message: format!("Failed to submit user input: {e}"), + }, + )), + Some(true), + ) + .await; + return; + } + + message_processor + .send_response_with_optional_error( + id, + Some(ToolCallResponseResult::ConversationSendMessage( + ConversationSendMessageResult::Ok, + )), + Some(false), + ) + .await; +} + +pub(crate) async fn get_session( + session_id: Uuid, + session_map: Arc>>>, +) -> Option> { + let guard = session_map.lock().await; + guard.get(&session_id).cloned() +} diff --git a/codex-rs/mcp-server/tests/send_message.rs b/codex-rs/mcp-server/tests/send_message.rs new file mode 100644 index 00000000..fd4b210b --- /dev/null +++ b/codex-rs/mcp-server/tests/send_message.rs @@ -0,0 +1,163 @@ +#![allow(clippy::expect_used)] + +use std::path::Path; +use std::thread::sleep; +use std::time::Duration; + +use codex_mcp_server::CodexToolCallParam; +use mcp_test_support::McpProcess; +use mcp_test_support::create_final_assistant_message_sse_response; +use mcp_test_support::create_mock_chat_completions_server; +use mcp_types::JSONRPC_VERSION; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use pretty_assertions::assert_eq; +use serde_json::json; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_send_message_success() { + // Spin up a mock completions server that immediately ends the Codex turn. + // Two Codex turns hit the mock model (session start + send-user-message). Provide two SSE responses. + let responses = vec![ + create_final_assistant_message_sse_response("Done").expect("build mock assistant message"), + create_final_assistant_message_sse_response("Done").expect("build mock assistant message"), + ]; + let server = create_mock_chat_completions_server(responses).await; + + // Create a temporary Codex home with config pointing at the mock server. + let codex_home = TempDir::new().expect("create temp dir"); + create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml"); + + // Start MCP server process and initialize. + let mut mcp_process = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()) + .await + .expect("init timed out") + .expect("init failed"); + + // Kick off a Codex session so we have a valid session_id. + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + prompt: "Start a session".to_string(), + ..Default::default() + }) + .await + .expect("send codex tool call"); + + // Wait for the session_configured event to get the session_id. + let session_id = mcp_process + .read_stream_until_configured_response_message() + .await + .expect("read session_configured"); + + // The original codex call will finish quickly given our mock; consume its response. + timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await + .expect("codex response timeout") + .expect("codex response error"); + + // Now exercise the send-user-message tool. + let send_msg_request_id = mcp_process + .send_user_message_tool_call("Hello again", &session_id) + .await + .expect("send send-message tool call"); + + let response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(send_msg_request_id)), + ) + .await + .expect("send-user-message response timeout") + .expect("send-user-message response error"); + + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(send_msg_request_id), + result: json!({ + "content": [ + { + "text": "{\"status\":\"ok\"}", + "type": "text", + } + ], + "isError": false, + "structuredContent": { + "status": "ok" + } + }), + }, + response + ); + // wait for the server to hear the user message + sleep(Duration::from_secs(1)); + + // Ensure the server and tempdir live until end of test + drop(server); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_send_message_session_not_found() { + // Start MCP without creating a Codex session + let codex_home = TempDir::new().expect("tempdir"); + let mut mcp = McpProcess::new(codex_home.path()).await.expect("spawn"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("timeout") + .expect("init"); + + let unknown = uuid::Uuid::new_v4().to_string(); + let req_id = mcp + .send_user_message_tool_call("ping", &unknown) + .await + .expect("send tool"); + + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(req_id)), + ) + .await + .expect("timeout") + .expect("resp"); + + let result = resp.result.clone(); + let content = result["content"][0]["text"].as_str().unwrap_or(""); + assert!(content.contains("Session does not exist")); + assert_eq!(result["isError"], json!(true)); +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "danger-full-access" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +}