Adding interrupt Support to MCP (#1646)
This commit is contained in:
@@ -168,7 +168,7 @@ impl CodexToolCallParam {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub(crate) struct CodexToolCallReplyParam {
|
pub struct CodexToolCallReplyParam {
|
||||||
/// The *session id* for this conversation.
|
/// The *session id* for this conversation.
|
||||||
pub session_id: String,
|
pub session_id: String,
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ use mcp_types::CallToolResult;
|
|||||||
use mcp_types::ContentBlock;
|
use mcp_types::ContentBlock;
|
||||||
use mcp_types::RequestId;
|
use mcp_types::RequestId;
|
||||||
use mcp_types::TextContent;
|
use mcp_types::TextContent;
|
||||||
|
use serde_json::json;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ pub async fn run_codex_tool_session(
|
|||||||
config: CodexConfig,
|
config: CodexConfig,
|
||||||
outgoing: Arc<OutgoingMessageSender>,
|
outgoing: Arc<OutgoingMessageSender>,
|
||||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||||
|
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||||
) {
|
) {
|
||||||
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
|
let (codex, first_event, _ctrl_c, session_id) = match init_codex(config).await {
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
@@ -73,7 +75,10 @@ pub async fn run_codex_tool_session(
|
|||||||
RequestId::String(s) => s.clone(),
|
RequestId::String(s) => s.clone(),
|
||||||
RequestId::Integer(n) => n.to_string(),
|
RequestId::Integer(n) => n.to_string(),
|
||||||
};
|
};
|
||||||
|
running_requests_id_to_codex_uuid
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(id.clone(), session_id);
|
||||||
let submission = Submission {
|
let submission = Submission {
|
||||||
id: sub_id.clone(),
|
id: sub_id.clone(),
|
||||||
op: Op::UserInput {
|
op: Op::UserInput {
|
||||||
@@ -85,9 +90,12 @@ pub async fn run_codex_tool_session(
|
|||||||
|
|
||||||
if let Err(e) = codex.submit_with_id(submission).await {
|
if let Err(e) = codex.submit_with_id(submission).await {
|
||||||
tracing::error!("Failed to submit initial prompt: {e}");
|
tracing::error!("Failed to submit initial prompt: {e}");
|
||||||
|
// unregister the id so we don't keep it in the map
|
||||||
|
running_requests_id_to_codex_uuid.lock().await.remove(&id);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
run_codex_tool_session_inner(codex, outgoing, id).await;
|
run_codex_tool_session_inner(codex, outgoing, id, running_requests_id_to_codex_uuid).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_codex_tool_session_reply(
|
pub async fn run_codex_tool_session_reply(
|
||||||
@@ -95,7 +103,13 @@ pub async fn run_codex_tool_session_reply(
|
|||||||
outgoing: Arc<OutgoingMessageSender>,
|
outgoing: Arc<OutgoingMessageSender>,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||||
|
session_id: Uuid,
|
||||||
) {
|
) {
|
||||||
|
running_requests_id_to_codex_uuid
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.insert(request_id.clone(), session_id);
|
||||||
if let Err(e) = codex
|
if let Err(e) = codex
|
||||||
.submit(Op::UserInput {
|
.submit(Op::UserInput {
|
||||||
items: vec![InputItem::Text { text: prompt }],
|
items: vec![InputItem::Text { text: prompt }],
|
||||||
@@ -103,15 +117,28 @@ pub async fn run_codex_tool_session_reply(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
tracing::error!("Failed to submit user input: {e}");
|
tracing::error!("Failed to submit user input: {e}");
|
||||||
|
// unregister the id so we don't keep it in the map
|
||||||
|
running_requests_id_to_codex_uuid
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.remove(&request_id);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
run_codex_tool_session_inner(codex, outgoing, request_id).await;
|
run_codex_tool_session_inner(
|
||||||
|
codex,
|
||||||
|
outgoing,
|
||||||
|
request_id,
|
||||||
|
running_requests_id_to_codex_uuid,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_codex_tool_session_inner(
|
async fn run_codex_tool_session_inner(
|
||||||
codex: Arc<Codex>,
|
codex: Arc<Codex>,
|
||||||
outgoing: Arc<OutgoingMessageSender>,
|
outgoing: Arc<OutgoingMessageSender>,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
|
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||||
) {
|
) {
|
||||||
let request_id_str = match &request_id {
|
let request_id_str = match &request_id {
|
||||||
RequestId::String(s) => s.clone(),
|
RequestId::String(s) => s.clone(),
|
||||||
@@ -143,6 +170,14 @@ async fn run_codex_tool_session_inner(
|
|||||||
.await;
|
.await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
EventMsg::Error(err_event) => {
|
||||||
|
// Return a response to conclude the tool call when the Codex session reports an error (e.g., interruption).
|
||||||
|
let result = json!({
|
||||||
|
"error": err_event.message,
|
||||||
|
});
|
||||||
|
outgoing.send_response(request_id.clone(), result).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||||
reason,
|
reason,
|
||||||
grant_root,
|
grant_root,
|
||||||
@@ -178,6 +213,11 @@ async fn run_codex_tool_session_inner(
|
|||||||
outgoing
|
outgoing
|
||||||
.send_response(request_id.clone(), result.into())
|
.send_response(request_id.clone(), result.into())
|
||||||
.await;
|
.await;
|
||||||
|
// unregister the id so we don't keep it in the map
|
||||||
|
running_requests_id_to_codex_uuid
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.remove(&request_id);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
EventMsg::SessionConfigured(_) => {
|
EventMsg::SessionConfigured(_) => {
|
||||||
@@ -192,8 +232,7 @@ async fn run_codex_tool_session_inner(
|
|||||||
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
EventMsg::AgentMessage(AgentMessageEvent { .. }) => {
|
||||||
// TODO: think how we want to support this in the MCP
|
// TODO: think how we want to support this in the MCP
|
||||||
}
|
}
|
||||||
EventMsg::Error(_)
|
EventMsg::TaskStarted
|
||||||
| EventMsg::TaskStarted
|
|
||||||
| EventMsg::TokenCount(_)
|
| EventMsg::TokenCount(_)
|
||||||
| EventMsg::AgentReasoning(_)
|
| EventMsg::AgentReasoning(_)
|
||||||
| EventMsg::McpToolCallBegin(_)
|
| EventMsg::McpToolCallBegin(_)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ use crate::outgoing_message::OutgoingMessage;
|
|||||||
use crate::outgoing_message::OutgoingMessageSender;
|
use crate::outgoing_message::OutgoingMessageSender;
|
||||||
|
|
||||||
pub use crate::codex_tool_config::CodexToolCallParam;
|
pub use crate::codex_tool_config::CodexToolCallParam;
|
||||||
|
pub use crate::codex_tool_config::CodexToolCallReplyParam;
|
||||||
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
||||||
pub use crate::exec_approval::ExecApprovalResponse;
|
pub use crate::exec_approval::ExecApprovalResponse;
|
||||||
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
||||||
@@ -81,7 +82,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
|||||||
match msg {
|
match msg {
|
||||||
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
JSONRPCMessage::Request(r) => processor.process_request(r).await,
|
||||||
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
JSONRPCMessage::Response(r) => processor.process_response(r).await,
|
||||||
JSONRPCMessage::Notification(n) => processor.process_notification(n),
|
JSONRPCMessage::Notification(n) => processor.process_notification(n).await,
|
||||||
JSONRPCMessage::Error(e) => processor.process_error(e),
|
JSONRPCMessage::Error(e) => processor.process_error(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use crate::outgoing_message::OutgoingMessageSender;
|
|||||||
|
|
||||||
use codex_core::Codex;
|
use codex_core::Codex;
|
||||||
use codex_core::config::Config as CodexConfig;
|
use codex_core::config::Config as CodexConfig;
|
||||||
|
use codex_core::protocol::Submission;
|
||||||
use mcp_types::CallToolRequestParams;
|
use mcp_types::CallToolRequestParams;
|
||||||
use mcp_types::CallToolResult;
|
use mcp_types::CallToolResult;
|
||||||
use mcp_types::ClientRequest;
|
use mcp_types::ClientRequest;
|
||||||
@@ -35,6 +36,7 @@ pub(crate) struct MessageProcessor {
|
|||||||
initialized: bool,
|
initialized: bool,
|
||||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||||
|
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessageProcessor {
|
impl MessageProcessor {
|
||||||
@@ -49,6 +51,7 @@ impl MessageProcessor {
|
|||||||
initialized: false,
|
initialized: false,
|
||||||
codex_linux_sandbox_exe,
|
codex_linux_sandbox_exe,
|
||||||
session_map: Arc::new(Mutex::new(HashMap::new())),
|
session_map: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +119,7 @@ impl MessageProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a fire-and-forget JSON-RPC notification.
|
/// Handle a fire-and-forget JSON-RPC notification.
|
||||||
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
|
pub(crate) async fn process_notification(&mut self, notification: JSONRPCNotification) {
|
||||||
let server_notification = match ServerNotification::try_from(notification) {
|
let server_notification = match ServerNotification::try_from(notification) {
|
||||||
Ok(n) => n,
|
Ok(n) => n,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -129,7 +132,7 @@ impl MessageProcessor {
|
|||||||
// handler so additional logic can be implemented incrementally.
|
// handler so additional logic can be implemented incrementally.
|
||||||
match server_notification {
|
match server_notification {
|
||||||
ServerNotification::CancelledNotification(params) => {
|
ServerNotification::CancelledNotification(params) => {
|
||||||
self.handle_cancelled_notification(params);
|
self.handle_cancelled_notification(params).await;
|
||||||
}
|
}
|
||||||
ServerNotification::ProgressNotification(params) => {
|
ServerNotification::ProgressNotification(params) => {
|
||||||
self.handle_progress_notification(params);
|
self.handle_progress_notification(params);
|
||||||
@@ -379,6 +382,7 @@ impl MessageProcessor {
|
|||||||
// Clone outgoing and session map to move into async task.
|
// Clone outgoing and session map to move into async task.
|
||||||
let outgoing = self.outgoing.clone();
|
let outgoing = self.outgoing.clone();
|
||||||
let session_map = self.session_map.clone();
|
let session_map = self.session_map.clone();
|
||||||
|
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||||
|
|
||||||
// Spawn an async task to handle the Codex session so that we do not
|
// Spawn an async task to handle the Codex session so that we do not
|
||||||
// block the synchronous message-processing loop.
|
// block the synchronous message-processing loop.
|
||||||
@@ -390,6 +394,7 @@ impl MessageProcessor {
|
|||||||
config,
|
config,
|
||||||
outgoing,
|
outgoing,
|
||||||
session_map,
|
session_map,
|
||||||
|
running_requests_id_to_codex_uuid,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
});
|
});
|
||||||
@@ -464,13 +469,12 @@ impl MessageProcessor {
|
|||||||
|
|
||||||
// Clone outgoing and session map to move into async task.
|
// Clone outgoing and session map to move into async task.
|
||||||
let outgoing = self.outgoing.clone();
|
let outgoing = self.outgoing.clone();
|
||||||
|
let running_requests_id_to_codex_uuid = self.running_requests_id_to_codex_uuid.clone();
|
||||||
|
|
||||||
// Spawn an async task to handle the Codex session so that we do not
|
let codex = {
|
||||||
// block the synchronous message-processing loop.
|
|
||||||
task::spawn(async move {
|
|
||||||
let session_map = session_map_mutex.lock().await;
|
let session_map = session_map_mutex.lock().await;
|
||||||
let codex = match session_map.get(&session_id) {
|
match session_map.get(&session_id).cloned() {
|
||||||
Some(codex) => codex,
|
Some(c) => c,
|
||||||
None => {
|
None => {
|
||||||
tracing::warn!("Session not found for session_id: {session_id}");
|
tracing::warn!("Session not found for session_id: {session_id}");
|
||||||
let result = CallToolResult {
|
let result = CallToolResult {
|
||||||
@@ -482,21 +486,32 @@ impl MessageProcessor {
|
|||||||
is_error: Some(true),
|
is_error: Some(true),
|
||||||
structured_content: None,
|
structured_content: None,
|
||||||
};
|
};
|
||||||
// unwrap_or_default is fine here because we know the result is valid JSON
|
|
||||||
outgoing
|
outgoing
|
||||||
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
|
.send_response(request_id, serde_json::to_value(result).unwrap_or_default())
|
||||||
.await;
|
.await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
crate::codex_tool_runner::run_codex_tool_session_reply(
|
// Spawn the long-running reply handler.
|
||||||
codex.clone(),
|
tokio::spawn({
|
||||||
outgoing,
|
let codex = codex.clone();
|
||||||
request_id,
|
let outgoing = outgoing.clone();
|
||||||
prompt.clone(),
|
let prompt = prompt.clone();
|
||||||
)
|
let running_requests_id_to_codex_uuid = running_requests_id_to_codex_uuid.clone();
|
||||||
.await;
|
|
||||||
|
async move {
|
||||||
|
crate::codex_tool_runner::run_codex_tool_session_reply(
|
||||||
|
codex,
|
||||||
|
outgoing,
|
||||||
|
request_id,
|
||||||
|
prompt,
|
||||||
|
running_requests_id_to_codex_uuid,
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -518,11 +533,58 @@ impl MessageProcessor {
|
|||||||
// Notification handlers
|
// Notification handlers
|
||||||
// ---------------------------------------------------------------------
|
// ---------------------------------------------------------------------
|
||||||
|
|
||||||
fn handle_cancelled_notification(
|
async fn handle_cancelled_notification(
|
||||||
&self,
|
&self,
|
||||||
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
|
||||||
) {
|
) {
|
||||||
tracing::info!("notifications/cancelled -> params: {:?}", params);
|
let request_id = params.request_id;
|
||||||
|
// Create a stable string form early for logging and submission id.
|
||||||
|
let request_id_string = match &request_id {
|
||||||
|
RequestId::String(s) => s.clone(),
|
||||||
|
RequestId::Integer(i) => i.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Obtain the session_id while holding the first lock, then release.
|
||||||
|
let session_id = {
|
||||||
|
let map_guard = self.running_requests_id_to_codex_uuid.lock().await;
|
||||||
|
match map_guard.get(&request_id) {
|
||||||
|
Some(id) => *id, // Uuid is Copy
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Session not found for request_id: {}", request_id_string);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!("session_id: {session_id}");
|
||||||
|
|
||||||
|
// Obtain the Codex Arc while holding the session_map lock, then release.
|
||||||
|
let codex_arc = {
|
||||||
|
let sessions_guard = self.session_map.lock().await;
|
||||||
|
match sessions_guard.get(&session_id) {
|
||||||
|
Some(codex) => Arc::clone(codex),
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Session not found for session_id: {session_id}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Submit interrupt to Codex.
|
||||||
|
let err = codex_arc
|
||||||
|
.submit_with_id(Submission {
|
||||||
|
id: request_id_string,
|
||||||
|
op: codex_core::protocol::Op::Interrupt,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
if let Err(e) = err {
|
||||||
|
tracing::error!("Failed to submit interrupt to Codex: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// unregister the id so we don't keep it in the map
|
||||||
|
self.running_requests_id_to_codex_uuid
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.remove(&request_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_progress_notification(
|
fn handle_progress_notification(
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ use tokio::process::ChildStdout;
|
|||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use assert_cmd::prelude::*;
|
use assert_cmd::prelude::*;
|
||||||
use codex_mcp_server::CodexToolCallParam;
|
use codex_mcp_server::CodexToolCallParam;
|
||||||
|
use codex_mcp_server::CodexToolCallReplyParam;
|
||||||
use mcp_types::CallToolRequestParams;
|
use mcp_types::CallToolRequestParams;
|
||||||
use mcp_types::ClientCapabilities;
|
use mcp_types::ClientCapabilities;
|
||||||
use mcp_types::Implementation;
|
use mcp_types::Implementation;
|
||||||
@@ -154,6 +155,25 @@ impl McpProcess {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn send_codex_reply_tool_call(
|
||||||
|
&mut self,
|
||||||
|
session_id: &str,
|
||||||
|
prompt: &str,
|
||||||
|
) -> anyhow::Result<i64> {
|
||||||
|
let codex_tool_call_params = CallToolRequestParams {
|
||||||
|
name: "codex-reply".to_string(),
|
||||||
|
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
|
||||||
|
prompt: prompt.to_string(),
|
||||||
|
session_id: session_id.to_string(),
|
||||||
|
})?),
|
||||||
|
};
|
||||||
|
self.send_request(
|
||||||
|
mcp_types::CallToolRequest::METHOD,
|
||||||
|
Some(serde_json::to_value(codex_tool_call_params)?),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
async fn send_request(
|
async fn send_request(
|
||||||
&mut self,
|
&mut self,
|
||||||
method: &str,
|
method: &str,
|
||||||
@@ -171,6 +191,8 @@ impl McpProcess {
|
|||||||
Ok(request_id)
|
Ok(request_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
pub async fn send_response(
|
pub async fn send_response(
|
||||||
&mut self,
|
&mut self,
|
||||||
id: RequestId,
|
id: RequestId,
|
||||||
@@ -198,7 +220,8 @@ impl McpProcess {
|
|||||||
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
||||||
Ok(message)
|
Ok(message)
|
||||||
}
|
}
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
||||||
loop {
|
loop {
|
||||||
let message = self.read_jsonrpc_message().await?;
|
let message = self.read_jsonrpc_message().await?;
|
||||||
@@ -221,6 +244,8 @@ impl McpProcess {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
pub async fn read_stream_until_response_message(
|
pub async fn read_stream_until_response_message(
|
||||||
&mut self,
|
&mut self,
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
@@ -247,4 +272,58 @@ impl McpProcess {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn read_stream_until_configured_response_message(
|
||||||
|
&mut self,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
loop {
|
||||||
|
let message = self.read_jsonrpc_message().await?;
|
||||||
|
eprint!("message: {message:?}");
|
||||||
|
|
||||||
|
match message {
|
||||||
|
JSONRPCMessage::Notification(notification) => {
|
||||||
|
if notification.method == "codex/event" {
|
||||||
|
if let Some(params) = notification.params {
|
||||||
|
if let Some(msg) = params.get("msg") {
|
||||||
|
if let Some(msg_type) = msg.get("type") {
|
||||||
|
if msg_type == "session_configured" {
|
||||||
|
if let Some(session_id) = msg.get("session_id") {
|
||||||
|
return Ok(session_id
|
||||||
|
.to_string()
|
||||||
|
.trim_matches('"')
|
||||||
|
.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
JSONRPCMessage::Request(_) => {
|
||||||
|
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
||||||
|
}
|
||||||
|
JSONRPCMessage::Error(_) => {
|
||||||
|
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
||||||
|
}
|
||||||
|
JSONRPCMessage::Response(_) => {
|
||||||
|
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn send_notification(
|
||||||
|
&mut self,
|
||||||
|
method: &str,
|
||||||
|
params: Option<serde_json::Value>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
||||||
|
jsonrpc: JSONRPC_VERSION.into(),
|
||||||
|
method: method.to_string(),
|
||||||
|
params,
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ mod responses;
|
|||||||
|
|
||||||
pub use mcp_process::McpProcess;
|
pub use mcp_process::McpProcess;
|
||||||
pub use mock_model_server::create_mock_chat_completions_server;
|
pub use mock_model_server::create_mock_chat_completions_server;
|
||||||
|
#[allow(unused_imports)]
|
||||||
pub use responses::create_apply_patch_sse_response;
|
pub use responses::create_apply_patch_sse_response;
|
||||||
|
#[allow(unused_imports)]
|
||||||
pub use responses::create_final_assistant_message_sse_response;
|
pub use responses::create_final_assistant_message_sse_response;
|
||||||
pub use responses::create_shell_sse_response;
|
pub use responses::create_shell_sse_response;
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ pub fn create_shell_sse_response(
|
|||||||
Ok(sse)
|
Ok(sse)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Result<String> {
|
||||||
let assistant_message = json!({
|
let assistant_message = json!({
|
||||||
"choices": [
|
"choices": [
|
||||||
@@ -58,6 +60,8 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res
|
|||||||
Ok(sse)
|
Ok(sse)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// allow dead code
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn create_apply_patch_sse_response(
|
pub fn create_apply_patch_sse_response(
|
||||||
patch_content: &str,
|
patch_content: &str,
|
||||||
call_id: &str,
|
call_id: &str,
|
||||||
|
|||||||
176
codex-rs/mcp-server/tests/interrupt.rs
Normal file
176
codex-rs/mcp-server/tests/interrupt.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
#![cfg(unix)]
|
||||||
|
mod common;
|
||||||
|
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||||
|
use codex_mcp_server::CodexToolCallParam;
|
||||||
|
use mcp_types::JSONRPCResponse;
|
||||||
|
use mcp_types::RequestId;
|
||||||
|
use serde_json::json;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::common::McpProcess;
|
||||||
|
use crate::common::create_mock_chat_completions_server;
|
||||||
|
use crate::common::create_shell_sse_response;
|
||||||
|
|
||||||
|
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn test_shell_command_interruption() {
|
||||||
|
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||||
|
println!(
|
||||||
|
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = shell_command_interruption().await {
|
||||||
|
panic!("failure: {err}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn shell_command_interruption() -> anyhow::Result<()> {
|
||||||
|
// Use a cross-platform blocking command. On Windows plain `sleep` is not guaranteed to exist
|
||||||
|
// (MSYS/GNU coreutils may be absent) and the failure causes the tool call to finish immediately,
|
||||||
|
// which triggers a second model request before the test sends the explicit follow-up. That
|
||||||
|
// prematurely consumes the second mocked SSE response and leads to a third POST (panic: no response for 2).
|
||||||
|
// Powershell Start-Sleep is always available on Windows runners. On Unix we keep using `sleep`.
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
let shell_command = vec![
|
||||||
|
"powershell".to_string(),
|
||||||
|
"-Command".to_string(),
|
||||||
|
"Start-Sleep -Seconds 60".to_string(),
|
||||||
|
];
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
let shell_command = vec!["sleep".to_string(), "60".to_string()];
|
||||||
|
let workdir_for_shell_function_call = TempDir::new()?;
|
||||||
|
|
||||||
|
// Create mock server with a single SSE response: the long sleep command
|
||||||
|
let server = create_mock_chat_completions_server(vec![
|
||||||
|
create_shell_sse_response(
|
||||||
|
shell_command.clone(),
|
||||||
|
Some(workdir_for_shell_function_call.path()),
|
||||||
|
Some(60_000), // 60 seconds timeout in ms
|
||||||
|
"call_sleep",
|
||||||
|
)?,
|
||||||
|
create_shell_sse_response(
|
||||||
|
shell_command.clone(),
|
||||||
|
Some(workdir_for_shell_function_call.path()),
|
||||||
|
Some(60_000), // 60 seconds timeout in ms
|
||||||
|
"call_sleep",
|
||||||
|
)?,
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Create Codex configuration
|
||||||
|
let codex_home = TempDir::new()?;
|
||||||
|
create_config_toml(codex_home.path(), server.uri())?;
|
||||||
|
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||||
|
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||||
|
|
||||||
|
// Send codex tool call that triggers "sleep 60"
|
||||||
|
let codex_request_id = mcp_process
|
||||||
|
.send_codex_tool_call(CodexToolCallParam {
|
||||||
|
cwd: None,
|
||||||
|
prompt: "First Run: run `sleep 60`".to_string(),
|
||||||
|
model: None,
|
||||||
|
profile: None,
|
||||||
|
approval_policy: None,
|
||||||
|
sandbox: None,
|
||||||
|
config: None,
|
||||||
|
base_instructions: None,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let session_id = mcp_process
|
||||||
|
.read_stream_until_configured_response_message()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Give the command a moment to start
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
// Send interrupt notification
|
||||||
|
mcp_process
|
||||||
|
.send_notification(
|
||||||
|
"notifications/cancelled",
|
||||||
|
Some(json!({ "requestId": codex_request_id })),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Expect Codex to return an error or interruption response
|
||||||
|
let codex_response: JSONRPCResponse = timeout(
|
||||||
|
DEFAULT_READ_TIMEOUT,
|
||||||
|
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||||
|
)
|
||||||
|
.await??;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
codex_response
|
||||||
|
.result
|
||||||
|
.as_object()
|
||||||
|
.map(|o| o.contains_key("error"))
|
||||||
|
.unwrap_or(false),
|
||||||
|
"Expected an interruption or error result, got: {codex_response:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let codex_reply_request_id = mcp_process
|
||||||
|
.send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Give the command a moment to start
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
// Send interrupt notification
|
||||||
|
mcp_process
|
||||||
|
.send_notification(
|
||||||
|
"notifications/cancelled",
|
||||||
|
Some(json!({ "requestId": codex_reply_request_id })),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Expect Codex to return an error or interruption response
|
||||||
|
let codex_response: JSONRPCResponse = timeout(
|
||||||
|
DEFAULT_READ_TIMEOUT,
|
||||||
|
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_reply_request_id)),
|
||||||
|
)
|
||||||
|
.await??;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
codex_response
|
||||||
|
.result
|
||||||
|
.as_object()
|
||||||
|
.map(|o| o.contains_key("error"))
|
||||||
|
.unwrap_or(false),
|
||||||
|
"Expected an interruption or error result, got: {codex_response:?}"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
|
||||||
|
let config_toml = codex_home.join("config.toml");
|
||||||
|
std::fs::write(
|
||||||
|
config_toml,
|
||||||
|
format!(
|
||||||
|
r#"
|
||||||
|
model = "mock-model"
|
||||||
|
approval_policy = "never"
|
||||||
|
sandbox_mode = "danger-full-access"
|
||||||
|
|
||||||
|
model_provider = "mock_provider"
|
||||||
|
|
||||||
|
[model_providers.mock_provider]
|
||||||
|
name = "Mock provider for test"
|
||||||
|
base_url = "{server_uri}/v1"
|
||||||
|
wire_api = "chat"
|
||||||
|
request_max_retries = 0
|
||||||
|
stream_max_retries = 0
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user