diff --git a/codex-rs/mcp-server/src/codex_tool_runner.rs b/codex-rs/mcp-server/src/codex_tool_runner.rs index 3893a485..df2154dd 100644 --- a/codex-rs/mcp-server/src/codex_tool_runner.rs +++ b/codex-rs/mcp-server/src/codex_tool_runner.rs @@ -3,38 +3,31 @@ //! and to make future feature-growth easier to manage. use std::collections::HashMap; -use std::path::PathBuf; use std::sync::Arc; use codex_core::Codex; use codex_core::codex_wrapper::init_codex; use codex_core::config::Config as CodexConfig; use codex_core::protocol::AgentMessageEvent; +use codex_core::protocol::ApplyPatchApprovalRequestEvent; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::InputItem; use codex_core::protocol::Op; -use codex_core::protocol::ReviewDecision; use codex_core::protocol::Submission; use codex_core::protocol::TaskCompleteEvent; use mcp_types::CallToolResult; use mcp_types::ContentBlock; -use mcp_types::ElicitRequest; -use mcp_types::ElicitRequestParamsRequestedSchema; -use mcp_types::JSONRPCErrorError; -use mcp_types::ModelContextProtocolRequest; use mcp_types::RequestId; use mcp_types::TextContent; -use serde::Deserialize; -use serde::Serialize; -use serde_json::json; use tokio::sync::Mutex; -use tracing::error; use uuid::Uuid; +use crate::exec_approval::handle_exec_approval_request; use crate::outgoing_message::OutgoingMessageSender; +use crate::patch_approval::handle_patch_approval_request; -const INVALID_PARAMS_ERROR_CODE: i64 = -32602; +pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602; /// Run a complete Codex session and stream events back to the client. /// @@ -120,7 +113,7 @@ async fn run_codex_tool_session_inner( outgoing: Arc, request_id: RequestId, ) { - let sub_id = match &request_id { + let request_id_str = match &request_id { RequestId::String(s) => s.clone(), RequestId::Integer(n) => n.to_string(), }; @@ -138,80 +131,34 @@ async fn run_codex_tool_session_inner( cwd, reason: _, }) => { - let escaped_command = shlex::try_join(command.iter().map(|s| s.as_str())) - .unwrap_or_else(|_| command.join(" ")); - let message = format!( - "Allow Codex to run `{escaped_command}` in `{cwd}`?", - cwd = cwd.to_string_lossy() - ); - - let params = ExecApprovalElicitRequestParams { - message, - requested_schema: ElicitRequestParamsRequestedSchema { - r#type: "object".to_string(), - properties: json!({}), - required: None, - }, - codex_elicitation: "exec-approval".to_string(), - codex_mcp_tool_call_id: sub_id.clone(), - codex_event_id: event.id.clone(), - codex_command: command, - codex_cwd: cwd, - }; - let params_json = match serde_json::to_value(¶ms) { - Ok(value) => value, - Err(err) => { - let message = format!( - "Failed to serialize ExecApprovalElicitRequestParams: {err}" - ); - tracing::error!("{message}"); - - outgoing - .send_error( - request_id.clone(), - JSONRPCErrorError { - code: INVALID_PARAMS_ERROR_CODE, - message, - data: None, - }, - ) - .await; - - continue; - } - }; - - let on_response = outgoing - .send_request(ElicitRequest::METHOD, Some(params_json)) - .await; - - // Listen for the response on a separate task so we do - // not block the main loop of this function. - { - let codex = codex.clone(); - let event_id = event.id.clone(); - tokio::spawn(async move { - on_exec_approval_response(event_id, on_response, codex).await; - }); - } - - // Continue, don't break so the session continues. + handle_exec_approval_request( + command, + cwd, + outgoing.clone(), + codex.clone(), + request_id.clone(), + request_id_str.clone(), + event.id.clone(), + ) + .await; continue; } - EventMsg::ApplyPatchApprovalRequest(_) => { - let result = CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: "PATCH_APPROVAL_REQUIRED".to_string(), - annotations: None, - })], - is_error: None, - structured_content: None, - }; - outgoing - .send_response(request_id.clone(), result.into()) - .await; - // Continue, don't break so the session continues. + EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { + reason, + grant_root, + changes, + }) => { + handle_patch_approval_request( + reason, + grant_root, + changes, + outgoing.clone(), + codex.clone(), + request_id.clone(), + request_id_str.clone(), + event.id.clone(), + ) + .await; continue; } EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => { @@ -286,71 +233,3 @@ async fn run_codex_tool_session_inner( } } } - -async fn on_exec_approval_response( - event_id: String, - receiver: tokio::sync::oneshot::Receiver, - codex: Arc, -) { - let response = receiver.await; - let value = match response { - Ok(value) => value, - Err(err) => { - error!("request failed: {err:?}"); - return; - } - }; - - // Try to deserialize `value` and then make the appropriate call to `codex`. - let response = match serde_json::from_value::(value) { - Ok(response) => response, - Err(err) => { - error!("failed to deserialize ExecApprovalResponse: {err}"); - // If we cannot deserialize the response, we deny the request to be - // conservative. - ExecApprovalResponse { - decision: ReviewDecision::Denied, - } - } - }; - - if let Err(err) = codex - .submit(Op::ExecApproval { - id: event_id, - decision: response.decision, - }) - .await - { - error!("failed to submit ExecApproval: {err}"); - } -} - -// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See: -// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636 -// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages -// It should have "action" and "content" fields. - -#[derive(Debug, Serialize, Deserialize)] -pub struct ExecApprovalResponse { - pub decision: ReviewDecision, -} - -/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the -/// `params` field of an [`mcp_types::ElicitRequest`]. -#[derive(Debug, Serialize)] -pub struct ExecApprovalElicitRequestParams { - // These fields are required so that `params` - // conforms to ElicitRequestParams. - pub message: String, - - #[serde(rename = "requestedSchema")] - pub requested_schema: ElicitRequestParamsRequestedSchema, - - // These are additional fields the client can use to - // correlate the request with the codex tool call. - pub codex_elicitation: String, - pub codex_mcp_tool_call_id: String, - pub codex_event_id: String, - pub codex_command: Vec, - pub codex_cwd: PathBuf, -} diff --git a/codex-rs/mcp-server/src/exec_approval.rs b/codex-rs/mcp-server/src/exec_approval.rs new file mode 100644 index 00000000..fc0c41d0 --- /dev/null +++ b/codex-rs/mcp-server/src/exec_approval.rs @@ -0,0 +1,145 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::protocol::Op; +use codex_core::protocol::ReviewDecision; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::JSONRPCErrorError; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use serde::Deserialize; +use serde::Serialize; +use serde_json::json; +use tracing::error; + +use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE; + +/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the +/// `params` field of an [`ElicitRequest`]. +#[derive(Debug, Serialize)] +pub struct ExecApprovalElicitRequestParams { + // These fields are required so that `params` + // conforms to ElicitRequestParams. + pub message: String, + + #[serde(rename = "requestedSchema")] + pub requested_schema: ElicitRequestParamsRequestedSchema, + + // These are additional fields the client can use to + // correlate the request with the codex tool call. + pub codex_elicitation: String, + pub codex_mcp_tool_call_id: String, + pub codex_event_id: String, + pub codex_command: Vec, + pub codex_cwd: PathBuf, +} + +// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See: +// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636 +// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages +// It should have "action" and "content" fields. +#[derive(Debug, Serialize, Deserialize)] +pub struct ExecApprovalResponse { + pub decision: ReviewDecision, +} + +pub(crate) async fn handle_exec_approval_request( + command: Vec, + cwd: PathBuf, + outgoing: Arc, + codex: Arc, + request_id: RequestId, + tool_call_id: String, + event_id: String, +) { + let escaped_command = + shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" ")); + let message = format!( + "Allow Codex to run `{escaped_command}` in `{cwd}`?", + cwd = cwd.to_string_lossy() + ); + + let params = ExecApprovalElicitRequestParams { + message, + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "exec-approval".to_string(), + codex_mcp_tool_call_id: tool_call_id.clone(), + codex_event_id: event_id.clone(), + codex_command: command, + codex_cwd: cwd, + }; + let params_json = match serde_json::to_value(¶ms) { + Ok(value) => value, + Err(err) => { + let message = format!("Failed to serialize ExecApprovalElicitRequestParams: {err}"); + error!("{message}"); + + outgoing + .send_error( + request_id.clone(), + JSONRPCErrorError { + code: INVALID_PARAMS_ERROR_CODE, + message, + data: None, + }, + ) + .await; + + return; + } + }; + + let on_response = outgoing + .send_request(ElicitRequest::METHOD, Some(params_json)) + .await; + + // Listen for the response on a separate task so we don't block the main agent loop. + { + let codex = codex.clone(); + let event_id = event_id.clone(); + tokio::spawn(async move { + on_exec_approval_response(event_id, on_response, codex).await; + }); + } +} + +async fn on_exec_approval_response( + event_id: String, + receiver: tokio::sync::oneshot::Receiver, + codex: Arc, +) { + let response = receiver.await; + let value = match response { + Ok(value) => value, + Err(err) => { + error!("request failed: {err:?}"); + return; + } + }; + + // Try to deserialize `value` and then make the appropriate call to `codex`. + let response = serde_json::from_value::(value).unwrap_or_else(|err| { + error!("failed to deserialize ExecApprovalResponse: {err}"); + // If we cannot deserialize the response, we deny the request to be + // conservative. + ExecApprovalResponse { + decision: ReviewDecision::Denied, + } + }); + + if let Err(err) = codex + .submit(Op::ExecApproval { + id: event_id, + decision: response.decision, + }) + .await + { + error!("failed to submit ExecApproval: {err}"); + } +} diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index 1f1ecc3f..300d1b5f 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -16,17 +16,21 @@ use tracing::info; mod codex_tool_config; mod codex_tool_runner; +mod exec_approval; mod json_to_toml; mod message_processor; mod outgoing_message; +mod patch_approval; use crate::message_processor::MessageProcessor; use crate::outgoing_message::OutgoingMessage; use crate::outgoing_message::OutgoingMessageSender; pub use crate::codex_tool_config::CodexToolCallParam; -pub use crate::codex_tool_runner::ExecApprovalElicitRequestParams; -pub use crate::codex_tool_runner::ExecApprovalResponse; +pub use crate::exec_approval::ExecApprovalElicitRequestParams; +pub use crate::exec_approval::ExecApprovalResponse; +pub use crate::patch_approval::PatchApprovalElicitRequestParams; +pub use crate::patch_approval::PatchApprovalResponse; /// Size of the bounded channels used to communicate between tasks. The value /// is a balance between throughput and memory usage – 128 messages should be diff --git a/codex-rs/mcp-server/src/patch_approval.rs b/codex-rs/mcp-server/src/patch_approval.rs new file mode 100644 index 00000000..bfccfa50 --- /dev/null +++ b/codex-rs/mcp-server/src/patch_approval.rs @@ -0,0 +1,147 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use codex_core::Codex; +use codex_core::protocol::FileChange; +use codex_core::protocol::Op; +use codex_core::protocol::ReviewDecision; +use mcp_types::ElicitRequest; +use mcp_types::ElicitRequestParamsRequestedSchema; +use mcp_types::JSONRPCErrorError; +use mcp_types::ModelContextProtocolRequest; +use mcp_types::RequestId; +use serde::Deserialize; +use serde::Serialize; +use serde_json::json; +use tracing::error; + +use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE; +use crate::outgoing_message::OutgoingMessageSender; + +#[derive(Debug, Serialize)] +pub struct PatchApprovalElicitRequestParams { + pub message: String, + #[serde(rename = "requestedSchema")] + pub requested_schema: ElicitRequestParamsRequestedSchema, + pub codex_elicitation: String, + pub codex_mcp_tool_call_id: String, + pub codex_event_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub codex_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub codex_grant_root: Option, + pub codex_changes: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct PatchApprovalResponse { + pub decision: ReviewDecision, +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn handle_patch_approval_request( + reason: Option, + grant_root: Option, + changes: HashMap, + outgoing: Arc, + codex: Arc, + request_id: RequestId, + tool_call_id: String, + event_id: String, +) { + let mut message_lines = Vec::new(); + if let Some(r) = &reason { + message_lines.push(r.clone()); + } + message_lines.push("Allow Codex to apply proposed code changes?".to_string()); + + let params = PatchApprovalElicitRequestParams { + message: message_lines.join("\n"), + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "patch-approval".to_string(), + codex_mcp_tool_call_id: tool_call_id.clone(), + codex_event_id: event_id.clone(), + codex_reason: reason, + codex_grant_root: grant_root, + codex_changes: changes, + }; + let params_json = match serde_json::to_value(¶ms) { + Ok(value) => value, + Err(err) => { + let message = format!("Failed to serialize PatchApprovalElicitRequestParams: {err}"); + error!("{message}"); + + outgoing + .send_error( + request_id.clone(), + JSONRPCErrorError { + code: INVALID_PARAMS_ERROR_CODE, + message, + data: None, + }, + ) + .await; + + return; + } + }; + + let on_response = outgoing + .send_request(ElicitRequest::METHOD, Some(params_json)) + .await; + + // Listen for the response on a separate task so we don't block the main agent loop. + { + let codex = codex.clone(); + let event_id = event_id.clone(); + tokio::spawn(async move { + on_patch_approval_response(event_id, on_response, codex).await; + }); + } +} + +pub(crate) async fn on_patch_approval_response( + event_id: String, + receiver: tokio::sync::oneshot::Receiver, + codex: Arc, +) { + let response = receiver.await; + let value = match response { + Ok(value) => value, + Err(err) => { + error!("request failed: {err:?}"); + if let Err(submit_err) = codex + .submit(Op::PatchApproval { + id: event_id.clone(), + decision: ReviewDecision::Denied, + }) + .await + { + error!("failed to submit denied PatchApproval after request failure: {submit_err}"); + } + return; + } + }; + + let response = serde_json::from_value::(value).unwrap_or_else(|err| { + error!("failed to deserialize PatchApprovalResponse: {err}"); + PatchApprovalResponse { + decision: ReviewDecision::Denied, + } + }); + + if let Err(err) = codex + .submit(Op::PatchApproval { + id: event_id, + decision: response.decision, + }) + .await + { + error!("failed to submit PatchApproval: {err}"); + } +} diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index 42d15f78..df9cc98a 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -139,14 +139,18 @@ impl McpProcess { /// Returns the id used to make the request so it can be used when /// correlating notifications. - pub async fn send_codex_tool_call(&mut self, prompt: &str) -> anyhow::Result { + pub async fn send_codex_tool_call( + &mut self, + cwd: Option, + prompt: &str, + ) -> anyhow::Result { let codex_tool_call_params = CallToolRequestParams { name: "codex".to_string(), arguments: Some(serde_json::to_value(CodexToolCallParam { + cwd, prompt: prompt.to_string(), model: None, profile: None, - cwd: None, approval_policy: None, sandbox: None, config: None, diff --git a/codex-rs/mcp-server/tests/common/mod.rs b/codex-rs/mcp-server/tests/common/mod.rs index 61a5774b..b338e2e8 100644 --- a/codex-rs/mcp-server/tests/common/mod.rs +++ b/codex-rs/mcp-server/tests/common/mod.rs @@ -4,5 +4,6 @@ mod responses; pub use mcp_process::McpProcess; pub use mock_model_server::create_mock_chat_completions_server; +pub use responses::create_apply_patch_sse_response; pub use responses::create_final_assistant_message_sse_response; pub use responses::create_shell_sse_response; diff --git a/codex-rs/mcp-server/tests/common/responses.rs b/codex-rs/mcp-server/tests/common/responses.rs index a11c72d0..9a827fb9 100644 --- a/codex-rs/mcp-server/tests/common/responses.rs +++ b/codex-rs/mcp-server/tests/common/responses.rs @@ -57,3 +57,39 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res ); Ok(sse) } + +pub fn create_apply_patch_sse_response( + patch_content: &str, + call_id: &str, +) -> anyhow::Result { + // Use shell command to call apply_patch with heredoc format + let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF"); + let tool_call_arguments = serde_json::to_string(&json!({ + "command": ["bash", "-lc", shell_command] + }))?; + + let tool_call = json!({ + "choices": [ + { + "delta": { + "tool_calls": [ + { + "id": call_id, + "function": { + "name": "shell", + "arguments": tool_call_arguments + } + } + ] + }, + "finish_reason": "tool_calls" + } + ] + }); + + let sse = format!( + "data: {}\n\ndata: DONE\n\n", + serde_json::to_string(&tool_call)? + ); + Ok(sse) +} diff --git a/codex-rs/mcp-server/tests/elicitation.rs b/codex-rs/mcp-server/tests/elicitation.rs index 7fd68d67..ac9435e8 100644 --- a/codex-rs/mcp-server/tests/elicitation.rs +++ b/codex-rs/mcp-server/tests/elicitation.rs @@ -1,11 +1,17 @@ mod common; +use std::collections::HashMap; +use std::env; use std::path::Path; +use std::path::PathBuf; use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_core::protocol::FileChange; use codex_core::protocol::ReviewDecision; use codex_mcp_server::ExecApprovalElicitRequestParams; use codex_mcp_server::ExecApprovalResponse; +use codex_mcp_server::PatchApprovalElicitRequestParams; +use codex_mcp_server::PatchApprovalResponse; use mcp_types::ElicitRequest; use mcp_types::ElicitRequestParamsRequestedSchema; use mcp_types::JSONRPC_VERSION; @@ -17,8 +23,10 @@ use pretty_assertions::assert_eq; use serde_json::json; use tempfile::TempDir; use tokio::time::timeout; +use wiremock::MockServer; use crate::common::McpProcess; +use crate::common::create_apply_patch_sse_response; use crate::common::create_final_assistant_message_sse_response; use crate::common::create_mock_chat_completions_server; use crate::common::create_shell_sse_response; @@ -30,7 +38,7 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs /// command, as expected. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_shell_command_approval_triggers_elicitation() { - if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { println!( "Skipping test because it cannot execute when network is disabled in a Codex sandbox." ); @@ -49,12 +57,11 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { let shell_command = vec!["git".to_string(), "init".to_string()]; let workdir_for_shell_function_call = TempDir::new()?; - // Configure the mock server so it makes two responses: - // 1. The first response is a shell function call that will trigger an - // elicitation request. - // 2. The second response is the final assistant message that should be - // returned after the elicitation is approved and the command is run. - let server = create_mock_chat_completions_server(vec![ + let McpHandle { + process: mut mcp_process, + server: _server, + dir: _dir, + } = create_mcp_process(vec![ create_shell_sse_response( shell_command.clone(), Some(workdir_for_shell_function_call.path()), @@ -63,18 +70,14 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { )?, create_final_assistant_message_sse_response("Enjoy your new git repo!")?, ]) - .await; - - // Run `codex mcp` with a specific config.toml. - let codex_home = TempDir::new()?; - create_config_toml(codex_home.path(), server.uri())?; - let mut mcp_process = McpProcess::new(codex_home.path()).await?; - timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + .await?; // Send a "codex" tool request, which should hit the completions endpoint. // In turn, it should reply with a tool call, which the MCP should forward // as an elicitation. - let codex_request_id = mcp_process.send_codex_tool_call("run `git init`").await?; + let codex_request_id = mcp_process + .send_codex_tool_call(None, "run `git init`") + .await?; let elicitation_request = timeout( DEFAULT_READ_TIMEOUT, mcp_process.read_stream_until_request_message(), @@ -136,32 +139,6 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { Ok(()) } -/// Create a Codex config that uses the mock server as the model provider. -/// It also uses `approval_policy = "untrusted"` so that we exercise the -/// elicitation code path for shell commands. -fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> { - let config_toml = codex_home.join("config.toml"); - std::fs::write( - config_toml, - format!( - r#" -model = "mock-model" -approval_policy = "untrusted" -sandbox_policy = "read-only" - -model_provider = "mock_provider" - -[model_providers.mock_provider] -name = "Mock provider for test" -base_url = "{server_uri}/v1" -wire_api = "chat" -request_max_retries = 0 -stream_max_retries = 0 -"# - ), - ) -} - fn create_expected_elicitation_request( elicitation_request_id: RequestId, command: Vec, @@ -193,3 +170,197 @@ fn create_expected_elicitation_request( })?), }) } + +/// Test that patch approval triggers an elicitation request to the MCP and that +/// sending the approval applies the patch, as expected. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_patch_approval_triggers_elicitation() { + if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + if let Err(err) = patch_approval_triggers_elicitation().await { + panic!("failure: {err}"); + } +} + +async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> { + let cwd = TempDir::new()?; + let test_file = cwd.path().join("destination_file.txt"); + std::fs::write(&test_file, "original content\n")?; + + let patch_content = format!( + "*** Begin Patch\n*** Update File: {}\n-original content\n+modified content\n*** End Patch", + test_file.as_path().to_string_lossy() + ); + + let McpHandle { + process: mut mcp_process, + server: _server, + dir: _dir, + } = create_mcp_process(vec![ + create_apply_patch_sse_response(&patch_content, "call1234")?, + create_final_assistant_message_sse_response("Patch has been applied successfully!")?, + ]) + .await?; + + // Send a "codex" tool request that will trigger the apply_patch command + let codex_request_id = mcp_process + .send_codex_tool_call( + Some(cwd.path().to_string_lossy().to_string()), + "please modify the test file", + ) + .await?; + let elicitation_request = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_request_message(), + ) + .await??; + + let elicitation_request_id = RequestId::Integer(0); + + let mut expected_changes = HashMap::new(); + expected_changes.insert( + test_file.as_path().to_path_buf(), + FileChange::Update { + unified_diff: "@@ -1 +1 @@\n-original content\n+modified content\n".to_string(), + move_path: None, + }, + ); + + let expected_elicitation_request = create_expected_patch_approval_elicitation_request( + elicitation_request_id.clone(), + expected_changes, + None, // No grant_root expected + None, // No reason expected + codex_request_id.to_string(), + "1".to_string(), + )?; + assert_eq!(expected_elicitation_request, elicitation_request); + + // Accept the patch approval request by responding to the elicitation + mcp_process + .send_response( + elicitation_request_id, + serde_json::to_value(PatchApprovalResponse { + decision: ReviewDecision::Approved, + })?, + ) + .await?; + + // Verify the original `codex` tool call completes + let codex_response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(codex_request_id), + result: json!({ + "content": [ + { + "text": "Patch has been applied successfully!", + "type": "text" + } + ] + }), + }, + codex_response + ); + + let file_contents = std::fs::read_to_string(test_file.as_path())?; + assert_eq!(file_contents, "modified content\n"); + + Ok(()) +} + +fn create_expected_patch_approval_elicitation_request( + elicitation_request_id: RequestId, + changes: HashMap, + grant_root: Option, + reason: Option, + codex_mcp_tool_call_id: String, + codex_event_id: String, +) -> anyhow::Result { + let mut message_lines = Vec::new(); + if let Some(r) = &reason { + message_lines.push(r.clone()); + } + message_lines.push("Allow Codex to apply proposed code changes?".to_string()); + + Ok(JSONRPCRequest { + jsonrpc: JSONRPC_VERSION.into(), + id: elicitation_request_id, + method: ElicitRequest::METHOD.to_string(), + params: Some(serde_json::to_value(&PatchApprovalElicitRequestParams { + message: message_lines.join("\n"), + requested_schema: ElicitRequestParamsRequestedSchema { + r#type: "object".to_string(), + properties: json!({}), + required: None, + }, + codex_elicitation: "patch-approval".to_string(), + codex_mcp_tool_call_id, + codex_event_id, + codex_reason: reason, + codex_grant_root: grant_root, + codex_changes: changes, + })?), + }) +} + +/// This handle is used to ensure that the MockServer and TempDir are not dropped while +/// the McpProcess is still running. +pub struct McpHandle { + pub process: McpProcess, + /// Retain the server for the lifetime of the McpProcess. + #[allow(dead_code)] + server: MockServer, + /// Retain the temporary directory for the lifetime of the McpProcess. + #[allow(dead_code)] + dir: TempDir, +} + +async fn create_mcp_process(responses: Vec) -> anyhow::Result { + let server = create_mock_chat_completions_server(responses).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + Ok(McpHandle { + process: mcp_process, + server, + dir: codex_home, + }) +} + +/// Create a Codex config that uses the mock server as the model provider. +/// It also uses `approval_policy = "untrusted"` so that we exercise the +/// elicitation code path for shell commands. +fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + format!( + r#" +model = "mock-model" +approval_policy = "untrusted" +sandbox_policy = "read-only" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "{server_uri}/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"# + ), + ) +}