diff --git a/codex-rs/mcp-server/src/mcp_protocol.rs b/codex-rs/mcp-server/src/mcp_protocol.rs index e507376c..23304dc4 100644 --- a/codex-rs/mcp-server/src/mcp_protocol.rs +++ b/codex-rs/mcp-server/src/mcp_protocol.rs @@ -7,7 +7,10 @@ use serde::Serialize; use strum_macros::Display; use uuid::Uuid; +use mcp_types::CallToolResult; +use mcp_types::ContentBlock; use mcp_types::RequestId; +use mcp_types::TextContent; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] @@ -118,10 +121,47 @@ pub struct ToolCallResponse { pub request_id: RequestId, #[serde(default, skip_serializing_if = "Option::is_none")] pub is_error: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] + #[serde(default, skip_serializing_if = "Option::is_none", flatten)] pub result: Option, } +impl From for CallToolResult { + fn from(val: ToolCallResponse) -> Self { + let ToolCallResponse { + request_id: _request_id, + is_error, + result, + } = val; + let (content, structured_content, is_error_out) = match result { + Some(res) => match serde_json::to_value(&res) { + Ok(v) => { + let content = vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text: v.to_string(), + annotations: None, + })]; + (content, Some(v), is_error) + } + Err(e) => { + let content = vec![ContentBlock::TextContent(TextContent { + r#type: "text".to_string(), + text: format!("Failed to serialize tool result: {e}"), + annotations: None, + })]; + (content, None, Some(true)) + } + }, + None => (vec![], None, is_error), + }; + + CallToolResult { + content, + is_error: is_error_out, + structured_content, + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolCallResponseResult { @@ -141,8 +181,10 @@ pub struct ConversationCreateResult { pub struct ConversationStreamResult {} #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct ConversationSendMessageResult { - pub success: bool, +#[serde(tag = "status", rename_all = "camelCase")] +pub enum ConversationSendMessageResult { + Ok, + Error { message: String }, } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -455,10 +497,13 @@ mod tests { }, )), }; - let observed = to_val(&env); + let req_id = env.request_id.clone(); + let observed = to_val(&CallToolResult::from(env)); let expected = json!({ - "requestId": 1, - "result": { + "content": [ + { "type": "text", "text": "{\"conversation_id\":\"d0f6ecbe-84a2-41c1-b23d-b20473b25eab\",\"model\":\"o3\"}" } + ], + "structuredContent": { "conversation_id": "d0f6ecbe-84a2-41c1-b23d-b20473b25eab", "model": "o3" } @@ -467,6 +512,7 @@ mod tests { observed, expected, "response (ConversationCreate) must match" ); + assert_eq!(req_id, RequestId::Integer(1)); } #[test] @@ -478,15 +524,17 @@ mod tests { ConversationStreamResult {}, )), }; - let observed = to_val(&env); + let req_id = env.request_id.clone(); + let observed = to_val(&CallToolResult::from(env)); let expected = json!({ - "requestId": 2, - "result": {} + "content": [ { "type": "text", "text": "{}" } ], + "structuredContent": {} }); assert_eq!( observed, expected, "response (ConversationStream) must have empty object result" ); + assert_eq!(req_id, RequestId::Integer(2)); } #[test] @@ -495,18 +543,20 @@ mod tests { request_id: RequestId::Integer(3), is_error: None, result: Some(ToolCallResponseResult::ConversationSendMessage( - ConversationSendMessageResult { success: true }, + ConversationSendMessageResult::Ok, )), }; - let observed = to_val(&env); + let req_id = env.request_id.clone(); + let observed = to_val(&CallToolResult::from(env)); let expected = json!({ - "requestId": 3, - "result": { "success": true } + "content": [ { "type": "text", "text": "{\"status\":\"ok\"}" } ], + "structuredContent": { "status": "ok" } }); assert_eq!( observed, expected, "response (ConversationSendMessageAccepted) must match" ); + assert_eq!(req_id, RequestId::Integer(3)); } #[test] @@ -526,10 +576,13 @@ mod tests { }, )), }; - let observed = to_val(&env); + let req_id = env.request_id.clone(); + let observed = to_val(&CallToolResult::from(env)); let expected = json!({ - "requestId": 4, - "result": { + "content": [ + { "type": "text", "text": "{\"conversations\":[{\"conversation_id\":\"67e55044-10b1-426f-9247-bb680e5fe0c8\",\"title\":\"Refactor config loader\"}],\"next_cursor\":\"next123\"}" } + ], + "structuredContent": { "conversations": [ { "conversation_id": "67e55044-10b1-426f-9247-bb680e5fe0c8", @@ -543,6 +596,7 @@ mod tests { observed, expected, "response (ConversationsList with cursor) must match" ); + assert_eq!(req_id, RequestId::Integer(4)); } #[test] @@ -552,15 +606,17 @@ mod tests { is_error: Some(true), result: None, }; - let observed = to_val(&env); + let req_id = env.request_id.clone(); + let observed = to_val(&CallToolResult::from(env)); let expected = json!({ - "requestId": 4, + "content": [], "isError": true }); assert_eq!( observed, expected, "error response must omit `result` and include `isError`" ); + assert_eq!(req_id, RequestId::Integer(4)); } // ----- Notifications -----