344 lines
12 KiB
Rust
344 lines
12 KiB
Rust
use std::path::Path;
|
|
use std::process::Stdio;
|
|
use std::sync::atomic::AtomicI64;
|
|
use std::sync::atomic::Ordering;
|
|
use tokio::io::AsyncBufReadExt;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::io::BufReader;
|
|
use tokio::process::Child;
|
|
use tokio::process::ChildStdin;
|
|
use tokio::process::ChildStdout;
|
|
|
|
use anyhow::Context;
|
|
use assert_cmd::prelude::*;
|
|
use codex_mcp_server::CodexToolCallParam;
|
|
use codex_mcp_server::CodexToolCallReplyParam;
|
|
use mcp_types::CallToolRequestParams;
|
|
use mcp_types::ClientCapabilities;
|
|
use mcp_types::Implementation;
|
|
use mcp_types::InitializeRequestParams;
|
|
use mcp_types::JSONRPC_VERSION;
|
|
use mcp_types::JSONRPCMessage;
|
|
use mcp_types::JSONRPCNotification;
|
|
use mcp_types::JSONRPCRequest;
|
|
use mcp_types::JSONRPCResponse;
|
|
use mcp_types::ModelContextProtocolNotification;
|
|
use mcp_types::ModelContextProtocolRequest;
|
|
use mcp_types::RequestId;
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::json;
|
|
use std::process::Command as StdCommand;
|
|
use tokio::process::Command;
|
|
|
|
pub struct McpProcess {
|
|
next_request_id: AtomicI64,
|
|
/// Retain this child process until the client is dropped. The Tokio runtime
|
|
/// will make a "best effort" to reap the process after it exits, but it is
|
|
/// not a guarantee. See the `kill_on_drop` documentation for details.
|
|
#[allow(dead_code)]
|
|
process: Child,
|
|
stdin: ChildStdin,
|
|
stdout: BufReader<ChildStdout>,
|
|
}
|
|
|
|
impl McpProcess {
|
|
pub async fn new(codex_home: &Path) -> anyhow::Result<Self> {
|
|
// Use assert_cmd to locate the binary path and then switch to tokio::process::Command
|
|
let std_cmd = StdCommand::cargo_bin("codex-mcp-server")
|
|
.context("should find binary for codex-mcp-server")?;
|
|
|
|
let program = std_cmd.get_program().to_owned();
|
|
|
|
let mut cmd = Command::new(program);
|
|
|
|
cmd.stdin(Stdio::piped());
|
|
cmd.stdout(Stdio::piped());
|
|
cmd.env("CODEX_HOME", codex_home);
|
|
cmd.env("RUST_LOG", "debug");
|
|
|
|
let mut process = cmd
|
|
.kill_on_drop(true)
|
|
.spawn()
|
|
.context("codex-mcp-server proc should start")?;
|
|
let stdin = process
|
|
.stdin
|
|
.take()
|
|
.ok_or_else(|| anyhow::format_err!("mcp should have stdin fd"))?;
|
|
let stdout = process
|
|
.stdout
|
|
.take()
|
|
.ok_or_else(|| anyhow::format_err!("mcp should have stdout fd"))?;
|
|
let stdout = BufReader::new(stdout);
|
|
Ok(Self {
|
|
next_request_id: AtomicI64::new(0),
|
|
process,
|
|
stdin,
|
|
stdout,
|
|
})
|
|
}
|
|
|
|
/// Performs the initialization handshake with the MCP server.
|
|
pub async fn initialize(&mut self) -> anyhow::Result<()> {
|
|
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
|
|
|
let params = InitializeRequestParams {
|
|
capabilities: ClientCapabilities {
|
|
elicitation: Some(json!({})),
|
|
experimental: None,
|
|
roots: None,
|
|
sampling: None,
|
|
},
|
|
client_info: Implementation {
|
|
name: "elicitation test".into(),
|
|
title: Some("Elicitation Test".into()),
|
|
version: "0.0.0".into(),
|
|
},
|
|
protocol_version: mcp_types::MCP_SCHEMA_VERSION.into(),
|
|
};
|
|
let params_value = serde_json::to_value(params)?;
|
|
|
|
self.send_jsonrpc_message(JSONRPCMessage::Request(JSONRPCRequest {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
id: RequestId::Integer(request_id),
|
|
method: mcp_types::InitializeRequest::METHOD.into(),
|
|
params: Some(params_value),
|
|
}))
|
|
.await?;
|
|
|
|
let initialized = self.read_jsonrpc_message().await?;
|
|
assert_eq!(
|
|
JSONRPCMessage::Response(JSONRPCResponse {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
id: RequestId::Integer(request_id),
|
|
result: json!({
|
|
"capabilities": {
|
|
"tools": {
|
|
"listChanged": true
|
|
},
|
|
},
|
|
"serverInfo": {
|
|
"name": "codex-mcp-server",
|
|
"title": "Codex",
|
|
"version": "0.0.0"
|
|
},
|
|
"protocolVersion": mcp_types::MCP_SCHEMA_VERSION
|
|
})
|
|
}),
|
|
initialized
|
|
);
|
|
|
|
// Send notifications/initialized to ack the response.
|
|
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
method: mcp_types::InitializedNotification::METHOD.into(),
|
|
params: None,
|
|
}))
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Returns the id used to make the request so it can be used when
|
|
/// correlating notifications.
|
|
pub async fn send_codex_tool_call(
|
|
&mut self,
|
|
params: CodexToolCallParam,
|
|
) -> anyhow::Result<i64> {
|
|
let codex_tool_call_params = CallToolRequestParams {
|
|
name: "codex".to_string(),
|
|
arguments: Some(serde_json::to_value(params)?),
|
|
};
|
|
self.send_request(
|
|
mcp_types::CallToolRequest::METHOD,
|
|
Some(serde_json::to_value(codex_tool_call_params)?),
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn send_codex_reply_tool_call(
|
|
&mut self,
|
|
session_id: &str,
|
|
prompt: &str,
|
|
) -> anyhow::Result<i64> {
|
|
let codex_tool_call_params = CallToolRequestParams {
|
|
name: "codex-reply".to_string(),
|
|
arguments: Some(serde_json::to_value(CodexToolCallReplyParam {
|
|
prompt: prompt.to_string(),
|
|
session_id: session_id.to_string(),
|
|
})?),
|
|
};
|
|
self.send_request(
|
|
mcp_types::CallToolRequest::METHOD,
|
|
Some(serde_json::to_value(codex_tool_call_params)?),
|
|
)
|
|
.await
|
|
}
|
|
|
|
async fn send_request(
|
|
&mut self,
|
|
method: &str,
|
|
params: Option<serde_json::Value>,
|
|
) -> anyhow::Result<i64> {
|
|
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
|
|
|
let message = JSONRPCMessage::Request(JSONRPCRequest {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
id: RequestId::Integer(request_id),
|
|
method: method.to_string(),
|
|
params,
|
|
});
|
|
self.send_jsonrpc_message(message).await?;
|
|
Ok(request_id)
|
|
}
|
|
|
|
pub async fn send_response(
|
|
&mut self,
|
|
id: RequestId,
|
|
result: serde_json::Value,
|
|
) -> anyhow::Result<()> {
|
|
self.send_jsonrpc_message(JSONRPCMessage::Response(JSONRPCResponse {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
id,
|
|
result,
|
|
}))
|
|
.await
|
|
}
|
|
|
|
async fn send_jsonrpc_message(&mut self, message: JSONRPCMessage) -> anyhow::Result<()> {
|
|
let payload = serde_json::to_string(&message)?;
|
|
self.stdin.write_all(payload.as_bytes()).await?;
|
|
self.stdin.write_all(b"\n").await?;
|
|
self.stdin.flush().await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn read_jsonrpc_message(&mut self) -> anyhow::Result<JSONRPCMessage> {
|
|
let mut line = String::new();
|
|
self.stdout.read_line(&mut line).await?;
|
|
let message = serde_json::from_str::<JSONRPCMessage>(&line)?;
|
|
Ok(message)
|
|
}
|
|
pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result<JSONRPCRequest> {
|
|
loop {
|
|
let message = self.read_jsonrpc_message().await?;
|
|
eprint!("message: {message:?}");
|
|
|
|
match message {
|
|
JSONRPCMessage::Notification(_) => {
|
|
eprintln!("notification: {message:?}");
|
|
}
|
|
JSONRPCMessage::Request(jsonrpc_request) => {
|
|
return Ok(jsonrpc_request);
|
|
}
|
|
JSONRPCMessage::Error(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
|
}
|
|
JSONRPCMessage::Response(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn read_stream_until_response_message(
|
|
&mut self,
|
|
request_id: RequestId,
|
|
) -> anyhow::Result<JSONRPCResponse> {
|
|
loop {
|
|
let message = self.read_jsonrpc_message().await?;
|
|
eprint!("message: {message:?}");
|
|
|
|
match message {
|
|
JSONRPCMessage::Notification(_) => {
|
|
eprintln!("notification: {message:?}");
|
|
}
|
|
JSONRPCMessage::Request(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
|
}
|
|
JSONRPCMessage::Error(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
|
}
|
|
JSONRPCMessage::Response(jsonrpc_response) => {
|
|
if jsonrpc_response.id == request_id {
|
|
return Ok(jsonrpc_response);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn read_stream_until_configured_response_message(
|
|
&mut self,
|
|
) -> anyhow::Result<String> {
|
|
let mut sid_old: Option<String> = None;
|
|
let mut sid_new: Option<String> = None;
|
|
loop {
|
|
let message = self.read_jsonrpc_message().await?;
|
|
eprint!("message: {message:?}");
|
|
|
|
match message {
|
|
JSONRPCMessage::Notification(notification) => {
|
|
if let Some(params) = notification.params {
|
|
// Back-compat schema: method == "codex/event" and msg.type == "session_configured"
|
|
if notification.method == "codex/event" {
|
|
if let Some(msg) = params.get("msg") {
|
|
if msg.get("type").and_then(|v| v.as_str())
|
|
== Some("session_configured")
|
|
{
|
|
if let Some(session_id) =
|
|
msg.get("session_id").and_then(|v| v.as_str())
|
|
{
|
|
sid_old = Some(session_id.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// New schema: method is the Display of EventMsg::SessionConfigured => "SessionConfigured"
|
|
if notification.method == "sessionconfigured" {
|
|
if let Some(msg) = params.get("msg") {
|
|
if let Some(session_id) =
|
|
msg.get("session_id").and_then(|v| v.as_str())
|
|
{
|
|
sid_new = Some(session_id.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if sid_old.is_some() && sid_new.is_some() {
|
|
// Both seen, they must match
|
|
assert_eq!(
|
|
sid_old.as_ref().unwrap(),
|
|
sid_new.as_ref().unwrap(),
|
|
"session_id mismatch between old and new schema"
|
|
);
|
|
return Ok(sid_old.unwrap());
|
|
}
|
|
}
|
|
JSONRPCMessage::Request(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}");
|
|
}
|
|
JSONRPCMessage::Error(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}");
|
|
}
|
|
JSONRPCMessage::Response(_) => {
|
|
anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn send_notification(
|
|
&mut self,
|
|
method: &str,
|
|
params: Option<serde_json::Value>,
|
|
) -> anyhow::Result<()> {
|
|
self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification {
|
|
jsonrpc: JSONRPC_VERSION.into(),
|
|
method: method.to_string(),
|
|
params,
|
|
}))
|
|
.await
|
|
}
|
|
}
|