Adding interrupt Support to MCP (#1646)

This commit is contained in:
aibrahim-oai
2025-07-22 13:33:49 -07:00
committed by GitHub
parent 4082246f6a
commit 01c0896f0f
8 changed files with 389 additions and 26 deletions

View File

@@ -168,7 +168,7 @@ impl CodexToolCallParam {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub(crate) struct CodexToolCallReplyParam {
pub struct CodexToolCallReplyParam {
/// The *session id* for this conversation.
pub session_id: String,

View File

@@ -20,6 +20,7 @@ use mcp_types::CallToolResult;
use mcp_types::ContentBlock;
use mcp_types::RequestId;
use mcp_types::TextContent;
use serde_json::json;
use tokio::sync::Mutex;
use uuid::Uuid;
@@ -39,6 +40,7 @@ pub async fn run_codex_tool_session(
config: CodexConfig,
outgoing: Arc<OutgoingMessageSender>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
Ok(res) => res,
@@ -73,7 +75,10 @@ pub async fn run_codex_tool_session(
RequestId::String(s) => s.clone(),
RequestId::Integer(n) => n.to_string(),
};
running_requests_id_to_codex_uuid
.lock()
.await
.insert(id.clone(), session_id);
let submission = Submission {
id: sub_id.clone(),
op: Op::UserInput {
@@ -85,9 +90,12 @@ pub async fn run_codex_tool_session(
if let Err(e) = codex.submit_with_id(submission).await {
tracing::error!("Failed to submit initial prompt: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid.lock().await.remove(&id);
return;
}
run_codex_tool_session_inner(codex, outgoing, id).await;
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
}
pub async fn run_codex_tool_session_reply(
@@ -95,7 +103,13 @@ pub async fn run_codex_tool_session_reply(
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
prompt: String,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
session_id: Uuid,
) {
running_requests_id_to_codex_uuid
.lock()
.await
.insert(request_id.clone(), session_id);
if let Err(e) = codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: prompt }],
@@ -103,15 +117,28 @@ pub async fn run_codex_tool_session_reply(
.await
{
tracing::error!("Failed to submit user input: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
return;
}
run_codex_tool_session_inner(codex, outgoing, request_id).await;
run_codex_tool_session_inner(
codex,
outgoing,
request_id,
running_requests_id_to_codex_uuid,
)
.await;
}
async fn run_codex_tool_session_inner(
codex: Arc<Codex>,
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let request_id_str = match &request_id {
RequestId::String(s) => s.clone(),
@@ -143,6 +170,14 @@ async fn run_codex_tool_session_inner(
.await;
continue;
}
EventMsg::Error(err_event) => {
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
let result = json!({
"error": err_event.message,
});
outgoing.send_response(request_id.clone(), result).await;
break;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
reason,
grant_root,
@@ -178,6 +213,11 @@ async fn run_codex_tool_session_inner(
outgoing
.send_response(request_id.clone(), result.into())
.await;
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
break;
}
EventMsg::SessionConfigured(_) => {
@@ -192,8 +232,7 @@ async fn run_codex_tool_session_inner(
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
// TODO: think how we want to support this in the MCP
}
EventMsg::Error(_)
| EventMsg::TaskStarted
EventMsg::TaskStarted
| EventMsg::TokenCount(_)
| EventMsg::AgentReasoning(_)
| EventMsg::McpToolCallBegin(_)

View File

@@ -27,6 +27,7 @@ use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
pub use crate::codex_tool_config::CodexToolCallParam;
pub use crate::codex_tool_config::CodexToolCallReplyParam;
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
pub use crate::exec_approval::ExecApprovalResponse;
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
@@ -81,7 +82,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
match msg {
JSONRPCMessage::Request(r) => processor.process_request(r).await,
JSONRPCMessage::Response(r) => processor.process_response(r).await,
JSONRPCMessage::Notification(n) => processor.process_notification(n),
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
JSONRPCMessage::Error(e) => processor.process_error(e),
}
}

View File

@@ -10,6 +10,7 @@ use crate::outgoing_message::OutgoingMessageSender;
use codex_core::Codex;
use codex_core::config::Config as CodexConfig;
use codex_core::protocol::Submission;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::ClientRequest;
@@ -35,6 +36,7 @@ pub(crate) struct MessageProcessor {
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
}
impl MessageProcessor {
@@ -49,6 +51,7 @@ impl MessageProcessor {
initialized: false,
codex_linux_sandbox_exe,
session_map: Arc::new(Mutex::new(HashMap::new())),
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
}
}
@@ -116,7 +119,7 @@ impl MessageProcessor {
}
/// Handle a fire-and-forget JSON-RPC notification.
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
let server_notification = match ServerNotification::try_from(notification) {
Ok(n) => n,
Err(e) => {
@@ -129,7 +132,7 @@ impl MessageProcessor {
// handler so additional logic can be implemented incrementally.
match server_notification {
ServerNotification::CancelledNotification(params) => {
self.handle_cancelled_notification(params);
self.handle_cancelled_notification(params).await;
}
ServerNotification::ProgressNotification(params) => {
self.handle_progress_notification(params);
@@ -379,6 +382,7 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let session_map = self.session_map.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
@@ -390,6 +394,7 @@ impl MessageProcessor {
config,
outgoing,
session_map,
running_requests_id_to_codex_uuid,
)
.await;
});
@@ -464,13 +469,12 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
task::spawn(async move {
let codex = {
let session_map = session_map_mutex.lock().await;
let codex = match session_map.get(&session_id) {
Some(codex) => codex,
match session_map.get(&session_id).cloned() {
Some(c) => c,
None => {
tracing::warn!("Session not found for session_id: {session_id}");
let result = CallToolResult {
@@ -482,21 +486,32 @@ impl MessageProcessor {
is_error: Some(true),
structured_content: None,
};
// unwrap_or_default is fine here because we know the result is valid JSON
outgoing
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
.await;
return;
}
};
}
};
crate::codex_tool_runner::run_codex_tool_session_reply(
codex.clone(),
outgoing,
request_id,
prompt.clone(),
)
.await;
// Spawn the long-running reply handler.
tokio::spawn({
let codex = codex.clone();
let outgoing = outgoing.clone();
let prompt = prompt.clone();
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
async move {
crate::codex_tool_runner::run_codex_tool_session_reply(
codex,
outgoing,
request_id,
prompt,
running_requests_id_to_codex_uuid,
session_id,
)
.await;
}
});
}
@@ -518,11 +533,58 @@ impl MessageProcessor {
// Notification handlers
// ---------------------------------------------------------------------
fn handle_cancelled_notification(
async fn handle_cancelled_notification(
&self,
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/cancelled -> params: {:?}", params);
let request_id = params.request_id;
// Create a stable string form early for logging and submission id.
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
};
// Obtain the session_id while holding the first lock, then release.
let session_id = {
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
match map_guard.get(&request_id) {
Some(id) => *id, // Uuid is Copy
None => {
tracing::warn!("Session not found for request_id: {}", request_id_string);
return;
}
}
};
tracing::info!("session_id: {session_id}");
// Obtain the Codex Arc while holding the session_map lock, then release.
let codex_arc = {
let sessions_guard = self.session_map.lock().await;
match sessions_guard.get(&session_id) {
Some(codex) => Arc::clone(codex),
None => {
tracing::warn!("Session not found for session_id: {session_id}");
return;
}
}
};
// Submit interrupt to Codex.
let err = codex_arc
.submit_with_id(Submission {
id: request_id_string,
op: codex_core::protocol::Op::Interrupt,
})
.await;
if let Err(e) = err {
tracing::error!("Failed to submit interrupt to Codex: {e}");
return;
}
// unregister the id so we don't keep it in the map
self.running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
}
fn handle_progress_notification(

View File

@@ -12,6 +12,7 @@ use tokio::process::ChildStdout;
use anyhow::Context;
use assert_cmd::prelude::*;
use codex_mcp_server::CodexToolCallParam;
use codex_mcp_server::CodexToolCallReplyParam;
use mcp_types::CallToolRequestParams;
use mcp_types::ClientCapabilities;
use mcp_types::Implementation;
@@ -154,6 +155,25 @@ impl McpProcess {
.await
}
pub async fn send_codex_reply_tool_call(
&mut self,
session_id: &str,
prompt: &str,
) -> anyhow::Result<i64> {
let codex_tool_call_params = CallToolRequestParams {
name: "codex-reply".to_string(),
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
prompt: prompt.to_string(),
session_id: session_id.to_string(),
})?),
};
self.send_request(
mcp_types::CallToolRequest::METHOD,
Some(serde_json::to_value(codex_tool_call_params)?),
)
.await
}
async fn send_request(
&mut self,
method: &str,
@@ -171,6 +191,8 @@ impl McpProcess {
Ok(request_id)
}
// allow dead code
#[allow(dead_code)]
pub async fn send_response(
&mut self,
id: RequestId,
@@ -198,7 +220,8 @@ impl McpProcess {
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
Ok(message)
}
// allow dead code
#[allow(dead_code)]
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
loop {
let message = self.read_jsonrpc_message().await?;
@@ -221,6 +244,8 @@ impl McpProcess {
}
}
// allow dead code
#[allow(dead_code)]
pub async fn read_stream_until_response_message(
&mut self,
request_id: RequestId,
@@ -247,4 +272,58 @@ impl McpProcess {
}
}
}
pub async fn read_stream_until_configured_response_message(
&mut self,
) -> anyhow::Result<String> {
loop {
let message = self.read_jsonrpc_message().await?;
eprint!("message: {message:?}");
match message {
JSONRPCMessage::Notification(notification) => {
if notification.method == "codex/event" {
if let Some(params) = notification.params {
if let Some(msg) = params.get("msg") {
if let Some(msg_type) = msg.get("type") {
if msg_type == "session_configured" {
if let Some(session_id) = msg.get("session_id") {
return Ok(session_id
.to_string()
.trim_matches('"')
.to_string());
}
}
}
}
}
}
}
JSONRPCMessage::Request(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
}
JSONRPCMessage::Error(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
}
JSONRPCMessage::Response(_) => {
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
}
}
}
}
// allow dead code
#[allow(dead_code)]
pub async fn send_notification(
&mut self,
method: &str,
params: Option<serde_json::Value>,
) -> anyhow::Result<()> {
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
jsonrpc: JSONRPC_VERSION.into(),
method: method.to_string(),
params,
}))
.await
}
}

View File

@@ -4,6 +4,8 @@ mod responses;
pub use mcp_process::McpProcess;
pub use mock_model_server::create_mock_chat_completions_server;
#[allow(unused_imports)]
pub use responses::create_apply_patch_sse_response;
#[allow(unused_imports)]
pub use responses::create_final_assistant_message_sse_response;
pub use responses::create_shell_sse_response;

View File

@@ -39,6 +39,8 @@ pub fn create_shell_sse_response(
Ok(sse)
}
// allow dead code
#[allow(dead_code)]
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
let assistant_message = json!({
"choices": [
@@ -58,6 +60,8 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res
Ok(sse)
}
// allow dead code
#[allow(dead_code)]
pub fn create_apply_patch_sse_response(
patch_content: &str,
call_id: &str,

View File

@@ -0,0 +1,176 @@
#![cfg(unix)]
mod common;
use std::path::Path;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use codex_mcp_server::CodexToolCallParam;
use mcp_types::JSONRPCResponse;
use mcp_types::RequestId;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::timeout;
use crate::common::McpProcess;
use crate::common::create_mock_chat_completions_server;
use crate::common::create_shell_sse_response;
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shell_command_interruption() {
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
if let Err(err) = shell_command_interruption().await {
panic!("failure: {err}");
}
}
async fn shell_command_interruption() -> anyhow::Result<()> {
// Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist
// (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately,
// which triggers a second model request before the test sends the explicit follow-up. That
// prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2).
// Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`.
#[cfg(target_os = "windows")]
let shell_command = vec![
"powershell".to_string(),
"-Command".to_string(),
"Start-Sleep -Seconds 60".to_string(),
];
#[cfg(not(target_os = "windows"))]
let shell_command = vec!["sleep".to_string(), "60".to_string()];
let workdir_for_shell_function_call = TempDir::new()?;
// Create mock server with a single SSE response: the long sleep command
let server = create_mock_chat_completions_server(vec![
create_shell_sse_response(
shell_command.clone(),
Some(workdir_for_shell_function_call.path()),
Some(60_000), // 60 seconds timeout in ms
"call_sleep",
)?,
create_shell_sse_response(
shell_command.clone(),
Some(workdir_for_shell_function_call.path()),
Some(60_000), // 60 seconds timeout in ms
"call_sleep",
)?,
])
.await;
// Create Codex configuration
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), server.uri())?;
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
// Send codex tool call that triggers "sleep 60"
let codex_request_id = mcp_process
.send_codex_tool_call(CodexToolCallParam {
cwd: None,
prompt: "First Run: run `sleep 60`".to_string(),
model: None,
profile: None,
approval_policy: None,
sandbox: None,
config: None,
base_instructions: None,
})
.await?;
let session_id = mcp_process
.read_stream_until_configured_response_message()
.await?;
// Give the command a moment to start
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// Send interrupt notification
mcp_process
.send_notification(
"notifications/cancelled",
Some(json!({ "requestId": codex_request_id })),
)
.await?;
// Expect Codex to return an error or interruption response
let codex_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
)
.await??;
assert!(
codex_response
.result
.as_object()
.map(|o| o.contains_key("error"))
.unwrap_or(false),
"Expected an interruption or error result, got: {codex_response:?}"
);
let codex_reply_request_id = mcp_process
.send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`")
.await?;
// Give the command a moment to start
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
// Send interrupt notification
mcp_process
.send_notification(
"notifications/cancelled",
Some(json!({ "requestId": codex_reply_request_id })),
)
.await?;
// Expect Codex to return an error or interruption response
let codex_response: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)),
)
.await??;
assert!(
codex_response
.result
.as_object()
.map(|o| o.contains_key("error"))
.unwrap_or(false),
"Expected an interruption or error result, got: {codex_response:?}"
);
Ok(())
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}