Files
llmx/codex-rs/mcp-server/src/message_processor.rs
Michael Bolin d49d802b06 test: add integration test for MCP server (#1633)
This PR introduces a single integration test for `cargo mcp`, though it
also introduces a number of reusable components so that it should be
easier to introduce more integration tests going forward.

The new test is introduced in `codex-rs/mcp-server/tests/elicitation.rs`
and the reusable pieces are in `codex-rs/mcp-server/tests/common`.

The test itself verifies new functionality around elicitations
introduced in https://github.com/openai/codex/pull/1623 (and the fix
introduced in https://github.com/openai/codex/pull/1629) by doing the
following:

- starts a mock model provider with canned responses for
`/v1/chat/completions`
- starts the MCP server with a `config.toml` to use that model provider
(and `approval_policy = "untrusted"`)
- sends the `codex` tool call which causes the mock model provider to
request a shell call for `git init`
- the MCP server sends an elicitation to the client to approve the
request
- the client replies to the elicitation with `"approved"`
- the MCP server runs the command and re-samples the model, getting a
`"finish_reason": "stop"`
- in turn, the MCP server sends the final response to the original
`codex` tool call
- verifies that `git init` ran as expected

To test:

```
cargo test shell_command_approval_triggers_elicitation
```

In writing this test, I discovered that `ExecApprovalResponse` does not
conform to `ElicitResult`, so I added a TODO to fix that, since I think
that should be updated in a separate PR. As it stands, this PR does not
update any business logic, though it does make a number of members of
the `mcp-server` crate `pub` so they can be used in the test.

One additional learning from this PR is that
`std::process::Command::cargo_bin()` from the `assert_cmd` trait is only
available for `std::process::Command`, but we really want to use
`tokio::process::Command` so that everything is async and we can
leverage utilities like `tokio::time::timeout()`. The trick I came up
with was to use `cargo_bin()` to locate the program, and then to use
`std::process::Command::get_program()` when constructing the
`tokio::process::Command`.
2025-07-21 10:27:07 -07:00

444 lines
16 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::path::PathBuf;
use std::sync::Arc;
use crate::codex_tool_config::CodexToolCallParam;
use crate::codex_tool_config::create_tool_for_codex_tool_call_param;
use crate::outgoing_message::OutgoingMessageSender;
use codex_core::config::Config as CodexConfig;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::ClientRequest;
use mcp_types::ContentBlock;
use mcp_types::JSONRPCError;
use mcp_types::JSONRPCErrorError;
use mcp_types::JSONRPCNotification;
use mcp_types::JSONRPCRequest;
use mcp_types::JSONRPCResponse;
use mcp_types::ListToolsResult;
use mcp_types::ModelContextProtocolRequest;
use mcp_types::RequestId;
use mcp_types::ServerCapabilitiesTools;
use mcp_types::ServerNotification;
use mcp_types::TextContent;
use serde_json::json;
use tokio::task;
pub(crate) struct MessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
initialized: bool,
codex_linux_sandbox_exe: Option<PathBuf>,
}
impl MessageProcessor {
/// Create a new `MessageProcessor`, retaining a handle to the outgoing
/// `Sender` so handlers can enqueue messages to be written to stdout.
pub(crate) fn new(
outgoing: OutgoingMessageSender,
codex_linux_sandbox_exe: Option<PathBuf>,
) -> Self {
Self {
outgoing: Arc::new(outgoing),
initialized: false,
codex_linux_sandbox_exe,
}
}
pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) {
// Hold on to the ID so we can respond.
let request_id = request.id.clone();
let client_request = match ClientRequest::try_from(request) {
Ok(client_request) => client_request,
Err(e) => {
tracing::warn!("Failed to convert request: {e}");
return;
}
};
// Dispatch to a dedicated handler for each request type.
match client_request {
ClientRequest::InitializeRequest(params) => {
self.handle_initialize(request_id, params).await;
}
ClientRequest::PingRequest(params) => {
self.handle_ping(request_id, params).await;
}
ClientRequest::ListResourcesRequest(params) => {
self.handle_list_resources(params);
}
ClientRequest::ListResourceTemplatesRequest(params) => {
self.handle_list_resource_templates(params);
}
ClientRequest::ReadResourceRequest(params) => {
self.handle_read_resource(params);
}
ClientRequest::SubscribeRequest(params) => {
self.handle_subscribe(params);
}
ClientRequest::UnsubscribeRequest(params) => {
self.handle_unsubscribe(params);
}
ClientRequest::ListPromptsRequest(params) => {
self.handle_list_prompts(params);
}
ClientRequest::GetPromptRequest(params) => {
self.handle_get_prompt(params);
}
ClientRequest::ListToolsRequest(params) => {
self.handle_list_tools(request_id, params).await;
}
ClientRequest::CallToolRequest(params) => {
self.handle_call_tool(request_id, params).await;
}
ClientRequest::SetLevelRequest(params) => {
self.handle_set_level(params);
}
ClientRequest::CompleteRequest(params) => {
self.handle_complete(params);
}
}
}
/// Handle a standalone JSON-RPC response originating from the peer.
pub(crate) async fn process_response(&mut self, response: JSONRPCResponse) {
tracing::info!("<- response: {:?}", response);
let JSONRPCResponse { id, result, .. } = response;
self.outgoing.notify_client_response(id, result).await
}
/// Handle a fire-and-forget JSON-RPC notification.
pub(crate) fn process_notification(&mut self, notification: JSONRPCNotification) {
let server_notification = match ServerNotification::try_from(notification) {
Ok(n) => n,
Err(e) => {
tracing::warn!("Failed to convert notification: {e}");
return;
}
};
// Similar to requests, route each notification type to its own stub
// handler so additional logic can be implemented incrementally.
match server_notification {
ServerNotification::CancelledNotification(params) => {
self.handle_cancelled_notification(params);
}
ServerNotification::ProgressNotification(params) => {
self.handle_progress_notification(params);
}
ServerNotification::ResourceListChangedNotification(params) => {
self.handle_resource_list_changed(params);
}
ServerNotification::ResourceUpdatedNotification(params) => {
self.handle_resource_updated(params);
}
ServerNotification::PromptListChangedNotification(params) => {
self.handle_prompt_list_changed(params);
}
ServerNotification::ToolListChangedNotification(params) => {
self.handle_tool_list_changed(params);
}
ServerNotification::LoggingMessageNotification(params) => {
self.handle_logging_message(params);
}
}
}
/// Handle an error object received from the peer.
pub(crate) fn process_error(&mut self, err: JSONRPCError) {
tracing::error!("<- error: {:?}", err);
}
async fn handle_initialize(
&mut self,
id: RequestId,
params: <mcp_types::InitializeRequest as ModelContextProtocolRequest>::Params,
) {
tracing::info!("initialize -> params: {:?}", params);
if self.initialized {
// Already initialised: send JSON-RPC error response.
let error = JSONRPCErrorError {
code: -32600, // Invalid Request
message: "initialize called more than once".to_string(),
data: None,
};
self.outgoing.send_error(id, error).await;
return;
}
self.initialized = true;
// Build a minimal InitializeResult. Fill with placeholders.
let result = mcp_types::InitializeResult {
capabilities: mcp_types::ServerCapabilities {
completions: None,
experimental: None,
logging: None,
prompts: None,
resources: None,
tools: Some(ServerCapabilitiesTools {
list_changed: Some(true),
}),
},
instructions: None,
protocol_version: params.protocol_version.clone(),
server_info: mcp_types::Implementation {
name: "codex-mcp-server".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
title: Some("Codex".to_string()),
},
};
self.send_response::<mcp_types::InitializeRequest>(id, result)
.await;
}
async fn send_response<T>(&self, id: RequestId, result: T::Result)
where
T: ModelContextProtocolRequest,
{
// result has `Serialized` instance so should never fail
#[expect(clippy::unwrap_used)]
let result = serde_json::to_value(result).unwrap();
self.outgoing.send_response(id, result).await;
}
async fn handle_ping(
&self,
id: RequestId,
params: <mcp_types::PingRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("ping -> params: {:?}", params);
let result = json!({});
self.send_response::<mcp_types::PingRequest>(id, result)
.await;
}
fn handle_list_resources(
&self,
params: <mcp_types::ListResourcesRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("resources/list -> params: {:?}", params);
}
fn handle_list_resource_templates(
&self,
params:
<mcp_types::ListResourceTemplatesRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("resources/templates/list -> params: {:?}", params);
}
fn handle_read_resource(
&self,
params: <mcp_types::ReadResourceRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("resources/read -> params: {:?}", params);
}
fn handle_subscribe(
&self,
params: <mcp_types::SubscribeRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("resources/subscribe -> params: {:?}", params);
}
fn handle_unsubscribe(
&self,
params: <mcp_types::UnsubscribeRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("resources/unsubscribe -> params: {:?}", params);
}
fn handle_list_prompts(
&self,
params: <mcp_types::ListPromptsRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("prompts/list -> params: {:?}", params);
}
fn handle_get_prompt(
&self,
params: <mcp_types::GetPromptRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("prompts/get -> params: {:?}", params);
}
async fn handle_list_tools(
&self,
id: RequestId,
params: <mcp_types::ListToolsRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::trace!("tools/list -> {params:?}");
let result = ListToolsResult {
tools: vec![create_tool_for_codex_tool_call_param()],
next_cursor: None,
};
self.send_response::<mcp_types::ListToolsRequest>(id, result)
.await;
}
async fn handle_call_tool(
&self,
id: RequestId,
params: <mcp_types::CallToolRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("tools/call -> params: {:?}", params);
let CallToolRequestParams { name, arguments } = params;
// We only support the "codex" tool for now.
if name != "codex" {
// Tool not found return error result so the LLM can react.
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text: format!("Unknown tool '{name}'"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
let (initial_prompt, config): (String, CodexConfig) = match arguments {
Some(json_val) => match serde_json::from_value::<CodexToolCallParam>(json_val) {
Ok(tool_cfg) => match tool_cfg.into_config(self.codex_linux_sandbox_exe.clone()) {
Ok(cfg) => cfg,
Err(e) => {
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_owned(),
text: format!(
"Failed to load Codex configuration from overrides: {e}"
),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
},
Err(e) => {
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_owned(),
text: format!("Failed to parse configuration for Codex tool: {e}"),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
},
None => {
let result = CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
r#type: "text".to_string(),
text:
"Missing arguments for codex tool-call; the `prompt` field is required."
.to_string(),
annotations: None,
})],
is_error: Some(true),
structured_content: None,
};
self.send_response::<mcp_types::CallToolRequest>(id, result)
.await;
return;
}
};
// Clone outgoing sender to move into async task.
let outgoing = self.outgoing.clone();
// Spawn an async task to handle the Codex session so that we do not
// block the synchronous message-processing loop.
task::spawn(async move {
// Run the Codex session and stream events back to the client.
crate::codex_tool_runner::run_codex_tool_session(id, initial_prompt, config, outgoing)
.await;
});
}
fn handle_set_level(
&self,
params: <mcp_types::SetLevelRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("logging/setLevel -> params: {:?}", params);
}
fn handle_complete(
&self,
params: <mcp_types::CompleteRequest as mcp_types::ModelContextProtocolRequest>::Params,
) {
tracing::info!("completion/complete -> params: {:?}", params);
}
// ---------------------------------------------------------------------
// Notification handlers
// ---------------------------------------------------------------------
fn handle_cancelled_notification(
&self,
params: <mcp_types::CancelledNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/cancelled -> params: {:?}", params);
}
fn handle_progress_notification(
&self,
params: <mcp_types::ProgressNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/progress -> params: {:?}", params);
}
fn handle_resource_list_changed(
&self,
params: <mcp_types::ResourceListChangedNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!(
"notifications/resources/list_changed -> params: {:?}",
params
);
}
fn handle_resource_updated(
&self,
params: <mcp_types::ResourceUpdatedNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/resources/updated -> params: {:?}", params);
}
fn handle_prompt_list_changed(
&self,
params: <mcp_types::PromptListChangedNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/prompts/list_changed -> params: {:?}", params);
}
fn handle_tool_list_changed(
&self,
params: <mcp_types::ToolListChangedNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/tools/list_changed -> params: {:?}", params);
}
fn handle_logging_message(
&self,
params: <mcp_types::LoggingMessageNotification as mcp_types::ModelContextProtocolNotification>::Params,
) {
tracing::info!("notifications/message -> params: {:?}", params);
}
}