Add an elicitation for approve patch and refactor tool calls (#1642)
1. Added an elicitation for `approve-patch` which is very similar to `approve-exec`. 2. Extracted both elicitations to their own files to prevent `codex_tool_runner` from blowing up in size.
This commit is contained in:
@@ -3,38 +3,31 @@
|
||||
//! and to make future feature-growth easier to manage.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::codex_wrapper::init_codex;
|
||||
use codex_core::config::Config as CodexConfig;
|
||||
use codex_core::protocol::AgentMessageEvent;
|
||||
use codex_core::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::ExecApprovalRequestEvent;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_core::protocol::Submission;
|
||||
use codex_core::protocol::TaskCompleteEvent;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use mcp_types::TextContent;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::exec_approval::handle_exec_approval_request;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::patch_approval::handle_patch_approval_request;
|
||||
|
||||
const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
pub(crate) const INVALID_PARAMS_ERROR_CODE: i64 = -32602;
|
||||
|
||||
/// Run a complete Codex session and stream events back to the client.
|
||||
///
|
||||
@@ -120,7 +113,7 @@ async fn run_codex_tool_session_inner(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
) {
|
||||
let sub_id = match &request_id {
|
||||
let request_id_str = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(n) => n.to_string(),
|
||||
};
|
||||
@@ -138,80 +131,34 @@ async fn run_codex_tool_session_inner(
|
||||
cwd,
|
||||
reason: _,
|
||||
}) => {
|
||||
let escaped_command = shlex::try_join(command.iter().map(|s| s.as_str()))
|
||||
.unwrap_or_else(|_| command.join(" "));
|
||||
let message = format!(
|
||||
"Allow Codex to run `{escaped_command}` in `{cwd}`?",
|
||||
cwd = cwd.to_string_lossy()
|
||||
);
|
||||
|
||||
let params = ExecApprovalElicitRequestParams {
|
||||
message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id: sub_id.clone(),
|
||||
codex_event_id: event.id.clone(),
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!(
|
||||
"Failed to serialize ExecApprovalElicitRequestParams: {err}"
|
||||
);
|
||||
tracing::error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we do
|
||||
// not block the main loop of this function.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event.id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Continue, don't break so the session continues.
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::ApplyPatchApprovalRequest(_) => {
|
||||
let result = CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
r#type: "text".to_string(),
|
||||
text: "PATCH_APPROVAL_REQUIRED".to_string(),
|
||||
annotations: None,
|
||||
})],
|
||||
is_error: None,
|
||||
structured_content: None,
|
||||
};
|
||||
outgoing
|
||||
.send_response(request_id.clone(), result.into())
|
||||
.await;
|
||||
// Continue, don't break so the session continues.
|
||||
EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent {
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
}) => {
|
||||
handle_patch_approval_request(
|
||||
reason,
|
||||
grant_root,
|
||||
changes,
|
||||
outgoing.clone(),
|
||||
codex.clone(),
|
||||
request_id.clone(),
|
||||
request_id_str.clone(),
|
||||
event.id.clone(),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }) => {
|
||||
@@ -286,71 +233,3 @@ async fn run_codex_tool_session_inner(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to deserialize `value` and then make the appropriate call to `codex`.
|
||||
let response = match serde_json::from_value::<ExecApprovalResponse>(value) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
error!("failed to deserialize ExecApprovalResponse: {err}");
|
||||
// If we cannot deserialize the response, we deny the request to be
|
||||
// conservative.
|
||||
ExecApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636
|
||||
// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages
|
||||
// It should have "action" and "content" fields.
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`mcp_types::ElicitRequest`].
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
pub message: String,
|
||||
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
|
||||
// These are additional fields the client can use to
|
||||
// correlate the request with the codex tool call.
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
|
||||
145
codex-rs/mcp-server/src/exec_approval.rs
Normal file
145
codex-rs/mcp-server/src/exec_approval.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
|
||||
/// Conforms to [`mcp_types::ElicitRequestParams`] so that it can be used as the
|
||||
/// `params` field of an [`ElicitRequest`].
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ExecApprovalElicitRequestParams {
|
||||
// These fields are required so that `params`
|
||||
// conforms to ElicitRequestParams.
|
||||
pub message: String,
|
||||
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
|
||||
// These are additional fields the client can use to
|
||||
// correlate the request with the codex tool call.
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
pub codex_command: Vec<String>,
|
||||
pub codex_cwd: PathBuf,
|
||||
}
|
||||
|
||||
// TODO(mbolin): ExecApprovalResponse does not conform to ElicitResult. See:
|
||||
// - https://github.com/modelcontextprotocol/modelcontextprotocol/blob/f962dc1780fa5eed7fb7c8a0232f1fc83ef220cd/schema/2025-06-18/schema.json#L617-L636
|
||||
// - https://modelcontextprotocol.io/specification/draft/client/elicitation#protocol-messages
|
||||
// It should have "action" and "content" fields.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ExecApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_exec_approval_request(
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
outgoing: Arc<crate::outgoing_message::OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let escaped_command =
|
||||
shlex::try_join(command.iter().map(|s| s.as_str())).unwrap_or_else(|_| command.join(" "));
|
||||
let message = format!(
|
||||
"Allow Codex to run `{escaped_command}` in `{cwd}`?",
|
||||
cwd = cwd.to_string_lossy()
|
||||
);
|
||||
|
||||
let params = ExecApprovalElicitRequestParams {
|
||||
message,
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "exec-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_command: command,
|
||||
codex_cwd: cwd,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize ExecApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_exec_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn on_exec_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Try to deserialize `value` and then make the appropriate call to `codex`.
|
||||
let response = serde_json::from_value::<ExecApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize ExecApprovalResponse: {err}");
|
||||
// If we cannot deserialize the response, we deny the request to be
|
||||
// conservative.
|
||||
ExecApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::ExecApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit ExecApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -16,17 +16,21 @@ use tracing::info;
|
||||
|
||||
mod codex_tool_config;
|
||||
mod codex_tool_runner;
|
||||
mod exec_approval;
|
||||
mod json_to_toml;
|
||||
mod message_processor;
|
||||
mod outgoing_message;
|
||||
mod patch_approval;
|
||||
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
pub use crate::codex_tool_config::CodexToolCallParam;
|
||||
pub use crate::codex_tool_runner::ExecApprovalElicitRequestParams;
|
||||
pub use crate::codex_tool_runner::ExecApprovalResponse;
|
||||
pub use crate::exec_approval::ExecApprovalElicitRequestParams;
|
||||
pub use crate::exec_approval::ExecApprovalResponse;
|
||||
pub use crate::patch_approval::PatchApprovalElicitRequestParams;
|
||||
pub use crate::patch_approval::PatchApprovalResponse;
|
||||
|
||||
/// Size of the bounded channels used to communicate between tasks. The value
|
||||
/// is a balance between throughput and memory usage – 128 messages should be
|
||||
|
||||
147
codex-rs/mcp-server/src/patch_approval.rs
Normal file
147
codex-rs/mcp-server/src/patch_approval.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPCErrorError;
|
||||
use mcp_types::ModelContextProtocolRequest;
|
||||
use mcp_types::RequestId;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex_tool_runner::INVALID_PARAMS_ERROR_CODE;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PatchApprovalElicitRequestParams {
|
||||
pub message: String,
|
||||
#[serde(rename = "requestedSchema")]
|
||||
pub requested_schema: ElicitRequestParamsRequestedSchema,
|
||||
pub codex_elicitation: String,
|
||||
pub codex_mcp_tool_call_id: String,
|
||||
pub codex_event_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_reason: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub codex_grant_root: Option<PathBuf>,
|
||||
pub codex_changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct PatchApprovalResponse {
|
||||
pub decision: ReviewDecision,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_patch_approval_request(
|
||||
reason: Option<String>,
|
||||
grant_root: Option<PathBuf>,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
codex: Arc<Codex>,
|
||||
request_id: RequestId,
|
||||
tool_call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
let params = PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id: tool_call_id.clone(),
|
||||
codex_event_id: event_id.clone(),
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
};
|
||||
let params_json = match serde_json::to_value(¶ms) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
let message = format!("Failed to serialize PatchApprovalElicitRequestParams: {err}");
|
||||
error!("{message}");
|
||||
|
||||
outgoing
|
||||
.send_error(
|
||||
request_id.clone(),
|
||||
JSONRPCErrorError {
|
||||
code: INVALID_PARAMS_ERROR_CODE,
|
||||
message,
|
||||
data: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let on_response = outgoing
|
||||
.send_request(ElicitRequest::METHOD, Some(params_json))
|
||||
.await;
|
||||
|
||||
// Listen for the response on a separate task so we don't block the main agent loop.
|
||||
{
|
||||
let codex = codex.clone();
|
||||
let event_id = event_id.clone();
|
||||
tokio::spawn(async move {
|
||||
on_patch_approval_response(event_id, on_response, codex).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn on_patch_approval_response(
|
||||
event_id: String,
|
||||
receiver: tokio::sync::oneshot::Receiver<mcp_types::Result>,
|
||||
codex: Arc<Codex>,
|
||||
) {
|
||||
let response = receiver.await;
|
||||
let value = match response {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
error!("request failed: {err:?}");
|
||||
if let Err(submit_err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id.clone(),
|
||||
decision: ReviewDecision::Denied,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit denied PatchApproval after request failure: {submit_err}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let response = serde_json::from_value::<PatchApprovalResponse>(value).unwrap_or_else(|err| {
|
||||
error!("failed to deserialize PatchApprovalResponse: {err}");
|
||||
PatchApprovalResponse {
|
||||
decision: ReviewDecision::Denied,
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(err) = codex
|
||||
.submit(Op::PatchApproval {
|
||||
id: event_id,
|
||||
decision: response.decision,
|
||||
})
|
||||
.await
|
||||
{
|
||||
error!("failed to submit PatchApproval: {err}");
|
||||
}
|
||||
}
|
||||
@@ -139,14 +139,18 @@ impl McpProcess {
|
||||
|
||||
/// Returns the id used to make the request so it can be used when
|
||||
/// correlating notifications.
|
||||
pub async fn send_codex_tool_call(&mut self, prompt: &str) -> anyhow::Result<i64> {
|
||||
pub async fn send_codex_tool_call(
|
||||
&mut self,
|
||||
cwd: Option<String>,
|
||||
prompt: &str,
|
||||
) -> anyhow::Result<i64> {
|
||||
let codex_tool_call_params = CallToolRequestParams {
|
||||
name: "codex".to_string(),
|
||||
arguments: Some(serde_json::to_value(CodexToolCallParam {
|
||||
cwd,
|
||||
prompt: prompt.to_string(),
|
||||
model: None,
|
||||
profile: None,
|
||||
cwd: None,
|
||||
approval_policy: None,
|
||||
sandbox: None,
|
||||
config: None,
|
||||
|
||||
@@ -4,5 +4,6 @@ mod responses;
|
||||
|
||||
pub use mcp_process::McpProcess;
|
||||
pub use mock_model_server::create_mock_chat_completions_server;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_shell_sse_response;
|
||||
|
||||
@@ -57,3 +57,39 @@ pub fn create_final_assistant_message_sse_response(message: &str) -> anyhow::Res
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
pub fn create_apply_patch_sse_response(
|
||||
patch_content: &str,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
// Use shell command to call apply_patch with heredoc format
|
||||
let shell_command = format!("apply_patch <<'EOF'\n{patch_content}\nEOF");
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"command": ["bash", "-lc", shell_command]
|
||||
}))?;
|
||||
|
||||
let tool_call = json!({
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": tool_call_arguments
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "tool_calls"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let sse = format!(
|
||||
"data: {}\n\ndata: DONE\n\n",
|
||||
serde_json::to_string(&tool_call)?
|
||||
);
|
||||
Ok(sse)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod common;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_core::protocol::FileChange;
|
||||
use codex_core::protocol::ReviewDecision;
|
||||
use codex_mcp_server::ExecApprovalElicitRequestParams;
|
||||
use codex_mcp_server::ExecApprovalResponse;
|
||||
use codex_mcp_server::PatchApprovalElicitRequestParams;
|
||||
use codex_mcp_server::PatchApprovalResponse;
|
||||
use mcp_types::ElicitRequest;
|
||||
use mcp_types::ElicitRequestParamsRequestedSchema;
|
||||
use mcp_types::JSONRPC_VERSION;
|
||||
@@ -17,8 +23,10 @@ use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::MockServer;
|
||||
|
||||
use crate::common::McpProcess;
|
||||
use crate::common::create_apply_patch_sse_response;
|
||||
use crate::common::create_final_assistant_message_sse_response;
|
||||
use crate::common::create_mock_chat_completions_server;
|
||||
use crate::common::create_shell_sse_response;
|
||||
@@ -30,7 +38,7 @@ const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs
|
||||
/// command, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shell_command_approval_triggers_elicitation() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
@@ -49,12 +57,11 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
let shell_command = vec!["git".to_string(), "init".to_string()];
|
||||
let workdir_for_shell_function_call = TempDir::new()?;
|
||||
|
||||
// Configure the mock server so it makes two responses:
|
||||
// 1. The first response is a shell function call that will trigger an
|
||||
// elicitation request.
|
||||
// 2. The second response is the final assistant message that should be
|
||||
// returned after the elicitation is approved and the command is run.
|
||||
let server = create_mock_chat_completions_server(vec![
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_shell_sse_response(
|
||||
shell_command.clone(),
|
||||
Some(workdir_for_shell_function_call.path()),
|
||||
@@ -63,18 +70,14 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("Enjoy your new git repo!")?,
|
||||
])
|
||||
.await;
|
||||
|
||||
// Run `codex mcp` with a specific config.toml.
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request, which should hit the completions endpoint.
|
||||
// In turn, it should reply with a tool call, which the MCP should forward
|
||||
// as an elicitation.
|
||||
let codex_request_id = mcp_process.send_codex_tool_call("run `git init`").await?;
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(None, "run `git init`")
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
@@ -136,32 +139,6 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a Codex config that uses the mock server as the model provider.
|
||||
/// It also uses `approval_policy = "untrusted"` so that we exercise the
|
||||
/// elicitation code path for shell commands.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: String) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "untrusted"
|
||||
sandbox_policy = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn create_expected_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
command: Vec<String>,
|
||||
@@ -193,3 +170,197 @@ fn create_expected_elicitation_request(
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// Test that patch approval triggers an elicitation request to the MCP and that
|
||||
/// sending the approval applies the patch, as expected.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_patch_approval_triggers_elicitation() {
|
||||
if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if let Err(err) = patch_approval_triggers_elicitation().await {
|
||||
panic!("failure: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> {
|
||||
let cwd = TempDir::new()?;
|
||||
let test_file = cwd.path().join("destination_file.txt");
|
||||
std::fs::write(&test_file, "original content\n")?;
|
||||
|
||||
let patch_content = format!(
|
||||
"*** Begin Patch\n*** Update File: {}\n-original content\n+modified content\n*** End Patch",
|
||||
test_file.as_path().to_string_lossy()
|
||||
);
|
||||
|
||||
let McpHandle {
|
||||
process: mut mcp_process,
|
||||
server: _server,
|
||||
dir: _dir,
|
||||
} = create_mcp_process(vec![
|
||||
create_apply_patch_sse_response(&patch_content, "call1234")?,
|
||||
create_final_assistant_message_sse_response("Patch has been applied successfully!")?,
|
||||
])
|
||||
.await?;
|
||||
|
||||
// Send a "codex" tool request that will trigger the apply_patch command
|
||||
let codex_request_id = mcp_process
|
||||
.send_codex_tool_call(
|
||||
Some(cwd.path().to_string_lossy().to_string()),
|
||||
"please modify the test file",
|
||||
)
|
||||
.await?;
|
||||
let elicitation_request = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_request_message(),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let elicitation_request_id = RequestId::Integer(0);
|
||||
|
||||
let mut expected_changes = HashMap::new();
|
||||
expected_changes.insert(
|
||||
test_file.as_path().to_path_buf(),
|
||||
FileChange::Update {
|
||||
unified_diff: "@@ -1 +1 @@\n-original content\n+modified content\n".to_string(),
|
||||
move_path: None,
|
||||
},
|
||||
);
|
||||
|
||||
let expected_elicitation_request = create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id.clone(),
|
||||
expected_changes,
|
||||
None, // No grant_root expected
|
||||
None, // No reason expected
|
||||
codex_request_id.to_string(),
|
||||
"1".to_string(),
|
||||
)?;
|
||||
assert_eq!(expected_elicitation_request, elicitation_request);
|
||||
|
||||
// Accept the patch approval request by responding to the elicitation
|
||||
mcp_process
|
||||
.send_response(
|
||||
elicitation_request_id,
|
||||
serde_json::to_value(PatchApprovalResponse {
|
||||
decision: ReviewDecision::Approved,
|
||||
})?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Verify the original `codex` tool call completes
|
||||
let codex_response = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)),
|
||||
)
|
||||
.await??;
|
||||
assert_eq!(
|
||||
JSONRPCResponse {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: RequestId::Integer(codex_request_id),
|
||||
result: json!({
|
||||
"content": [
|
||||
{
|
||||
"text": "Patch has been applied successfully!",
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}),
|
||||
},
|
||||
codex_response
|
||||
);
|
||||
|
||||
let file_contents = std::fs::read_to_string(test_file.as_path())?;
|
||||
assert_eq!(file_contents, "modified content\n");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_expected_patch_approval_elicitation_request(
|
||||
elicitation_request_id: RequestId,
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
grant_root: Option<PathBuf>,
|
||||
reason: Option<String>,
|
||||
codex_mcp_tool_call_id: String,
|
||||
codex_event_id: String,
|
||||
) -> anyhow::Result<JSONRPCRequest> {
|
||||
let mut message_lines = Vec::new();
|
||||
if let Some(r) = &reason {
|
||||
message_lines.push(r.clone());
|
||||
}
|
||||
message_lines.push("Allow Codex to apply proposed code changes?".to_string());
|
||||
|
||||
Ok(JSONRPCRequest {
|
||||
jsonrpc: JSONRPC_VERSION.into(),
|
||||
id: elicitation_request_id,
|
||||
method: ElicitRequest::METHOD.to_string(),
|
||||
params: Some(serde_json::to_value(&PatchApprovalElicitRequestParams {
|
||||
message: message_lines.join("\n"),
|
||||
requested_schema: ElicitRequestParamsRequestedSchema {
|
||||
r#type: "object".to_string(),
|
||||
properties: json!({}),
|
||||
required: None,
|
||||
},
|
||||
codex_elicitation: "patch-approval".to_string(),
|
||||
codex_mcp_tool_call_id,
|
||||
codex_event_id,
|
||||
codex_reason: reason,
|
||||
codex_grant_root: grant_root,
|
||||
codex_changes: changes,
|
||||
})?),
|
||||
})
|
||||
}
|
||||
|
||||
/// This handle is used to ensure that the MockServer and TempDir are not dropped while
|
||||
/// the McpProcess is still running.
|
||||
pub struct McpHandle {
|
||||
pub process: McpProcess,
|
||||
/// Retain the server for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
server: MockServer,
|
||||
/// Retain the temporary directory for the lifetime of the McpProcess.
|
||||
#[allow(dead_code)]
|
||||
dir: TempDir,
|
||||
}
|
||||
|
||||
async fn create_mcp_process(responses: Vec<String>) -> anyhow::Result<McpHandle> {
|
||||
let server = create_mock_chat_completions_server(responses).await;
|
||||
let codex_home = TempDir::new()?;
|
||||
create_config_toml(codex_home.path(), &server.uri())?;
|
||||
let mut mcp_process = McpProcess::new(codex_home.path()).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??;
|
||||
Ok(McpHandle {
|
||||
process: mcp_process,
|
||||
server,
|
||||
dir: codex_home,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a Codex config that uses the mock server as the model provider.
|
||||
/// It also uses `approval_policy = "untrusted"` so that we exercise the
|
||||
/// elicitation code path for shell commands.
|
||||
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
||||
let config_toml = codex_home.join("config.toml");
|
||||
std::fs::write(
|
||||
config_toml,
|
||||
format!(
|
||||
r#"
|
||||
model = "mock-model"
|
||||
approval_policy = "untrusted"
|
||||
sandbox_policy = "read-only"
|
||||
|
||||
model_provider = "mock_provider"
|
||||
|
||||
[model_providers.mock_provider]
|
||||
name = "Mock provider for test"
|
||||
base_url = "{server_uri}/v1"
|
||||
wire_api = "chat"
|
||||
request_max_retries = 0
|
||||
stream_max_retries = 0
|
||||
"#
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user