diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index f118cb67..9da3595e 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2783,6 +2783,7 @@ dependencies = [ "codex-mcp-server", "mcp-types", "pretty_assertions", + "serde", "serde_json", "shlex", "tempfile", diff --git a/codex-rs/mcp-server/src/conversation_loop.rs b/codex-rs/mcp-server/src/conversation_loop.rs deleted file mode 100644 index 7a3ae3e7..00000000 --- a/codex-rs/mcp-server/src/conversation_loop.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::sync::Arc; - -use crate::exec_approval::handle_exec_approval_request; -use crate::outgoing_message::OutgoingMessageSender; -use crate::outgoing_message::OutgoingNotificationMeta; -use crate::patch_approval::handle_patch_approval_request; -use codex_core::CodexConversation; -use codex_core::protocol::AgentMessageEvent; -use codex_core::protocol::ApplyPatchApprovalRequestEvent; -use codex_core::protocol::EventMsg; -use codex_core::protocol::ExecApprovalRequestEvent; -use mcp_types::RequestId; -use tracing::error; - -pub async fn run_conversation_loop( - codex: Arc, - outgoing: Arc, - request_id: RequestId, -) { - let request_id_str = match &request_id { - RequestId::String(s) => s.clone(), - RequestId::Integer(n) => n.to_string(), - }; - - // Stream events until the task needs to pause for user interaction or - // completes. - loop { - match codex.next_event().await { - Ok(event) => { - outgoing - .send_event_as_notification( - &event, - Some(OutgoingNotificationMeta::new(Some(request_id.clone()))), - ) - .await; - - match event.msg { - EventMsg::ExecApprovalRequest(ExecApprovalRequestEvent { - command, - cwd, - call_id, - reason: _, - }) => { - handle_exec_approval_request( - command, - cwd, - outgoing.clone(), - codex.clone(), - request_id.clone(), - request_id_str.clone(), - event.id.clone(), - call_id, - ) - .await; - } - EventMsg::Error(_) => { - error!("Codex runtime error"); - } - EventMsg::ApplyPatchApprovalRequest(ApplyPatchApprovalRequestEvent { - call_id, - reason, - grant_root, - changes, - }) => { - handle_patch_approval_request( - call_id, - reason, - grant_root, - changes, - outgoing.clone(), - codex.clone(), - request_id.clone(), - request_id_str.clone(), - event.id.clone(), - ) - .await; - } - EventMsg::TaskComplete(_) => {} - EventMsg::SessionConfigured(_) => { - tracing::error!("unexpected SessionConfigured event"); - } - EventMsg::AgentMessageDelta(_) => { - // TODO: think how we want to support this in the MCP - } - EventMsg::AgentReasoningDelta(_) => { - // TODO: think how we want to support this in the MCP - } - EventMsg::AgentMessage(AgentMessageEvent { .. }) => { - // TODO: think how we want to support this in the MCP - } - EventMsg::AgentReasoningRawContent(_) - | EventMsg::AgentReasoningRawContentDelta(_) - | EventMsg::TaskStarted - | EventMsg::TokenCount(_) - | EventMsg::AgentReasoning(_) - | EventMsg::AgentReasoningSectionBreak(_) - | EventMsg::McpToolCallBegin(_) - | EventMsg::McpToolCallEnd(_) - | EventMsg::ExecCommandBegin(_) - | EventMsg::ExecCommandEnd(_) - | EventMsg::TurnDiff(_) - | EventMsg::BackgroundEvent(_) - | EventMsg::ExecCommandOutputDelta(_) - | EventMsg::PatchApplyBegin(_) - | EventMsg::PatchApplyEnd(_) - | EventMsg::GetHistoryEntryResponse(_) - | EventMsg::PlanUpdate(_) - | EventMsg::TurnAborted(_) - | EventMsg::ShutdownComplete => { - // For now, we do not do anything extra for these - // events. Note that - // send(codex_event_to_notification(&event)) above has - // already dispatched these events as notifications, - // though we may want to do give different treatment to - // individual events in the future. - } - } - } - Err(e) => { - error!("Codex runtime error: {e}"); - } - } - } -} diff --git a/codex-rs/mcp-server/src/lib.rs b/codex-rs/mcp-server/src/lib.rs index b6dcf824..f30daa92 100644 --- a/codex-rs/mcp-server/src/lib.rs +++ b/codex-rs/mcp-server/src/lib.rs @@ -18,15 +18,12 @@ use tracing_subscriber::EnvFilter; mod codex_message_processor; mod codex_tool_config; mod codex_tool_runner; -mod conversation_loop; mod error_code; mod exec_approval; mod json_to_toml; -pub mod mcp_protocol; pub(crate) mod message_processor; mod outgoing_message; mod patch_approval; -pub(crate) mod tool_handlers; pub mod wire_format; use crate::message_processor::MessageProcessor; diff --git a/codex-rs/mcp-server/src/mcp_protocol.rs b/codex-rs/mcp-server/src/mcp_protocol.rs deleted file mode 100644 index 26c6655f..00000000 --- a/codex-rs/mcp-server/src/mcp_protocol.rs +++ /dev/null @@ -1,1054 +0,0 @@ -use codex_core::config_types::SandboxMode; -use codex_core::protocol::AskForApproval; -use codex_core::protocol::EventMsg; -use codex_core::protocol::InputItem; -use serde::Deserialize; -use serde::Serialize; -use strum_macros::Display; -use uuid::Uuid; - -use mcp_types::CallToolResult; -use mcp_types::ContentBlock; -use mcp_types::RequestId; -use mcp_types::TextContent; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(transparent)] -pub struct ConversationId(pub Uuid); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(transparent)] -pub struct MessageId(pub Uuid); - -// Requests -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallRequest { - #[serde(rename = "jsonrpc")] - pub jsonrpc: &'static str, - pub id: RequestId, - pub method: &'static str, - pub params: ToolCallRequestParams, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(tag = "name", content = "arguments", rename_all = "camelCase")] -pub enum ToolCallRequestParams { - ConversationCreate(ConversationCreateArgs), - ConversationStream(ConversationStreamArgs), - ConversationSendMessage(ConversationSendMessageArgs), - ConversationsList(ConversationsListArgs), -} - -impl ToolCallRequestParams { - /// Wrap this request in a JSON-RPC request. - #[allow(dead_code)] - pub fn into_request(self, id: RequestId) -> ToolCallRequest { - ToolCallRequest { - jsonrpc: "2.0", - id, - method: "tools/call", - params: self, - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationCreateArgs { - pub prompt: String, - pub model: String, - pub cwd: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub approval_policy: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub sandbox: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub config: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub profile: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub base_instructions: Option, -} - -/// Optional overrides for an existing conversation's execution context when sending a message. -/// Fields left as `None` inherit the current conversation/session settings. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationOverrides { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub model: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cwd: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub approval_policy: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub sandbox: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub config: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub profile: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub base_instructions: Option, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationStreamArgs { - pub conversation_id: ConversationId, -} - -/// If omitted, the message continues from the latest turn. -/// Set to resume/edit from an earlier parent message in the thread. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationSendMessageArgs { - pub conversation_id: ConversationId, - pub content: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub parent_message_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - #[serde(flatten)] - pub conversation_overrides: Option, -} -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationsListArgs { - #[serde(default, skip_serializing_if = "Option::is_none")] - pub limit: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub cursor: Option, -} - -// Responses -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCallResponse { - pub request_id: RequestId, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub is_error: Option, - #[serde(default, skip_serializing_if = "Option::is_none", flatten)] - pub result: Option, -} - -impl From for CallToolResult { - fn from(val: ToolCallResponse) -> Self { - let ToolCallResponse { - request_id: _request_id, - is_error, - result, - } = val; - match result { - Some(res) => match serde_json::to_value(&res) { - Ok(v) => CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: v.to_string(), - annotations: None, - })], - is_error, - structured_content: Some(v), - }, - Err(e) => CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: format!("Failed to serialize tool result: {e}"), - annotations: None, - })], - is_error: Some(true), - structured_content: None, - }, - }, - None => CallToolResult { - content: vec![], - is_error, - structured_content: None, - }, - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolCallResponseResult { - ConversationCreate(ConversationCreateResult), - ConversationStream(ConversationStreamResult), - ConversationSendMessage(ConversationSendMessageResult), - ConversationsList(ConversationsListResult), -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ConversationCreateResult { - Ok { - conversation_id: ConversationId, - model: String, - }, - Error { - message: String, - }, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationStreamResult {} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -// TODO: remove this status because we have is_error field in the response. -#[serde(tag = "status", rename_all = "camelCase")] -pub enum ConversationSendMessageResult { - Ok, - Error { message: String }, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationsListResult { - pub conversations: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationSummary { - pub conversation_id: ConversationId, - pub title: String, -} - -// Notifications -#[derive(Debug, Clone, Deserialize, Display)] -pub enum ServerNotification { - InitialState(InitialStateNotificationParams), - StreamDisconnected(StreamDisconnectedNotificationParams), - CodexEvent(Box), -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct NotificationMeta { - #[serde(skip_serializing_if = "Option::is_none")] - pub conversation_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub request_id: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InitialStateNotificationParams { - #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option, - pub initial_state: InitialStatePayload, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InitialStatePayload { - #[serde(default)] - pub events: Vec, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct StreamDisconnectedNotificationParams { - #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option, - pub reason: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CodexEventNotificationParams { - #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] - pub meta: Option, - pub msg: EventMsg, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CancelNotificationParams { - pub request_id: RequestId, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub reason: Option, -} - -impl Serialize for ServerNotification { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - use serde::ser::SerializeMap; - - let mut map = serializer.serialize_map(Some(2))?; - match self { - ServerNotification::CodexEvent(p) => { - map.serialize_entry("method", &format!("notifications/{}", p.msg))?; - map.serialize_entry("params", p)?; - } - ServerNotification::InitialState(p) => { - map.serialize_entry("method", "notifications/initial_state")?; - map.serialize_entry("params", p)?; - } - ServerNotification::StreamDisconnected(p) => { - map.serialize_entry("method", "notifications/stream_disconnected")?; - map.serialize_entry("params", p)?; - } - } - map.end() - } -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(tag = "method", content = "params", rename_all = "camelCase")] -pub enum ClientNotification { - #[serde(rename = "notifications/cancelled")] - Cancelled(CancelNotificationParams), -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use super::*; - use codex_core::protocol::McpInvocation; - use codex_core::protocol::McpToolCallBeginEvent; - use pretty_assertions::assert_eq; - use serde::Serialize; - use serde_json::Value; - use serde_json::json; - use uuid::uuid; - - fn to_val(v: &T) -> Value { - serde_json::to_value(v).expect("serialize to Value") - } - - // ----- Requests ----- - - #[test] - fn serialize_tool_call_request_params_conversation_create_minimal() { - let req = ToolCallRequestParams::ConversationCreate(ConversationCreateArgs { - prompt: "".into(), - model: "o3".into(), - cwd: "/repo".into(), - approval_policy: None, - sandbox: None, - config: None, - profile: None, - base_instructions: None, - }); - - let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); - let expected = json!({ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": "conversationCreate", - "arguments": { - "prompt": "", - "model": "o3", - "cwd": "/repo" - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_tool_call_request_params_conversation_send_message_with_overrides_and_parent_message_id() - { - let req = ToolCallRequestParams::ConversationSendMessage(ConversationSendMessageArgs { - conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")), - content: vec![ - InputItem::Text { text: "Hi".into() }, - InputItem::Image { - image_url: "https://example.com/cat.jpg".into(), - }, - InputItem::LocalImage { - path: "notes.txt".into(), - }, - ], - parent_message_id: Some(MessageId(uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"))), - conversation_overrides: Some(ConversationOverrides { - model: Some("o4-mini".into()), - cwd: Some("/workdir".into()), - approval_policy: None, - sandbox: Some(SandboxMode::DangerFullAccess), - config: Some(json!({"temp": 0.2})), - profile: Some("eng".into()), - base_instructions: Some("Be terse".into()), - }), - }); - - let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); - let expected = json!({ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": "conversationSendMessage", - "arguments": { - "conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab", - "content": [ - { "type": "text", "text": "Hi" }, - { "type": "image", "image_url": "https://example.com/cat.jpg" }, - { "type": "local_image", "path": "notes.txt" } - ], - "parent_message_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "model": "o4-mini", - "cwd": "/workdir", - "sandbox": "danger-full-access", - "config": { "temp": 0.2 }, - "profile": "eng", - "base_instructions": "Be terse" - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_tool_call_request_params_conversations_list_with_opts() { - let req = ToolCallRequestParams::ConversationsList(ConversationsListArgs { - limit: Some(50), - cursor: Some("abc".into()), - }); - - let observed = to_val(&req.into_request(RequestId::Integer(2))); - let expected = json!({ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": "conversationsList", - "arguments": { - "limit": 50, - "cursor": "abc" - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_tool_call_request_params_conversation_stream() { - let req = ToolCallRequestParams::ConversationStream(ConversationStreamArgs { - conversation_id: ConversationId(uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8")), - }); - - let observed = to_val(&req.into_request(mcp_types::RequestId::Integer(2))); - let expected = json!({ - "jsonrpc": "2.0", - "id": 2, - "method": "tools/call", - "params": { - "name": "conversationStream", - "arguments": { - "conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8" - } - } - }); - assert_eq!(observed, expected); - } - - // ----- Message inputs / sources ----- - - #[test] - fn serialize_message_input_image_url() { - let item = InputItem::Image { - image_url: "https://example.com/x.png".into(), - }; - let observed = to_val(&item); - let expected = json!({ - "type": "image", - "image_url": "https://example.com/x.png" - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_message_input_local_image_path() { - let url = InputItem::LocalImage { - path: PathBuf::from("https://example.com/a.pdf"), - }; - let id = InputItem::LocalImage { - path: PathBuf::from("file_456"), - }; - let observed_url = to_val(&url); - let expected_url = json!({"type":"local_image","path":"https://example.com/a.pdf"}); - assert_eq!( - observed_url, expected_url, - "LocalImage with URL path should serialize as image_url" - ); - let observed_id = to_val(&id); - let expected_id = json!({"type":"local_image","path":"file_456"}); - assert_eq!( - observed_id, expected_id, - "LocalImage with file id should serialize as image_url" - ); - } - - #[test] - fn serialize_message_input_image_url_without_detail() { - let item = InputItem::Image { - image_url: "https://example.com/x.png".into(), - }; - let observed = to_val(&item); - let expected = json!({ - "type": "image", - "image_url": "https://example.com/x.png" - }); - assert_eq!(observed, expected); - } - - // ----- Responses ----- - - #[test] - fn response_success_conversation_create_full_schema() { - let env = ToolCallResponse { - request_id: RequestId::Integer(1), - is_error: None, - result: Some(ToolCallResponseResult::ConversationCreate( - ConversationCreateResult::Ok { - conversation_id: ConversationId(uuid!("d0f6ecbe-84a2-41c1-b23d-b20473b25eab")), - model: "o3".into(), - }, - )), - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [ - { "type": "text", "text": "{\"conversation_id\":\"d0f6ecbe-84a2-41c1-b23d-b20473b25eab\",\"model\":\"o3\"}" } - ], - "structuredContent": { - "conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab", - "model": "o3" - } - }); - assert_eq!( - observed, expected, - "response (ConversationCreate) must match" - ); - assert_eq!(req_id, RequestId::Integer(1)); - } - - #[test] - fn response_error_conversation_create_full_schema() { - let env = ToolCallResponse { - request_id: RequestId::Integer(2), - is_error: Some(true), - result: Some(ToolCallResponseResult::ConversationCreate( - ConversationCreateResult::Error { - message: "Failed to initialize session".into(), - }, - )), - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [ - { "type": "text", "text": "{\"message\":\"Failed to initialize session\"}" } - ], - "isError": true, - "structuredContent": { - "message": "Failed to initialize session" - } - }); - assert_eq!( - observed, expected, - "error response (ConversationCreate) must match" - ); - assert_eq!(req_id, RequestId::Integer(2)); - } - - #[test] - fn response_success_conversation_stream_empty_result_object() { - let env = ToolCallResponse { - request_id: RequestId::Integer(2), - is_error: None, - result: Some(ToolCallResponseResult::ConversationStream( - ConversationStreamResult {}, - )), - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [ { "type": "text", "text": "{}" } ], - "structuredContent": {} - }); - assert_eq!( - observed, expected, - "response (ConversationStream) must have empty object result" - ); - assert_eq!(req_id, RequestId::Integer(2)); - } - - #[test] - fn response_success_send_message_accepted_full_schema() { - let env = ToolCallResponse { - request_id: RequestId::Integer(3), - is_error: None, - result: Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Ok, - )), - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [ { "type": "text", "text": "{\"status\":\"ok\"}" } ], - "structuredContent": { "status": "ok" } - }); - assert_eq!( - observed, expected, - "response (ConversationSendMessageAccepted) must match" - ); - assert_eq!(req_id, RequestId::Integer(3)); - } - - #[test] - fn response_success_conversations_list_with_next_cursor_full_schema() { - let env = ToolCallResponse { - request_id: RequestId::Integer(4), - is_error: None, - result: Some(ToolCallResponseResult::ConversationsList( - ConversationsListResult { - conversations: vec![ConversationSummary { - conversation_id: ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - )), - title: "Refactor config loader".into(), - }], - next_cursor: Some("next123".into()), - }, - )), - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [ - { "type": "text", "text": "{\"conversations\":[{\"conversation_id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\",\"title\":\"Refactor config loader\"}],\"next_cursor\":\"next123\"}" } - ], - "structuredContent": { - "conversations": [ - { - "conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "title": "Refactor config loader" - } - ], - "next_cursor": "next123" - } - }); - assert_eq!( - observed, expected, - "response (ConversationsList with cursor) must match" - ); - assert_eq!(req_id, RequestId::Integer(4)); - } - - #[test] - fn response_error_only_is_error_and_request_id_string() { - let env = ToolCallResponse { - request_id: RequestId::Integer(4), - is_error: Some(true), - result: None, - }; - let req_id = env.request_id.clone(); - let observed = to_val(&CallToolResult::from(env)); - let expected = json!({ - "content": [], - "isError": true - }); - assert_eq!( - observed, expected, - "error response must omit `result` and include `isError`" - ); - assert_eq!(req_id, RequestId::Integer(4)); - } - - // ----- Notifications ----- - - #[test] - fn serialize_notification_initial_state_minimal() { - let params = InitialStateNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: Some(RequestId::Integer(44)), - }), - initial_state: InitialStatePayload { - events: vec![ - CodexEventNotificationParams { - meta: None, - msg: EventMsg::TaskStarted, - }, - CodexEventNotificationParams { - meta: None, - msg: EventMsg::AgentMessageDelta( - codex_core::protocol::AgentMessageDeltaEvent { - delta: "Loading...".into(), - }, - ), - }, - ], - }, - }; - - let observed = to_val(&ServerNotification::InitialState(params.clone())); - let expected = json!({ - "method": "notifications/initial_state", - "params": { - "_meta": { - "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "requestId": 44 - }, - "initial_state": { - "events": [ - { "msg": { "type": "task_started" } }, - { "msg": { "type": "agent_message_delta", "delta": "Loading..." } } - ] - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_initial_state_omits_empty_events_full_json() { - let params = InitialStateNotificationParams { - meta: None, - initial_state: InitialStatePayload { events: vec![] }, - }; - - let observed = to_val(&ServerNotification::InitialState(params)); - let expected = json!({ - "method": "notifications/initial_state", - "params": { - "initial_state": { "events": [] } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_stream_disconnected() { - let params = StreamDisconnectedNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: None, - }), - reason: "New stream() took over".into(), - }; - - let observed = to_val(&ServerNotification::StreamDisconnected(params)); - let expected = json!({ - "method": "notifications/stream_disconnected", - "params": { - "_meta": { "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8" }, - "reason": "New stream() took over" - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_uses_eventmsg_type_in_method() { - let params = CodexEventNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: Some(RequestId::Integer(44)), - }), - msg: EventMsg::AgentMessage(codex_core::protocol::AgentMessageEvent { - message: "hi".into(), - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/agent_message", - "params": { - "_meta": { - "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "requestId": 44 - }, - "msg": { "type": "agent_message", "message": "hi" } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_task_started_full_json() { - let params = CodexEventNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: Some(RequestId::Integer(7)), - }), - msg: EventMsg::TaskStarted, - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/task_started", - "params": { - "_meta": { - "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "requestId": 7 - }, - "msg": { "type": "task_started" } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_agent_message_delta_full_json() { - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::AgentMessageDelta(codex_core::protocol::AgentMessageDeltaEvent { - delta: "stream...".into(), - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/agent_message_delta", - "params": { - "msg": { "type": "agent_message_delta", "delta": "stream..." } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_agent_message_full_json() { - let params = CodexEventNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: Some(RequestId::Integer(44)), - }), - msg: EventMsg::AgentMessage(codex_core::protocol::AgentMessageEvent { - message: "hi".into(), - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/agent_message", - "params": { - "_meta": { - "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "requestId": 44 - }, - "msg": { "type": "agent_message", "message": "hi" } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_agent_reasoning_full_json() { - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::AgentReasoning(codex_core::protocol::AgentReasoningEvent { - text: "thinking…".into(), - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/agent_reasoning", - "params": { - "msg": { "type": "agent_reasoning", "text": "thinking…" } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_token_count_full_json() { - let usage = codex_core::protocol::TokenUsage { - input_tokens: 10, - cached_input_tokens: Some(2), - output_tokens: 5, - reasoning_output_tokens: Some(1), - total_tokens: 16, - }; - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::TokenCount(usage), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/token_count", - "params": { - "msg": { - "type": "token_count", - "input_tokens": 10, - "cached_input_tokens": 2, - "output_tokens": 5, - "reasoning_output_tokens": 1, - "total_tokens": 16 - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_session_configured_full_json() { - let params = CodexEventNotificationParams { - meta: Some(NotificationMeta { - conversation_id: Some(ConversationId(uuid!( - "67e55044-10b1-426f-9247-bb680e5fe0c8" - ))), - request_id: None, - }), - msg: EventMsg::SessionConfigured(codex_core::protocol::SessionConfiguredEvent { - session_id: uuid!("67e55044-10b1-426f-9247-bb680e5fe0c8"), - model: "codex-mini-latest".into(), - history_log_id: 42, - history_entry_count: 3, - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/session_configured", - "params": { - "_meta": { "conversationId": "67e55044-10b1-426f-9247-bb680e5fe0c8" }, - "msg": { - "type": "session_configured", - "session_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", - "model": "codex-mini-latest", - "history_log_id": 42, - "history_entry_count": 3 - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_exec_command_begin_full_json() { - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::ExecCommandBegin(codex_core::protocol::ExecCommandBeginEvent { - call_id: "c1".into(), - command: vec!["bash".into(), "-lc".into(), "echo hi".into()], - cwd: std::path::PathBuf::from("/work"), - parsed_cmd: vec![], - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/exec_command_begin", - "params": { - "msg": { - "type": "exec_command_begin", - "call_id": "c1", - "command": ["bash", "-lc", "echo hi"], - "cwd": "/work", - "parsed_cmd": [] - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_mcp_tool_call_begin_full_json() { - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::McpToolCallBegin(McpToolCallBeginEvent { - call_id: "m1".into(), - invocation: McpInvocation { - server: "calc".into(), - tool: "add".into(), - arguments: Some(json!({"a":1,"b":2})), - }, - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/mcp_tool_call_begin", - "params": { - "msg": { - "type": "mcp_tool_call_begin", - "call_id": "m1", - "invocation": { - "server": "calc", - "tool": "add", - "arguments": { "a": 1, "b": 2 } - } - } - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_codex_event_patch_apply_end_full_json() { - let params = CodexEventNotificationParams { - meta: None, - msg: EventMsg::PatchApplyEnd(codex_core::protocol::PatchApplyEndEvent { - call_id: "p1".into(), - stdout: "ok".into(), - stderr: "".into(), - success: true, - }), - }; - - let observed = to_val(&ServerNotification::CodexEvent(Box::new(params))); - let expected = json!({ - "method": "notifications/patch_apply_end", - "params": { - "msg": { - "type": "patch_apply_end", - "call_id": "p1", - "stdout": "ok", - "stderr": "", - "success": true - } - } - }); - assert_eq!(observed, expected); - } - - // ----- Cancelled notifications ----- - - #[test] - fn serialize_notification_cancelled_with_reason_full_json() { - let params = CancelNotificationParams { - request_id: RequestId::String("r-123".into()), - reason: Some("user_cancelled".into()), - }; - - let observed = to_val(&ClientNotification::Cancelled(params)); - let expected = json!({ - "method": "notifications/cancelled", - "params": { - "requestId": "r-123", - "reason": "user_cancelled" - } - }); - assert_eq!(observed, expected); - } - - #[test] - fn serialize_notification_cancelled_without_reason_full_json() { - let params = CancelNotificationParams { - request_id: RequestId::Integer(77), - reason: None, - }; - - let observed = to_val(&ClientNotification::Cancelled(params)); - - // Check exact structure: reason must be omitted. - assert_eq!(observed["method"], "notifications/cancelled"); - assert_eq!(observed["params"]["requestId"], 77); - assert!( - observed["params"].get("reason").is_none(), - "reason must be omitted when None" - ); - } -} diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index 763e51bf..2143cecc 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; @@ -9,18 +8,12 @@ use crate::codex_tool_config::CodexToolCallReplyParam; use crate::codex_tool_config::create_tool_for_codex_tool_call_param; use crate::codex_tool_config::create_tool_for_codex_tool_call_reply_param; use crate::error_code::INVALID_REQUEST_ERROR_CODE; -use crate::mcp_protocol::ToolCallRequestParams; -use crate::mcp_protocol::ToolCallResponse; -use crate::mcp_protocol::ToolCallResponseResult; use crate::outgoing_message::OutgoingMessageSender; -use crate::tool_handlers::create_conversation::handle_create_conversation; -use crate::tool_handlers::send_message::handle_send_message; use crate::wire_format::ClientRequest; use codex_core::ConversationManager; use codex_core::config::Config as CodexConfig; use codex_core::protocol::Submission; -use mcp_types::CallToolRequest; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::ClientRequest as McpClientRequest; @@ -48,7 +41,6 @@ pub(crate) struct MessageProcessor { codex_linux_sandbox_exe: Option, conversation_manager: Arc, running_requests_id_to_codex_uuid: Arc>>, - running_session_ids: Arc>>, } impl MessageProcessor { @@ -72,22 +64,9 @@ impl MessageProcessor { codex_linux_sandbox_exe, conversation_manager, running_requests_id_to_codex_uuid: Arc::new(Mutex::new(HashMap::new())), - running_session_ids: Arc::new(Mutex::new(HashSet::new())), } } - pub(crate) fn get_conversation_manager(&self) -> &ConversationManager { - &self.conversation_manager - } - - pub(crate) fn outgoing(&self) -> Arc { - self.outgoing.clone() - } - - pub(crate) fn running_session_ids(&self) -> Arc>> { - self.running_session_ids.clone() - } - pub(crate) async fn process_request(&mut self, request: JSONRPCRequest) { if let Ok(request_json) = serde_json::to_value(request.clone()) && let Ok(codex_request) = serde_json::from_value::(request_json) @@ -341,14 +320,6 @@ impl MessageProcessor { params: ::Params, ) { tracing::info!("tools/call -> params: {:?}", params); - // Serialize params into JSON and try to parse as new type - if let Ok(new_params) = - serde_json::to_value(¶ms).and_then(serde_json::from_value::) - { - // New tool call matched → forward - self.handle_new_tool_calls(id, new_params).await; - return; - } let CallToolRequestParams { name, arguments } = params; match name.as_str() { @@ -372,30 +343,6 @@ impl MessageProcessor { } } } - async fn handle_new_tool_calls(&self, request_id: RequestId, params: ToolCallRequestParams) { - match params { - ToolCallRequestParams::ConversationCreate(args) => { - handle_create_conversation(self, request_id, args).await; - } - ToolCallRequestParams::ConversationSendMessage(args) => { - handle_send_message(self, request_id, args).await; - } - _ => { - let result = CallToolResult { - content: vec![ContentBlock::TextContent(TextContent { - r#type: "text".to_string(), - text: "Unknown tool".to_string(), - annotations: None, - })], - is_error: Some(true), - structured_content: None, - }; - self.send_response::(request_id, result) - .await; - } - } - } - async fn handle_tool_call_codex(&self, id: RequestId, arguments: Option) { let (initial_prompt, config): (String, CodexConfig) = match arguments { Some(json_val) => match serde_json::from_value::(json_val) { @@ -692,20 +639,4 @@ impl MessageProcessor { ) { tracing::info!("notifications/message -> params: {:?}", params); } - - pub(crate) async fn send_response_with_optional_error( - &self, - id: RequestId, - message: Option, - error: Option, - ) { - let response = ToolCallResponse { - request_id: id.clone(), - is_error: error, - result: message, - }; - let result: CallToolResult = response.into(); - self.send_response::(id.clone(), result) - .await; - } } diff --git a/codex-rs/mcp-server/src/outgoing_message.rs b/codex-rs/mcp-server/src/outgoing_message.rs index f408ca0a..c5e51a34 100644 --- a/codex-rs/mcp-server/src/outgoing_message.rs +++ b/codex-rs/mcp-server/src/outgoing_message.rs @@ -119,9 +119,6 @@ impl OutgoingMessageSender { params: Some(params.clone()), }) .await; - - self.send_event_as_notification_new_schema(event, Some(params.clone())) - .await; } pub(crate) async fn send_notification(&self, notification: OutgoingNotification) { @@ -129,19 +126,6 @@ impl OutgoingMessageSender { let _ = self.sender.send(outgoing_message).await; } - // should be backwards compatible. - // it will replace send_event_as_notification eventually. - async fn send_event_as_notification_new_schema( - &self, - event: &Event, - params: Option, - ) { - let outgoing_message = OutgoingMessage::Notification(OutgoingNotification { - method: event.msg.to_string(), - params, - }); - let _ = self.sender.send(outgoing_message).await; - } pub(crate) async fn send_error(&self, id: RequestId, error: JSONRPCErrorError) { let outgoing_message = OutgoingMessage::Error(OutgoingError { id, error }); let _ = self.sender.send(outgoing_message).await; @@ -281,17 +265,6 @@ mod tests { panic!("Event must serialize"); }; assert_eq!(params, Some(expected_params.clone())); - - let result2 = outgoing_rx.recv().await.unwrap(); - let OutgoingMessage::Notification(OutgoingNotification { - method: method2, - params: params2, - }) = result2 - else { - panic!("expected Notification for second message"); - }; - assert_eq!(method2, event.msg.to_string()); - assert_eq!(params2, Some(expected_params)); } #[tokio::test] @@ -336,16 +309,5 @@ mod tests { } }); assert_eq!(params.unwrap(), expected_params); - - let result2 = outgoing_rx.recv().await.unwrap(); - let OutgoingMessage::Notification(OutgoingNotification { - method: method2, - params: params2, - }) = result2 - else { - panic!("expected Notification for second message"); - }; - assert_eq!(method2, event.msg.to_string()); - assert_eq!(params2.unwrap(), expected_params); } } diff --git a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs b/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs deleted file mode 100644 index eee2e1d5..00000000 --- a/codex-rs/mcp-server/src/tool_handlers/create_conversation.rs +++ /dev/null @@ -1,127 +0,0 @@ -use std::path::PathBuf; - -use codex_core::NewConversation; -use codex_core::config::Config as CodexConfig; -use codex_core::config::ConfigOverrides; -use mcp_types::RequestId; - -use crate::conversation_loop::run_conversation_loop; -use crate::json_to_toml::json_to_toml; -use crate::mcp_protocol::ConversationCreateArgs; -use crate::mcp_protocol::ConversationCreateResult; -use crate::mcp_protocol::ConversationId; -use crate::mcp_protocol::ToolCallResponseResult; -use crate::message_processor::MessageProcessor; - -pub(crate) async fn handle_create_conversation( - message_processor: &MessageProcessor, - id: RequestId, - args: ConversationCreateArgs, -) { - // Build ConfigOverrides from args - let ConversationCreateArgs { - prompt: _, // not used here; creation only establishes the session - model, - cwd, - approval_policy, - sandbox, - config, - profile, - base_instructions, - } = args; - - // Convert config overrides JSON into CLI-style TOML overrides - let cli_overrides: Vec<(String, toml::Value)> = match config { - Some(v) => match v.as_object() { - Some(map) => map - .into_iter() - .map(|(k, v)| (k.clone(), json_to_toml(v.clone()))) - .collect(), - None => Vec::new(), - }, - None => Vec::new(), - }; - - let overrides = ConfigOverrides { - model: Some(model.clone()), - cwd: Some(PathBuf::from(cwd)), - approval_policy, - sandbox_mode: sandbox, - model_provider: None, - config_profile: profile, - codex_linux_sandbox_exe: None, - base_instructions, - include_plan_tool: None, - include_apply_patch_tool: None, - disable_response_storage: None, - show_raw_agent_reasoning: None, - }; - - let cfg: CodexConfig = match CodexConfig::load_with_cli_overrides(cli_overrides, overrides) { - Ok(cfg) => cfg, - Err(e) => { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationCreate( - ConversationCreateResult::Error { - message: format!("Failed to load config: {e}"), - }, - )), - Some(true), - ) - .await; - return; - } - }; - - // Initialize Codex session via server API - let NewConversation { - conversation_id: session_id, - conversation, - session_configured, - } = match message_processor - .get_conversation_manager() - .new_conversation(cfg) - .await - { - Ok(conv) => conv, - Err(e) => { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationCreate( - ConversationCreateResult::Error { - message: format!("Failed to initialize session: {e}"), - }, - )), - Some(true), - ) - .await; - return; - } - }; - - let effective_model = session_configured.model.clone(); - - // Run the conversation loop in the background so this request can return immediately. - let outgoing = message_processor.outgoing(); - let spawn_id = id.clone(); - tokio::spawn(async move { - run_conversation_loop(conversation.clone(), outgoing, spawn_id).await; - }); - - // Reply with the new conversation id and effective model - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationCreate( - ConversationCreateResult::Ok { - conversation_id: ConversationId(session_id), - model: effective_model, - }, - )), - Some(false), - ) - .await; -} diff --git a/codex-rs/mcp-server/src/tool_handlers/send_message.rs b/codex-rs/mcp-server/src/tool_handlers/send_message.rs deleted file mode 100644 index 985854f8..00000000 --- a/codex-rs/mcp-server/src/tool_handlers/send_message.rs +++ /dev/null @@ -1,114 +0,0 @@ -use codex_core::protocol::Op; -use codex_core::protocol::Submission; -use mcp_types::RequestId; - -use crate::mcp_protocol::ConversationSendMessageArgs; -use crate::mcp_protocol::ConversationSendMessageResult; -use crate::mcp_protocol::ToolCallResponseResult; -use crate::message_processor::MessageProcessor; - -pub(crate) async fn handle_send_message( - message_processor: &MessageProcessor, - id: RequestId, - arguments: ConversationSendMessageArgs, -) { - let ConversationSendMessageArgs { - conversation_id, - content: items, - parent_message_id: _, - conversation_overrides: _, - } = arguments; - - if items.is_empty() { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: "No content items provided".to_string(), - }, - )), - Some(true), - ) - .await; - return; - } - - let session_id = conversation_id.0; - let Ok(codex) = message_processor - .get_conversation_manager() - .get_conversation(session_id) - .await - else { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: "Session does not exist".to_string(), - }, - )), - Some(true), - ) - .await; - return; - }; - - let running = { - let running_sessions = message_processor.running_session_ids(); - let mut running_sessions = running_sessions.lock().await; - !running_sessions.insert(session_id) - }; - - if running { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: "Session is already running".to_string(), - }, - )), - Some(true), - ) - .await; - return; - } - - let request_id_string = match &id { - RequestId::String(s) => s.clone(), - RequestId::Integer(i) => i.to_string(), - }; - - let submit_res = codex - .submit_with_id(Submission { - id: request_id_string, - op: Op::UserInput { items }, - }) - .await; - - if let Err(e) = submit_res { - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Error { - message: format!("Failed to submit user input: {e}"), - }, - )), - Some(true), - ) - .await; - return; - } - - message_processor - .send_response_with_optional_error( - id, - Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult::Ok, - )), - Some(false), - ) - .await; -} diff --git a/codex-rs/mcp-server/tests/codex_message_processor_flow.rs b/codex-rs/mcp-server/tests/codex_message_processor_flow.rs index 5b89f3fe..03706af1 100644 --- a/codex-rs/mcp-server/tests/codex_message_processor_flow.rs +++ b/codex-rs/mcp-server/tests/codex_message_processor_flow.rs @@ -20,11 +20,11 @@ use mcp_test_support::McpProcess; use mcp_test_support::create_final_assistant_message_sse_response; use mcp_test_support::create_mock_chat_completions_server; use mcp_test_support::create_shell_sse_response; +use mcp_test_support::to_response; use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCResponse; use mcp_types::RequestId; use pretty_assertions::assert_eq; -use serde::de::DeserializeOwned; use std::env; use tempfile::TempDir; use tokio::time::timeout; @@ -168,12 +168,6 @@ async fn test_codex_jsonrpc_conversation_flow() { to_response(remove_listener_resp).expect("deserialize removeConversationListener response"); } -fn to_response(response: JSONRPCResponse) -> anyhow::Result { - let value = serde_json::to_value(response.result)?; - let codex_response = serde_json::from_value(value)?; - Ok(codex_response) -} - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_send_user_turn_changes_approval_policy_behavior() { if env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { diff --git a/codex-rs/mcp-server/tests/common/Cargo.toml b/codex-rs/mcp-server/tests/common/Cargo.toml index 420b9781..3528ad6e 100644 --- a/codex-rs/mcp-server/tests/common/Cargo.toml +++ b/codex-rs/mcp-server/tests/common/Cargo.toml @@ -1,7 +1,7 @@ [package] +edition = "2024" name = "mcp_test_support" version = { workspace = true } -edition = "2024" [lib] path = "lib.rs" @@ -9,10 +9,11 @@ path = "lib.rs" [dependencies] anyhow = "1" assert_cmd = "2" -codex-mcp-server = { path = "../.." } codex-core = { path = "../../../core" } +codex-mcp-server = { path = "../.." } mcp-types = { path = "../../../mcp-types" } pretty_assertions = "1.4.1" +serde = { version = "1" } serde_json = "1" shlex = "1.3.0" tempfile = "3" @@ -22,5 +23,5 @@ tokio = { version = "1", features = [ "process", "rt-multi-thread", ] } +uuid = { version = "1", features = ["serde", "v4"] } wiremock = "0.6" -uuid = { version = "1", features = ["serde", "v4"] } \ No newline at end of file diff --git a/codex-rs/mcp-server/tests/common/lib.rs b/codex-rs/mcp-server/tests/common/lib.rs index b338e2e8..d088b184 100644 --- a/codex-rs/mcp-server/tests/common/lib.rs +++ b/codex-rs/mcp-server/tests/common/lib.rs @@ -3,7 +3,15 @@ mod mock_model_server; mod responses; pub use mcp_process::McpProcess; +use mcp_types::JSONRPCResponse; pub use mock_model_server::create_mock_chat_completions_server; pub use responses::create_apply_patch_sse_response; pub use responses::create_final_assistant_message_sse_response; pub use responses::create_shell_sse_response; +use serde::de::DeserializeOwned; + +pub fn to_response(response: JSONRPCResponse) -> anyhow::Result { + let value = serde_json::to_value(response.result)?; + let codex_response = serde_json::from_value(value)?; + Ok(codex_response) +} diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index dc783344..5d2dbac0 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -11,14 +11,9 @@ use tokio::process::ChildStdout; use anyhow::Context; use assert_cmd::prelude::*; -use codex_core::protocol::InputItem; use codex_mcp_server::CodexToolCallParam; -use codex_mcp_server::CodexToolCallReplyParam; -use codex_mcp_server::mcp_protocol::ConversationCreateArgs; -use codex_mcp_server::mcp_protocol::ConversationId; -use codex_mcp_server::mcp_protocol::ConversationSendMessageArgs; -use codex_mcp_server::mcp_protocol::ToolCallRequestParams; use codex_mcp_server::wire_format::AddConversationListenerParams; +use codex_mcp_server::wire_format::InterruptConversationParams; use codex_mcp_server::wire_format::NewConversationParams; use codex_mcp_server::wire_format::RemoveConversationListenerParams; use codex_mcp_server::wire_format::SendUserMessageParams; @@ -40,7 +35,6 @@ use pretty_assertions::assert_eq; use serde_json::json; use std::process::Command as StdCommand; use tokio::process::Command; -use uuid::Uuid; pub struct McpProcess { next_request_id: AtomicI64, @@ -167,83 +161,6 @@ impl McpProcess { .await } - pub async fn send_codex_reply_tool_call( - &mut self, - session_id: &str, - prompt: &str, - ) -> anyhow::Result { - let codex_tool_call_params = CallToolRequestParams { - name: "codex-reply".to_string(), - arguments: Some(serde_json::to_value(CodexToolCallReplyParam { - prompt: prompt.to_string(), - session_id: session_id.to_string(), - })?), - }; - self.send_request( - mcp_types::CallToolRequest::METHOD, - Some(serde_json::to_value(codex_tool_call_params)?), - ) - .await - } - - pub async fn send_user_message_tool_call( - &mut self, - message: &str, - session_id: &str, - ) -> anyhow::Result { - let params = ToolCallRequestParams::ConversationSendMessage(ConversationSendMessageArgs { - conversation_id: ConversationId(Uuid::parse_str(session_id)?), - content: vec![InputItem::Text { - text: message.to_string(), - }], - parent_message_id: None, - conversation_overrides: None, - }); - self.send_request( - mcp_types::CallToolRequest::METHOD, - Some(serde_json::to_value(params)?), - ) - .await - } - - pub async fn send_conversation_create_tool_call( - &mut self, - prompt: &str, - model: &str, - cwd: &str, - ) -> anyhow::Result { - let params = ToolCallRequestParams::ConversationCreate(ConversationCreateArgs { - prompt: prompt.to_string(), - model: model.to_string(), - cwd: cwd.to_string(), - approval_policy: None, - sandbox: None, - config: None, - profile: None, - base_instructions: None, - }); - self.send_request( - mcp_types::CallToolRequest::METHOD, - Some(serde_json::to_value(params)?), - ) - .await - } - - pub async fn send_conversation_create_with_args( - &mut self, - args: ConversationCreateArgs, - ) -> anyhow::Result { - let params = ToolCallRequestParams::ConversationCreate(args); - self.send_request( - mcp_types::CallToolRequest::METHOD, - Some(serde_json::to_value(params)?), - ) - .await - } - - // --------------------------------------------------------------------- - // Codex JSON-RPC (non-tool) helpers - // --------------------------------------------------------------------- /// Send a `newConversation` JSON-RPC request. pub async fn send_new_conversation_request( &mut self, @@ -291,6 +208,15 @@ impl McpProcess { self.send_request("sendUserTurn", params).await } + /// Send a `interruptConversation` JSON-RPC request. + pub async fn send_interrupt_conversation_request( + &mut self, + params: InterruptConversationParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("interruptConversation", params).await + } + async fn send_request( &mut self, method: &str, @@ -335,6 +261,7 @@ impl McpProcess { let message = serde_json::from_str::(&line)?; Ok(message) } + pub async fn read_stream_until_request_message(&mut self) -> anyhow::Result { loop { let message = self.read_jsonrpc_message().await?; @@ -384,6 +311,33 @@ impl McpProcess { } } + pub async fn read_stream_until_error_message( + &mut self, + request_id: RequestId, + ) -> anyhow::Result { + loop { + let message = self.read_jsonrpc_message().await?; + eprint!("message: {message:?}"); + + match message { + JSONRPCMessage::Notification(_) => { + eprintln!("notification: {message:?}"); + } + JSONRPCMessage::Request(_) => { + anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); + } + JSONRPCMessage::Response(_) => { + // Keep scanning; we're waiting for an error with matching id. + } + JSONRPCMessage::Error(err) => { + if err.id == request_id { + return Ok(err); + } + } + } + } + } + pub async fn read_stream_until_notification_message( &mut self, method: &str, @@ -411,80 +365,6 @@ impl McpProcess { } } - pub async fn read_stream_until_configured_response_message( - &mut self, - ) -> anyhow::Result { - let mut sid_old: Option = None; - let mut sid_new: Option = None; - loop { - let message = self.read_jsonrpc_message().await?; - eprint!("message: {message:?}"); - - match message { - JSONRPCMessage::Notification(notification) => { - if let Some(params) = notification.params { - // Back-compat schema: method == "codex/event" and msg.type == "session_configured" - if notification.method == "codex/event" { - if let Some(msg) = params.get("msg") { - if msg.get("type").and_then(|v| v.as_str()) - == Some("session_configured") - { - if let Some(session_id) = - msg.get("session_id").and_then(|v| v.as_str()) - { - sid_old = Some(session_id.to_string()); - } - } - } - } - // New schema: method is the Display of EventMsg::SessionConfigured => "SessionConfigured" - if notification.method == "session_configured" { - if let Some(msg) = params.get("msg") { - if let Some(session_id) = - msg.get("session_id").and_then(|v| v.as_str()) - { - sid_new = Some(session_id.to_string()); - } - } - } - } - - if sid_old.is_some() && sid_new.is_some() { - // Both seen, they must match - assert_eq!( - sid_old.as_ref().unwrap(), - sid_new.as_ref().unwrap(), - "session_id mismatch between old and new schema" - ); - return Ok(sid_old.unwrap()); - } - } - JSONRPCMessage::Request(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Request: {message:?}"); - } - JSONRPCMessage::Error(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Error: {message:?}"); - } - JSONRPCMessage::Response(_) => { - anyhow::bail!("unexpected JSONRPCMessage::Response: {message:?}"); - } - } - } - } - - pub async fn send_notification( - &mut self, - method: &str, - params: Option, - ) -> anyhow::Result<()> { - self.send_jsonrpc_message(JSONRPCMessage::Notification(JSONRPCNotification { - jsonrpc: JSONRPC_VERSION.into(), - method: method.to_string(), - params, - })) - .await - } - /// Reads notifications until a legacy TaskComplete event is observed: /// Method "codex/event" with params.msg.type == "task_complete". pub async fn read_stream_until_legacy_task_complete_notification( diff --git a/codex-rs/mcp-server/tests/create_conversation.rs b/codex-rs/mcp-server/tests/create_conversation.rs index 2349b0b9..86f6963b 100644 --- a/codex-rs/mcp-server/tests/create_conversation.rs +++ b/codex-rs/mcp-server/tests/create_conversation.rs @@ -1,8 +1,16 @@ use std::path::Path; +use codex_mcp_server::wire_format::AddConversationListenerParams; +use codex_mcp_server::wire_format::AddConversationSubscriptionResponse; +use codex_mcp_server::wire_format::InputItem; +use codex_mcp_server::wire_format::NewConversationParams; +use codex_mcp_server::wire_format::NewConversationResponse; +use codex_mcp_server::wire_format::SendUserMessageParams; +use codex_mcp_server::wire_format::SendUserMessageResponse; use mcp_test_support::McpProcess; use mcp_test_support::create_final_assistant_message_sse_response; use mcp_test_support::create_mock_chat_completions_server; +use mcp_test_support::to_response; use mcp_types::JSONRPCResponse; use mcp_types::RequestId; use pretty_assertions::assert_eq; @@ -33,43 +41,64 @@ async fn test_conversation_create_and_send_message_ok() { .expect("init timeout") .expect("init failed"); - // Create a conversation via the new tool. - let req_id = mcp - .send_conversation_create_tool_call("", "o3", "/repo") + // Create a conversation via the new JSON-RPC API. + let new_conv_id = mcp + .send_new_conversation_request(NewConversationParams { + model: Some("o3".to_string()), + ..Default::default() + }) .await - .expect("send conversationCreate"); - - let resp: JSONRPCResponse = timeout( + .expect("send newConversation"); + let new_conv_resp: JSONRPCResponse = timeout( DEFAULT_READ_TIMEOUT, - mcp.read_stream_until_response_message(RequestId::Integer(req_id)), + mcp.read_stream_until_response_message(RequestId::Integer(new_conv_id)), ) .await - .expect("create response timeout") - .expect("create response error"); + .expect("newConversation timeout") + .expect("newConversation resp"); + let NewConversationResponse { + conversation_id, + model, + } = to_response::(new_conv_resp) + .expect("deserialize newConversation response"); + assert_eq!(model, "o3"); - // Structured content must include status=ok, a UUID conversation_id and the model we passed. - let sc = &resp.result["structuredContent"]; - let conv_id = sc["conversation_id"].as_str().expect("uuid string"); - assert!(!conv_id.is_empty()); - assert_eq!(sc["model"], json!("o3")); - - // Now send a message to the created conversation and expect an OK result. - let send_id = mcp - .send_user_message_tool_call("Hello", conv_id) + // Add a listener so we receive notifications for this conversation (not strictly required for this test). + let add_listener_id = mcp + .send_add_conversation_listener_request(AddConversationListenerParams { conversation_id }) .await - .expect("send message"); + .expect("send addConversationListener"); + let _sub: AddConversationSubscriptionResponse = + to_response::( + timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)), + ) + .await + .expect("addConversationListener timeout") + .expect("addConversationListener resp"), + ) + .expect("deserialize addConversationListener response"); + // Now send a user message via the wire API and expect an OK (empty object) result. + let send_id = mcp + .send_send_user_message_request(SendUserMessageParams { + conversation_id, + items: vec![InputItem::Text { + text: "Hello".to_string(), + }], + }) + .await + .expect("send sendUserMessage"); let send_resp: JSONRPCResponse = timeout( DEFAULT_READ_TIMEOUT, mcp.read_stream_until_response_message(RequestId::Integer(send_id)), ) .await - .expect("send response timeout") - .expect("send response error"); - assert_eq!( - send_resp.result["structuredContent"], - json!({ "status": "ok" }) - ); + .expect("sendUserMessage timeout") + .expect("sendUserMessage resp"); + let _ok: SendUserMessageResponse = to_response::(send_resp) + .expect("deserialize sendUserMessage response"); // avoid race condition by waiting for the mock server to receive the chat.completions request let deadline = std::time::Instant::now() + DEFAULT_READ_TIMEOUT; diff --git a/codex-rs/mcp-server/tests/interrupt.rs b/codex-rs/mcp-server/tests/interrupt.rs index 365972e0..08406d37 100644 --- a/codex-rs/mcp-server/tests/interrupt.rs +++ b/codex-rs/mcp-server/tests/interrupt.rs @@ -3,15 +3,24 @@ use std::path::Path; +use codex_core::protocol::TurnAbortReason; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; -use codex_mcp_server::CodexToolCallParam; -use serde_json::json; +use codex_mcp_server::wire_format::AddConversationListenerParams; +use codex_mcp_server::wire_format::InterruptConversationParams; +use codex_mcp_server::wire_format::InterruptConversationResponse; +use codex_mcp_server::wire_format::NewConversationParams; +use codex_mcp_server::wire_format::NewConversationResponse; +use codex_mcp_server::wire_format::SendUserMessageParams; +use codex_mcp_server::wire_format::SendUserMessageResponse; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; use tempfile::TempDir; use tokio::time::timeout; use mcp_test_support::McpProcess; use mcp_test_support::create_mock_chat_completions_server; use mcp_test_support::create_shell_sse_response; +use mcp_test_support::to_response; const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); @@ -39,93 +48,91 @@ async fn shell_command_interruption() -> anyhow::Result<()> { let shell_command = vec![ "powershell".to_string(), "-Command".to_string(), - "Start-Sleep -Seconds 60".to_string(), + "Start-Sleep -Seconds 10".to_string(), ]; #[cfg(not(target_os = "windows"))] - let shell_command = vec!["sleep".to_string(), "60".to_string()]; - let workdir_for_shell_function_call = TempDir::new()?; + let shell_command = vec!["sleep".to_string(), "10".to_string()]; + + let tmp = TempDir::new()?; + // Temporary Codex home with config pointing at the mock server. + let codex_home = tmp.path().join("codex_home"); + std::fs::create_dir(&codex_home)?; + let working_directory = tmp.path().join("workdir"); + std::fs::create_dir(&working_directory)?; // Create mock server with a single SSE response: the long sleep command - let server = create_mock_chat_completions_server(vec![ - create_shell_sse_response( - shell_command.clone(), - Some(workdir_for_shell_function_call.path()), - Some(60_000), // 60 seconds timeout in ms - "call_sleep", - )?, - create_shell_sse_response( - shell_command.clone(), - Some(workdir_for_shell_function_call.path()), - Some(60_000), // 60 seconds timeout in ms - "call_sleep", - )?, - ]) + let server = create_mock_chat_completions_server(vec![create_shell_sse_response( + shell_command.clone(), + Some(&working_directory), + Some(10_000), // 10 seconds timeout in ms + "call_sleep", + )?]) .await; + create_config_toml(&codex_home, server.uri())?; - // Create Codex configuration - let codex_home = TempDir::new()?; - create_config_toml(codex_home.path(), server.uri())?; - let mut mcp_process = McpProcess::new(codex_home.path()).await?; - timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + // Start MCP server and initialize. + let mut mcp = McpProcess::new(&codex_home).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??; - // Send codex tool call that triggers "sleep 60" - let codex_request_id = mcp_process - .send_codex_tool_call(CodexToolCallParam { - cwd: None, - prompt: "First Run: run `sleep 60`".to_string(), - model: None, - profile: None, - approval_policy: None, - sandbox: None, - config: None, - base_instructions: None, - include_plan_tool: None, + // 1) newConversation + let new_conv_id = mcp + .send_new_conversation_request(NewConversationParams { + cwd: Some(working_directory.to_string_lossy().into_owned()), + ..Default::default() }) .await?; + let new_conv_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(new_conv_id)), + ) + .await??; + let new_conv_resp = to_response::(new_conv_resp)?; + let NewConversationResponse { + conversation_id, .. + } = new_conv_resp; - let session_id = mcp_process - .read_stream_until_configured_response_message() + // 2) addConversationListener + let add_listener_id = mcp + .send_add_conversation_listener_request(AddConversationListenerParams { conversation_id }) .await?; + let _add_listener_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)), + ) + .await??; + + // 3) sendUserMessage (should trigger notifications; we only validate an OK response) + let send_user_id = mcp + .send_send_user_message_request(SendUserMessageParams { + conversation_id, + items: vec![codex_mcp_server::wire_format::InputItem::Text { + text: "run first sleep command".to_string(), + }], + }) + .await?; + let send_user_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(send_user_id)), + ) + .await??; + let SendUserMessageResponse {} = to_response::(send_user_resp)?; // Give the command a moment to start tokio::time::sleep(std::time::Duration::from_secs(1)).await; - // Send interrupt notification - mcp_process - .send_notification( - "notifications/cancelled", - Some(json!({ "requestId": codex_request_id })), - ) + // 4) send interrupt request + let interrupt_id = mcp + .send_interrupt_conversation_request(InterruptConversationParams { conversation_id }) .await?; - - // Expect Codex to emit a TurnAborted event notification - let _turn_aborted = timeout( + let interrupt_resp: JSONRPCResponse = timeout( DEFAULT_READ_TIMEOUT, - mcp_process.read_stream_until_notification_message("turn_aborted"), + mcp.read_stream_until_response_message(RequestId::Integer(interrupt_id)), ) .await??; + let InterruptConversationResponse { abort_reason } = + to_response::(interrupt_resp)?; + assert_eq!(TurnAbortReason::Interrupted, abort_reason); - let codex_reply_request_id = mcp_process - .send_codex_reply_tool_call(&session_id, "Second Run: run `sleep 60`") - .await?; - - // Give the command a moment to start - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - // Send interrupt notification - mcp_process - .send_notification( - "notifications/cancelled", - Some(json!({ "requestId": codex_reply_request_id })), - ) - .await?; - - // Expect Codex to emit a TurnAborted event notification - let _turn_aborted = timeout( - DEFAULT_READ_TIMEOUT, - mcp_process.read_stream_until_notification_message("turn_aborted"), - ) - .await??; Ok(()) } diff --git a/codex-rs/mcp-server/tests/send_message.rs b/codex-rs/mcp-server/tests/send_message.rs index bf2966ef..d96d1e09 100644 --- a/codex-rs/mcp-server/tests/send_message.rs +++ b/codex-rs/mcp-server/tests/send_message.rs @@ -1,16 +1,21 @@ use std::path::Path; -use std::thread::sleep; -use std::time::Duration; -use codex_mcp_server::CodexToolCallParam; +use codex_mcp_server::wire_format::AddConversationListenerParams; +use codex_mcp_server::wire_format::AddConversationSubscriptionResponse; +use codex_mcp_server::wire_format::ConversationId; +use codex_mcp_server::wire_format::InputItem; +use codex_mcp_server::wire_format::NewConversationParams; +use codex_mcp_server::wire_format::NewConversationResponse; +use codex_mcp_server::wire_format::SendUserMessageParams; +use codex_mcp_server::wire_format::SendUserMessageResponse; use mcp_test_support::McpProcess; use mcp_test_support::create_final_assistant_message_sse_response; use mcp_test_support::create_mock_chat_completions_server; -use mcp_types::JSONRPC_VERSION; +use mcp_test_support::to_response; +use mcp_types::JSONRPCNotification; use mcp_types::JSONRPCResponse; use mcp_types::RequestId; use pretty_assertions::assert_eq; -use serde_json::json; use tempfile::TempDir; use tokio::time::timeout; @@ -31,76 +36,94 @@ async fn test_send_message_success() { create_config_toml(codex_home.path(), &server.uri()).expect("write config.toml"); // Start MCP server process and initialize. - let mut mcp_process = McpProcess::new(codex_home.path()) + let mut mcp = McpProcess::new(codex_home.path()) .await .expect("spawn mcp process"); - timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()) + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) .await .expect("init timed out") .expect("init failed"); - // Kick off a Codex session so we have a valid session_id. - let codex_request_id = mcp_process - .send_codex_tool_call(CodexToolCallParam { - prompt: "Start a session".to_string(), - ..Default::default() - }) + // Start a conversation using the new wire API. + let new_conv_id = mcp + .send_new_conversation_request(NewConversationParams::default()) .await - .expect("send codex tool call"); - - // Wait for the session_configured event to get the session_id. - let session_id = mcp_process - .read_stream_until_configured_response_message() - .await - .expect("read session_configured"); - - // The original codex call will finish quickly given our mock; consume its response. - timeout( + .expect("send newConversation"); + let new_conv_resp: JSONRPCResponse = timeout( DEFAULT_READ_TIMEOUT, - mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + mcp.read_stream_until_response_message(RequestId::Integer(new_conv_id)), ) .await - .expect("codex response timeout") - .expect("codex response error"); + .expect("newConversation timeout") + .expect("newConversation resp"); + let NewConversationResponse { + conversation_id, .. + } = to_response::<_>(new_conv_resp).expect("deserialize newConversation response"); - // Now exercise the send-user-message tool. - let send_msg_request_id = mcp_process - .send_user_message_tool_call("Hello again", &session_id) + // 2) addConversationListener + let add_listener_id = mcp + .send_add_conversation_listener_request(AddConversationListenerParams { conversation_id }) .await - .expect("send send-message tool call"); + .expect("send addConversationListener"); + let add_listener_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)), + ) + .await + .expect("addConversationListener timeout") + .expect("addConversationListener resp"); + let AddConversationSubscriptionResponse { subscription_id: _ } = + to_response::<_>(add_listener_resp).expect("deserialize addConversationListener response"); + + // Now exercise sendUserMessage twice. + send_message("Hello", conversation_id, &mut mcp).await; + send_message("Hello again", conversation_id, &mut mcp).await; +} + +#[expect(clippy::expect_used)] +async fn send_message(message: &str, conversation_id: ConversationId, mcp: &mut McpProcess) { + // Now exercise sendUserMessage. + let send_id = mcp + .send_send_user_message_request(SendUserMessageParams { + conversation_id, + items: vec![InputItem::Text { + text: message.to_string(), + }], + }) + .await + .expect("send sendUserMessage"); let response: JSONRPCResponse = timeout( DEFAULT_READ_TIMEOUT, - mcp_process.read_stream_until_response_message(RequestId::Integer(send_msg_request_id)), + mcp.read_stream_until_response_message(RequestId::Integer(send_id)), ) .await - .expect("send-user-message response timeout") - .expect("send-user-message response error"); + .expect("sendUserMessage response timeout") + .expect("sendUserMessage response error"); + let _ok: SendUserMessageResponse = to_response::(response) + .expect("deserialize sendUserMessage response"); + + // Verify the task_finished notification is received. + // Note this also ensures that the final request to the server was made. + let task_finished_notification: JSONRPCNotification = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_notification_message("codex/event/task_complete"), + ) + .await + .expect("task_finished_notification timeout") + .expect("task_finished_notification resp"); + let serde_json::Value::Object(map) = task_finished_notification + .params + .expect("notification should have params") + else { + panic!("task_finished_notification should have params"); + }; assert_eq!( - JSONRPCResponse { - jsonrpc: JSONRPC_VERSION.into(), - id: RequestId::Integer(send_msg_request_id), - result: json!({ - "content": [ - { - "text": "{\"status\":\"ok\"}", - "type": "text", - } - ], - "isError": false, - "structuredContent": { - "status": "ok" - } - }), - }, - response + map.get("conversationId") + .expect("should have conversationId"), + &serde_json::Value::String(conversation_id.to_string()) ); - // wait for the server to hear the user message - sleep(Duration::from_secs(5)); - - // Ensure the server and tempdir live until end of test - drop(server); } #[tokio::test] @@ -113,24 +136,26 @@ async fn test_send_message_session_not_found() { .expect("timeout") .expect("init"); - let unknown = uuid::Uuid::new_v4().to_string(); + let unknown = ConversationId(uuid::Uuid::new_v4()); let req_id = mcp - .send_user_message_tool_call("ping", &unknown) + .send_send_user_message_request(SendUserMessageParams { + conversation_id: unknown, + items: vec![InputItem::Text { + text: "ping".to_string(), + }], + }) .await - .expect("send tool"); + .expect("send sendUserMessage"); - let resp: JSONRPCResponse = timeout( + // Expect an error response for unknown conversation. + let err = timeout( DEFAULT_READ_TIMEOUT, - mcp.read_stream_until_response_message(RequestId::Integer(req_id)), + mcp.read_stream_until_error_message(RequestId::Integer(req_id)), ) .await .expect("timeout") - .expect("resp"); - - let result = resp.result.clone(); - let content = result["content"][0]["text"].as_str().unwrap_or(""); - assert!(content.contains("Session does not exist")); - assert_eq!(result["isError"], json!(true)); + .expect("error"); + assert_eq!(err.id, RequestId::Integer(req_id)); } // --------------------------------------------------------------------------- diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 4c9ba6cf..4b9a2902 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -752,7 +752,7 @@ pub struct TurnAbortedEvent { pub reason: TurnAbortReason, } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] #[serde(rename_all = "snake_case")] pub enum TurnAbortReason { Interrupted,