From 01c0896f0f0867ba072ca8e742532252220049bc Mon Sep 17 00:00:00 2001 From: aibrahim-oai Date: Tue, 22 Jul 2025 13:33:49 -0700 Subject: [PATCH] Adding interrupt Support to MCP (#1646) --- codex-rs/mcp-server/src/codex_tool_config.rs | 2 +- codex-rs/mcp-server/src/codex_tool_runner.rs | 49 ++++- codex-rs/mcp-server/src/lib.rs | 3 +- codex-rs/mcp-server/src/message_processor.rs | 98 ++++++++-- .../mcp-server/tests/common/mcp_process.rs | 81 +++++++- codex-rs/mcp-server/tests/common/mod.rs | 2 + codex-rs/mcp-server/tests/common/responses.rs | 4 + codex-rs/mcp-server/tests/interrupt.rs | 176 ++++++++++++++++++ 8 files changed, 389 insertions(+), 26 deletions(-) create mode 100644 codex-rs/mcp-server/tests/interrupt.rs diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 6357c94b..9f6f7a78 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -168,7 +168,7 @@ impl CodexToolCallParam { #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "camelCase")] -pub(crate) struct CodexToolCallReplyParam { +pub struct CodexToolCallReplyParam { /// The *session id* for this conversation. pub session_id: String, diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index df2154dd..9aaab543 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -20,6 +20,7 @@ use mcp_types::CallToolResult; use mcp_types::ContentBlock; use mcp_types::RequestId; use mcp_types::TextContent; +use serde_json::json; use tokio::sync::Mutex; use uuid::Uuid; @@ -39,6 +40,7 @@ pub async fn run_codex_tool_session( config: CodexConfig, outgoing: Arc, session_map: Arc>>>, + running_requests_id_to_codex_uuid: Arc>>, ) { let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await { Ok(res) => res, @@ -73,7 +75,10 @@ pub async fn run_codex_tool_session( RequestId::String(s) => s.clone(), RequestId::Integer(n) => n.to_string(), }; - + running_requests_id_to_codex_uuid + .lock() + .await + .insert(id.clone(), session_id); let submission = Submission { id: sub_id.clone(), op: Op::UserInput { @@ -85,9 +90,12 @@ pub async fn run_codex_tool_session( if let Err(e) = codex.submit_with_id(submission).await { tracing::error!("Failed to submit initial prompt: {e}"); + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid.lock().await.remove(&id); + return; } - run_codex_tool_session_inner(codex, outgoing, id).await; + run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await; } pub async fn run_codex_tool_session_reply( @@ -95,7 +103,13 @@ pub async fn run_codex_tool_session_reply( outgoing: Arc, request_id: RequestId, prompt: String, + running_requests_id_to_codex_uuid: Arc>>, + session_id: Uuid, ) { + running_requests_id_to_codex_uuid + .lock() + .await + .insert(request_id.clone(), session_id); if let Err(e) = codex .submit(Op::UserInput { items: vec![InputItem::Text { text: prompt }], @@ -103,15 +117,28 @@ pub async fn run_codex_tool_session_reply( .await { tracing::error!("Failed to submit user input: {e}"); + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); + return; } - run_codex_tool_session_inner(codex, outgoing, request_id).await; + run_codex_tool_session_inner( + codex, + outgoing, + request_id, + running_requests_id_to_codex_uuid, + ) + .await; } async fn run_codex_tool_session_inner( codex: Arc, outgoing: Arc, request_id: RequestId, + running_requests_id_to_codex_uuid: Arc>>, ) { let request_id_str = match &request_id { RequestId::String(s) => s.clone(), @@ -143,6 +170,14 @@ async fn run_codex_tool_session_inner( .await; continue; } + EventMsg::Error(err_event) => { + // Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption). + let result = json!({ + "error": err_event.message, + }); + outgoing.send_response(request_id.clone(), result).await; + break; + } EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { reason, grant_root, @@ -178,6 +213,11 @@ async fn run_codex_tool_session_inner( outgoing .send_response(request_id.clone(), result.into()) .await; + // unregister the id so we don't keep it in the map + running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); break; } EventMsg::SessionConfigured(_) => { @@ -192,8 +232,7 @@ async fn run_codex_tool_session_inner( EventMsg::AgentMessage(AgentMessageEvent { .. }) => { // TODO: think how we want to support this in the MCP } - EventMsg::Error(_) - | EventMsg::TaskStarted + EventMsg::TaskStarted | EventMsg::TokenCount(_) | EventMsg::AgentReasoning(_) | EventMsg::McpToolCallBegin(_) diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index 300d1b5f..79981e49 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -27,6 +27,7 @@ use crate::outgoing_message::OutgoingMessage; use crate::outgoing_message::OutgoingMessageSender; pub use crate::codex_tool_config::CodexToolCallParam; +pub use crate::codex_tool_config::CodexToolCallReplyParam; pub use crate::exec_approval::ExecApprovalElicitRequestParams; pub use crate::exec_approval::ExecApprovalResponse; pub use crate::patch_approval::PatchApprovalElicitRequestParams; @@ -81,7 +82,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> match msg { JSONRPCMessage::Request(r) => processor.process_request(r).await, JSONRPCMessage::Response(r) => processor.process_response(r).await, - JSONRPCMessage::Notification(n) => processor.process_notification(n), + JSONRPCMessage::Notification(n) => processor.process_notification(n).await, JSONRPCMessage::Error(e) => processor.process_error(e), } } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index e72a52e0..7ba827d6 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -10,6 +10,7 @@ use crate::outgoing_message::OutgoingMessageSender; use codex_core::Codex; use codex_core::config::Config as CodexConfig; +use codex_core::protocol::Submission; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::ClientRequest; @@ -35,6 +36,7 @@ pub(crate) struct MessageProcessor { initialized: bool, codex_linux_sandbox_exe: Option, session_map: Arc>>>, + running_requests_id_to_codex_uuid: Arc>>, } impl MessageProcessor { @@ -49,6 +51,7 @@ impl MessageProcessor { initialized: false, codex_linux_sandbox_exe, session_map: Arc::new(Mutex::new(HashMap::new())), + running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())), } } @@ -116,7 +119,7 @@ impl MessageProcessor { } /// Handle a fire-and-forget JSON-RPC notification. - pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) { + pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) { let server_notification = match ServerNotification::try_from(notification) { Ok(n) => n, Err(e) => { @@ -129,7 +132,7 @@ impl MessageProcessor { // handler so additional logic can be implemented incrementally. match server_notification { ServerNotification::CancelledNotification(params) => { - self.handle_cancelled_notification(params); + self.handle_cancelled_notification(params).await; } ServerNotification::ProgressNotification(params) => { self.handle_progress_notification(params); @@ -379,6 +382,7 @@ impl MessageProcessor { // Clone outgoing and session map to move into async task. let outgoing = self.outgoing.clone(); let session_map = self.session_map.clone(); + let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone(); // Spawn an async task to handle the Codex session so that we do not // block the synchronous message-processing loop. @@ -390,6 +394,7 @@ impl MessageProcessor { config, outgoing, session_map, + running_requests_id_to_codex_uuid, ) .await; }); @@ -464,13 +469,12 @@ impl MessageProcessor { // Clone outgoing and session map to move into async task. let outgoing = self.outgoing.clone(); + let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone(); - // Spawn an async task to handle the Codex session so that we do not - // block the synchronous message-processing loop. - task::spawn(async move { + let codex = { let session_map = session_map_mutex.lock().await; - let codex = match session_map.get(&session_id) { - Some(codex) => codex, + match session_map.get(&session_id).cloned() { + Some(c) => c, None => { tracing::warn!("Session not found for session_id: {session_id}"); let result = CallToolResult { @@ -482,21 +486,32 @@ impl MessageProcessor { is_error: Some(true), structured_content: None, }; - // unwrap_or_default is fine here because we know the result is valid JSON outgoing .send_response(request_id, serde_json::to_value(result).unwrap_or_default()) .await; return; } - }; + } + }; - crate::codex_tool_runner::run_codex_tool_session_reply( - codex.clone(), - outgoing, - request_id, - prompt.clone(), - ) - .await; + // Spawn the long-running reply handler. + tokio::spawn({ + let codex = codex.clone(); + let outgoing = outgoing.clone(); + let prompt = prompt.clone(); + let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone(); + + async move { + crate::codex_tool_runner::run_codex_tool_session_reply( + codex, + outgoing, + request_id, + prompt, + running_requests_id_to_codex_uuid, + session_id, + ) + .await; + } }); } @@ -518,11 +533,58 @@ impl MessageProcessor { // Notification handlers // --------------------------------------------------------------------- - fn handle_cancelled_notification( + async fn handle_cancelled_notification( &self, params: ::Params, ) { - tracing::info!("notifications/cancelled -> params: {:?}", params); + let request_id = params.request_id; + // Create a stable string form early for logging and submission id. + let request_id_string = match &request_id { + RequestId::String(s) => s.clone(), + RequestId::Integer(i) => i.to_string(), + }; + + // Obtain the session_id while holding the first lock, then release. + let session_id = { + let map_guard = self.running_requests_id_to_codex_uuid.lock().await; + match map_guard.get(&request_id) { + Some(id) => *id, // Uuid is Copy + None => { + tracing::warn!("Session not found for request_id: {}", request_id_string); + return; + } + } + }; + tracing::info!("session_id: {session_id}"); + + // Obtain the Codex Arc while holding the session_map lock, then release. + let codex_arc = { + let sessions_guard = self.session_map.lock().await; + match sessions_guard.get(&session_id) { + Some(codex) => Arc::clone(codex), + None => { + tracing::warn!("Session not found for session_id: {session_id}"); + return; + } + } + }; + + // Submit interrupt to Codex. + let err = codex_arc + .submit_with_id(Submission { + id: request_id_string, + op: codex_core::protocol::Op::Interrupt, + }) + .await; + if let Err(e) = err { + tracing::error!("Failed to submit interrupt to Codex: {e}"); + return; + } + // unregister the id so we don't keep it in the map + self.running_requests_id_to_codex_uuid + .lock() + .await + .remove(&request_id); } fn handle_progress_notification( diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index a86deaab..8f1f7a9e 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -12,6 +12,7 @@ use tokio::process::ChildStdout; use anyhow::Context; use assert_cmd::prelude::*; use codex_mcp_server::CodexToolCallParam; +use codex_mcp_server::CodexToolCallReplyParam; use mcp_types::CallToolRequestParams; use mcp_types::ClientCapabilities; use mcp_types::Implementation; @@ -154,6 +155,25 @@ impl McpProcess { .await } + pub async fn send_codex_reply_tool_call( + &mut self, + session_id: &str, + prompt: &str, + ) -> anyhow::Result { + let codex_tool_call_params = CallToolRequestParams { + name: "codex-reply".to_string(), + arguments: Some(serde_json::to_value(CodexToolCallReplyParam { + prompt: prompt.to_string(), + session_id: session_id.to_string(), + })?), + }; + self.send_request( + mcp_types::CallToolRequest::METHOD, + Some(serde_json::to_value(codex_tool_call_params)?), + ) + .await + } + async fn send_request( &mut self, method: &str, @@ -171,6 +191,8 @@ impl McpProcess { Ok(request_id) } + // allow dead code + #[allow(dead_code)] pub async fn send_response( &mut self, id: RequestId, @@ -198,7 +220,8 @@ impl McpProcess { let message = serde_json::from_str::(&line)?; Ok(message) } - + // allow dead code + #[allow(dead_code)] pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result { loop { let message = self.read_jsonrpc_message().await?; @@ -221,6 +244,8 @@ impl McpProcess { } } + // allow dead code + #[allow(dead_code)] pub async fn read_stream_until_response_message( &mut self, request_id: RequestId, @@ -247,4 +272,58 @@ impl McpProcess { } } } + + pub async fn read_stream_until_configured_response_message( + &mut self, + ) -> anyhow::Result { + loop { + let message = self.read_jsonrpc_message().await?; + eprint!("message: {message:?}"); + + match message { + JSONRPCMessage::Notification(notification) => { + if notification.method == "codex/event" { + if let Some(params) = notification.params { + if let Some(msg) = params.get("msg") { + if let Some(msg_type) = msg.get("type") { + if msg_type == "session_configured" { + if let Some(session_id) = msg.get("session_id") { + return Ok(session_id + .to_string() + .trim_matches('"') + .to_string()); + } + } + } + } + } + } + } + JSONRPCMessage::Request(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); + } + JSONRPCMessage::Error(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); + } + JSONRPCMessage::Response(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); + } + } + } + } + + // allow dead code + #[allow(dead_code)] + pub async fn send_notification( + &mut self, + method: &str, + params: Option, + ) -> anyhow::Result<()> { + self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification { + jsonrpc: JSONRPC_VERSION.into(), + method: method.to_string(), + params, + })) + .await + } } diff --git a/codex-rs/mcp-server/tests/common/mod.rs b/codex-rs/mcp-server/tests/common/mod.rs index b338e2e8..a9593e39 100644 --- a/codex-rs/mcp-server/tests/common/mod.rs +++ b/codex-rs/mcp-server/tests/common/mod.rs @@ -4,6 +4,8 @@ mod responses; pub use mcp_process::McpProcess; pub use mock_model_server::create_mock_chat_completions_server; +#[allow(unused_imports)] pub use responses::create_apply_patch_sse_response; +#[allow(unused_imports)] pub use responses::create_final_assistant_message_sse_response; pub use responses::create_shell_sse_response; diff --git a/codex-rs/mcp-server/tests/common/responses.rs b/codex-rs/mcp-server/tests/common/responses.rs index 9a827fb9..f47952a5 100644 --- a/codex-rs/mcp-server/tests/common/responses.rs +++ b/codex-rs/mcp-server/tests/common/responses.rs @@ -39,6 +39,8 @@ pub fn create_shell_sse_response( Ok(sse) } +// allow dead code +#[allow(dead_code)] pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result { let assistant_message = json!({ "choices": [ @@ -58,6 +60,8 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res Ok(sse) } +// allow dead code +#[allow(dead_code)] pub fn create_apply_patch_sse_response( patch_content: &str, call_id: &str, diff --git a/codex-rs/mcp-server/tests/interrupt.rs b/codex-rs/mcp-server/tests/interrupt.rs new file mode 100644 index 00000000..64cf8b47 --- /dev/null +++ b/codex-rs/mcp-server/tests/interrupt.rs @@ -0,0 +1,176 @@ +#![cfg(unix)] +mod common; + +use std::path::Path; + +use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_mcp_server::CodexToolCallParam; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use serde_json::json; +use tempfile::TempDir; +use tokio::time::timeout; + +use crate::common::McpProcess; +use crate::common::create_mock_chat_completions_server; +use crate::common::create_shell_sse_response; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_shell_command_interruption() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + if let Err(err) = shell_command_interruption().await { + panic!("failure: {err}"); + } +} + +async fn shell_command_interruption() -> anyhow::Result<()> { + // Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist + // (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately, + // which triggers a second model request before the test sends the explicit follow-up. That + // prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2). + // Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`. + #[cfg(target_os = "windows")] + let shell_command = vec![ + "powershell".to_string(), + "-Command".to_string(), + "Start-Sleep -Seconds 60".to_string(), + ]; + #[cfg(not(target_os = "windows"))] + let shell_command = vec!["sleep".to_string(), "60".to_string()]; + let workdir_for_shell_function_call = TempDir::new()?; + + // Create mock server with a single SSE response: the long sleep command + let server = create_mock_chat_completions_server(vec![ + create_shell_sse_response( + shell_command.clone(), + Some(workdir_for_shell_function_call.path()), + Some(60_000), // 60 seconds timeout in ms + "call_sleep", + )?, + create_shell_sse_response( + shell_command.clone(), + Some(workdir_for_shell_function_call.path()), + Some(60_000), // 60 seconds timeout in ms + "call_sleep", + )?, + ]) + .await; + + // Create Codex configuration + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + + // Send codex tool call that triggers "sleep 60" + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + cwd: None, + prompt: "First Run: run `sleep 60`".to_string(), + model: None, + profile: None, + approval_policy: None, + sandbox: None, + config: None, + base_instructions: None, + }) + .await?; + + let session_id = mcp_process + .read_stream_until_configured_response_message() + .await?; + + // Give the command a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // Send interrupt notification + mcp_process + .send_notification( + "notifications/cancelled", + Some(json!({ "requestId": codex_request_id })), + ) + .await?; + + // Expect Codex to return an error or interruption response + let codex_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + + assert!( + codex_response + .result + .as_object() + .map(|o| o.contains_key("error")) + .unwrap_or(false), + "Expected an interruption or error result, got: {codex_response:?}" + ); + + let codex_reply_request_id = mcp_process + .send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`") + .await?; + + // Give the command a moment to start + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // Send interrupt notification + mcp_process + .send_notification( + "notifications/cancelled", + Some(json!({ "requestId": codex_reply_request_id })), + ) + .await?; + + // Expect Codex to return an error or interruption response + let codex_response: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)), + ) + .await??; + + assert!( + codex_response + .result + .as_object() + .map(|o| o.contains_key("error")) + .unwrap_or(false), + "Expected an interruption or error result, got: {codex_response:?}" + ); + Ok(()) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "danger-full-access" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +}