diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index a25e0f8b..9171369a 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -799,6 +799,7 @@ dependencies = [ "schemars 0.8.22", "serde", "serde_json", + "shlex", "tokio", "toml 0.9.1", "tracing", diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index cb91bc61..886e4f8b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -18,6 +18,7 @@ use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::Tool; +use serde_json::json; use sha1::Digest; use sha1::Sha1; use tokio::task::JoinSet; @@ -135,7 +136,9 @@ impl McpConnectionManager { experimental: None, roots: None, sampling: None, - elicitation: None, + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + elicitation: Some(json!({})), }, client_info: Implementation { name: "codex-mcp-client".to_owned(), diff --git a/codex-rs/mcp-server/Cargo.toml b/codex-rs/mcp-server/Cargo.toml index f91a3dc8..640a9993 100644 --- a/codex-rs/mcp-server/Cargo.toml +++ b/codex-rs/mcp-server/Cargo.toml @@ -22,6 +22,7 @@ mcp-types = { path = "../mcp-types" } schemars = "0.8.22" serde = { version = "1", features = ["derive"] } serde_json = "1" +shlex = "1.3.0" toml = "0.9" tracing = { version = "0.1.41", features = ["log"] } tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index a20566d6..3036df51 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -4,18 +4,27 @@ use std::sync::Arc; +use codex_core::Codex; use codex_core::codex_wrapper::init_codex; use codex_core::config::Config as CodexConfig; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::EventMsg; +use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::InputItem; use codex_core::protocol::Op; +use codex_core::protocol::ReviewDecision; use codex_core::protocol::Submission; use codex_core::protocol::TaskCompleteEvent; use mcp_types::CallToolResult; use mcp_types::ContentBlock; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::ModelContextProtocolRequest; use mcp_types::RequestId; use mcp_types::TextContent; +use serde::Deserialize; +use serde_json::json; +use tracing::error; use crate::outgoing_message::OutgoingMessageSender; @@ -45,6 +54,7 @@ pub async fn run_codex_tool_session( return; } }; + let codex = Arc::new(codex); // Send initial SessionConfigured event. outgoing.send_event_as_notification(&first_event).await; @@ -58,7 +68,7 @@ pub async fn run_codex_tool_session( }; let submission = Submission { - id: sub_id, + id: sub_id.clone(), op: Op::UserInput { items: vec![InputItem::Text { text: initial_prompt.clone(), @@ -77,18 +87,50 @@ pub async fn run_codex_tool_session( Ok(event) => { outgoing.send_event_as_notification(&event).await; - match &event.msg { - EventMsg::ExecApprovalRequest(_) => { - let result = CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: "EXEC_APPROVAL_REQUIRED".to_string(), - annotations: None, - })], - is_error: None, - structured_content: None, - }; - outgoing.send_response(id.clone(), result.into()).await; + match event.msg { + EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { + command, + cwd, + reason: _, + }) => { + let escaped_command = shlex::try_join(command.iter().map(|s| s.as_str())) + .unwrap_or_else(|_| command.join(" ")); + let message = format!("Allow Codex to run `{escaped_command}` in {cwd:?}?"); + + let params = json!({ + // These fields are required so that `params` + // conforms to ElicitRequestParams. + "message": message, + "requestedSchema": ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + + // These are additional fields the client can use to + // correlate the request with the codex tool call. + "codex_elicitation": "exec-approval", + "codex_mcp_tool_call_id": sub_id, + "codex_event_id": event.id, + "codex_command": command, + // Could convert it to base64 encoded bytes if we + // don't want to use to_string_lossy() here? + "codex_cwd": cwd.to_string_lossy().to_string() + }); + let on_response = outgoing + .send_request(ElicitRequest::METHOD, Some(params)) + .await; + + // Listen for the response on a separate task so we do + // not block the main loop of this function. + { + let codex = codex.clone(); + let event_id = event.id.clone(); + tokio::spawn(async move { + on_exec_approval_response(event_id, on_response, codex).await; + }); + } + break; } EventMsg::ApplyPatchApprovalRequest(_) => { @@ -172,3 +214,42 @@ pub async fn run_codex_tool_session( } } } + +async fn on_exec_approval_response( + event_id: String, + receiver: tokio::sync::oneshot::Receiver, + codex: Arc, +) { + let response = receiver.await; + let value = match response { + Ok(value) => value, + Err(err) => { + error!("request failed: {err:?}"); + return; + } + }; + + // Try to deserialize `value` and then make the appropriate call to `codex`. + let response = match serde_json::from_value::(value) { + Ok(response) => response, + Err(err) => { + error!("failed to deserialize ExecApprovalResponse: {err}"); + return; + } + }; + + if let Err(err) = codex + .submit(Op::ExecApproval { + id: event_id, + decision: response.decision, + }) + .await + { + error!("failed to submit ExecApproval: {err}"); + } +} + +#[derive(Debug, Deserialize)] +pub struct ExecApprovalResponse { + pub decision: ReviewDecision, +} diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index b968b497..3b984ecf 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -72,7 +72,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option) -> IoResult<()> while let Some(msg) = incoming_rx.recv().await { match msg { JSONRPCMessage::Request(r) => processor.process_request(r).await, - JSONRPCMessage::Response(r) => processor.process_response(r), + JSONRPCMessage::Response(r) => processor.process_response(r).await, JSONRPCMessage::Notification(n) => processor.process_notification(n), JSONRPCMessage::Error(e) => processor.process_error(e), } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index aad7f211..d994d8a7 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -101,8 +101,10 @@ impl MessageProcessor { } /// Handle a standalone JSON-RPC response originating from the peer. - pub(crate) fn process_response(&mut self, response: JSONRPCResponse) { + pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) { tracing::info!("<- response: {:?}", response); + let JSONRPCResponse { id, result, .. } = response; + self.outgoing.notify_client_response(id, result).await } /// Handle a fire-and-forget JSON-RPC notification. diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs index 93a760d3..a1eea65f 100644 --- a/codex-rs/mcp-server/src/outgoing_message.rs +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; @@ -12,11 +13,15 @@ use mcp_types::JSONRPCResponse; use mcp_types::RequestId; use mcp_types::Result; use serde::Serialize; +use tokio::sync::Mutex; use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tracing::warn; pub(crate) struct OutgoingMessageSender { next_request_id: AtomicI64, sender: mpsc::Sender, + request_id_to_callback: Mutex>>, } impl OutgoingMessageSender { @@ -24,17 +29,48 @@ impl OutgoingMessageSender { Self { next_request_id: AtomicI64::new(0), sender, + request_id_to_callback: Mutex::new(HashMap::new()), } } - #[allow(dead_code)] - pub(crate) async fn send_request(&self, method: &str, params: Option) { + pub(crate) async fn send_request( + &self, + method: &str, + params: Option, + ) -> oneshot::Receiver { + let id = RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)); + let outgoing_message_id = id.clone(); + let (tx_approve, rx_approve) = oneshot::channel(); + { + let mut request_id_to_callback = self.request_id_to_callback.lock().await; + request_id_to_callback.insert(id, tx_approve); + } + let outgoing_message = OutgoingMessage::Request(OutgoingRequest { - id: RequestId::Integer(self.next_request_id.fetch_add(1, Ordering::Relaxed)), + id: outgoing_message_id, method: method.to_string(), params, }); let _ = self.sender.send(outgoing_message).await; + rx_approve + } + + pub(crate) async fn notify_client_response(&self, id: RequestId, result: Result) { + let entry = { + let mut request_id_to_callback = self.request_id_to_callback.lock().await; + request_id_to_callback.remove_entry(&id) + }; + + match entry { + Some((id, sender)) => { + if let Err(err) = sender.send(result) { + warn!("could not notify callback for {id:?} due to: {err:?}"); + } + } + None => { + warn!("could not find callback for {id:?}"); + } + } } pub(crate) async fn send_response(&self, id: RequestId, result: Result) { diff --git a/codex-rs/mcp-types/generate_mcp_types.py b/codex-rs/mcp-types/generate_mcp_types.py index 224e04c0..38f57e9a 100755 --- a/codex-rs/mcp-types/generate_mcp_types.py +++ b/codex-rs/mcp-types/generate_mcp_types.py @@ -18,6 +18,9 @@ SCHEMA_VERSION = "2025-06-18" JSONRPC_VERSION = "2.0" STANDARD_DERIVE = "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]\n" +STANDARD_HASHABLE_DERIVE = ( + "#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)]\n" +) # Will be populated with the schema's `definitions` map in `main()` so that # helper functions (for example `define_any_of`) can perform look-ups while @@ -391,7 +394,7 @@ def define_string_enum( def define_untagged_enum(name: str, type_list: list[str], out: list[str]) -> None: - out.append(STANDARD_DERIVE) + out.append(STANDARD_HASHABLE_DERIVE) out.append("#[serde(untagged)]\n") out.append(f"pub enum {name} {{\n") for simple_type in type_list: diff --git a/codex-rs/mcp-types/src/lib.rs b/codex-rs/mcp-types/src/lib.rs index 6341fb62..cf09d67e 100644 --- a/codex-rs/mcp-types/src/lib.rs +++ b/codex-rs/mcp-types/src/lib.rs @@ -931,7 +931,7 @@ pub struct ProgressNotificationParams { pub total: Option, } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)] #[serde(untagged)] pub enum ProgressToken { String(String), @@ -1031,7 +1031,7 @@ pub struct Request { pub params: Option, } -#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Hash, Eq)] #[serde(untagged)] pub enum RequestId { String(String),