Adding interrupt Support to MCP (#1646)

This commit is contained in:
aibrahim-oai
2025-07-22 13:33:49 -07:00
committed by GitHub
parent 4082246f6a
commit 01c0896f0f
8 changed files with 389 additions and 26 deletions

View File

@@ -168,7 +168,7 @@ impl CodexToolCallParam {
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub(crate) struct CodexToolCallReplyParam {
pub struct CodexToolCallReplyParam {
/// The *session id* for this conversation.
pub session_id: String,

View File

@@ -20,6 +20,7 @@ use mcp_types::CallToolResult;
use mcp_types::ContentBlock;
use mcp_types::RequestId;
use mcp_types::TextContent;
use serde_json::json;
use tokio::sync::Mutex;
use uuid::Uuid;
@@ -39,6 +40,7 @@ pub async fn run_codex_tool_session(
config: CodexConfig,
outgoing: Arc<OutgoingMessageSender>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
Ok(res) => res,
@@ -73,7 +75,10 @@ pub async fn run_codex_tool_session(
RequestId::String(s) => s.clone(),
RequestId::Integer(n) => n.to_string(),
};
running_requests_id_to_codex_uuid
.lock()
.await
.insert(id.clone(), session_id);
let submission = Submission {
id: sub_id.clone(),
op: Op::UserInput {
@@ -85,9 +90,12 @@ pub async fn run_codex_tool_session(
if let Err(e) = codex.submit_with_id(submission).await {
tracing::error!("Failed to submit initial prompt: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid.lock().await.remove(&id);
return;
}
run_codex_tool_session_inner(codex, outgoing, id).await;
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
}
pub async fn run_codex_tool_session_reply(
@@ -95,7 +103,13 @@ pub async fn run_codex_tool_session_reply(
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
prompt: String,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
session_id: Uuid,
) {
running_requests_id_to_codex_uuid
.lock()
.await
.insert(request_id.clone(), session_id);
if let Err(e) = codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: prompt }],
@@ -103,15 +117,28 @@ pub async fn run_codex_tool_session_reply(
.await
{
tracing::error!("Failed to submit user input: {e}");
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
return;
}
run_codex_tool_session_inner(codex, outgoing, request_id).await;
run_codex_tool_session_inner(
codex,
outgoing,
request_id,
running_requests_id_to_codex_uuid,
)
.await;
}
async fn run_codex_tool_session_inner(
codex: Arc<Codex>,
outgoing: Arc<OutgoingMessageSender>,
request_id: RequestId,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
) {
let request_id_str = match &request_id {
RequestId::String(s) => s.clone(),
@@ -143,6 +170,14 @@ async fn run_codex_tool_session_inner(
.await;
continue;
}
EventMsg::Error(err_event) => {
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
let result = json!({
"error": err_event.message,
});
outgoing.send_response(request_id.clone(), result).await;
break;
}
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
reason,
grant_root,
@@ -178,6 +213,11 @@ async fn run_codex_tool_session_inner(
outgoing
.send_response(request_id.clone(), result.into())
.await;
// unregister the id so we don't keep it in the map
running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
break;
}
EventMsg::SessionConfigured(_) => {
@@ -192,8 +232,7 @@ async fn run_codex_tool_session_inner(
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
// TODO: think how we want to support this in the MCP
}
EventMsg::Error(_)
| EventMsg::TaskStarted
EventMsg::TaskStarted
| EventMsg::TokenCount(_)
| EventMsg::AgentReasoning(_)
| EventMsg::McpToolCallBegin(_)

View File

@@ -27,6 +27,7 @@ use crate::outgoing_message::OutgoingMessage;
use crate::outgoing_message::OutgoingMessageSender;
pub use crate::codex_tool_config::CodexToolCallParam;
pub use crate::codex_tool_config::CodexToolCallReplyParam;
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
pub use crate::exec_approval::ExecApprovalResponse;
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
@@ -81,7 +82,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
match msg {
JSONRPCMessage::Request(r) => processor.process_request(r).await,
JSONRPCMessage::Response(r) => processor.process_response(r).await,
JSONRPCMessage::Notification(n) => processor.process_notification(n),
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
JSONRPCMessage::Error(e) => processor.process_error(e),
}
}

View File

@@ -10,6 +10,7 @@ use crate::outgoing_message::OutgoingMessageSender;
use codex_core::Codex;
use codex_core::config::Config as CodexConfig;
use codex_core::protocol::Submission;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::ClientRequest;
@@ -35,6 +36,7 @@ pub(crate) struct MessageProcessor {
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
}
impl MessageProcessor {
@@ -49,6 +51,7 @@ impl MessageProcessor {
initialized: false,
codex_linux_sandbox_exe,
session_map: Arc::new(Mutex::new(HashMap::new())),
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
}
}
@@ -116,7 +119,7 @@ impl MessageProcessor {
}
/// Handle a fire-and-forget JSON-RPC notification.
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
let server_notification = match ServerNotification::try_from(notification) {
Ok(n) => n,
Err(e) => {
@@ -129,7 +132,7 @@ impl MessageProcessor {
// handler so additional logic can be implemented incrementally.
match server_notification {
ServerNotification::CancelledNotification(params) => {
self.handle_cancelled_notification(params);
self.handle_cancelled_notification(params).await;
}
ServerNotification::ProgressNotification(params) => {
self.handle_progress_notification(params);
@@ -379,6 +382,7 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let session_map = self.session_map.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
@@ -390,6 +394,7 @@ impl MessageProcessor {
config,
outgoing,
session_map,
running_requests_id_to_codex_uuid,
)
.await;
});
@@ -464,13 +469,12 @@ impl MessageProcessor {
// Clone outgoing and session map to move into async task.
let outgoing = self.outgoing.clone();
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
task::spawn(async move {
let codex = {
let session_map = session_map_mutex.lock().await;
let codex = match session_map.get(&session_id) {
Some(codex) => codex,
match session_map.get(&session_id).cloned() {
Some(c) => c,
None => {
tracing::warn!("Session not found for session_id: {session_id}");
let result = CallToolResult {
@@ -482,21 +486,32 @@ impl MessageProcessor {
is_error: Some(true),
structured_content: None,
};
// unwrap_or_default is fine here because we know the result is valid JSON
outgoing
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
.await;
return;
}
};
}
};
crate::codex_tool_runner::run_codex_tool_session_reply(
codex.clone(),
outgoing,
request_id,
prompt.clone(),
)
.await;
// Spawn the long-running reply handler.
tokio::spawn({
let codex = codex.clone();
let outgoing = outgoing.clone();
let prompt = prompt.clone();
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
async move {
crate::codex_tool_runner::run_codex_tool_session_reply(
codex,
outgoing,
request_id,
prompt,
running_requests_id_to_codex_uuid,
session_id,
)
.await;
}
});
}
@@ -518,11 +533,58 @@ impl MessageProcessor {
// Notification handlers
// ---------------------------------------------------------------------
fn handle_cancelled_notification(
async fn handle_cancelled_notification(
&self,
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/cancelled -> params: {:?}", params);
let request_id = params.request_id;
// Create a stable string form early for logging and submission id.
let request_id_string = match &request_id {
RequestId::String(s) => s.clone(),
RequestId::Integer(i) => i.to_string(),
};
// Obtain the session_id while holding the first lock, then release.
let session_id = {
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
match map_guard.get(&request_id) {
Some(id) => *id, // Uuid is Copy
None => {
tracing::warn!("Session not found for request_id: {}", request_id_string);
return;
}
}
};
tracing::info!("session_id: {session_id}");
// Obtain the Codex Arc while holding the session_map lock, then release.
let codex_arc = {
let sessions_guard = self.session_map.lock().await;
match sessions_guard.get(&session_id) {
Some(codex) => Arc::clone(codex),
None => {
tracing::warn!("Session not found for session_id: {session_id}");
return;
}
}
};
// Submit interrupt to Codex.
let err = codex_arc
.submit_with_id(Submission {
id: request_id_string,
op: codex_core::protocol::Op::Interrupt,
})
.await;
if let Err(e) = err {
tracing::error!("Failed to submit interrupt to Codex: {e}");
return;
}
// unregister the id so we don't keep it in the map
self.running_requests_id_to_codex_uuid
.lock()
.await
.remove(&request_id);
}
fn handle_progress_notification(