diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index b1aa13a1..a25e0f8b 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -399,6 +399,15 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "bstr" version = "1.12.0" @@ -671,6 +680,7 @@ dependencies = [ "seccompiler", "serde", "serde_json", + "sha1", "strum_macros 0.27.1", "tempfile", "thiserror 2.0.12", @@ -933,6 +943,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.4.2" @@ -1007,6 +1026,16 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "ctor" version = "0.1.26" @@ -1157,6 +1186,16 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "dirs" version = "6.0.0" @@ -1646,6 +1685,16 @@ dependencies = [ "byteorder", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.23" @@ -3945,6 +3994,17 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -4852,6 +4912,12 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "typenum" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" + [[package]] name = "unicase" version = "2.8.1" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index ff066cc5..e192a71f 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -28,6 +28,7 @@ rand = "0.9" reqwest = { version = "0.12", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" +sha1 = "0.10.6" strum_macros = "0.27.1" thiserror = "2.0.12" time = { version = "0.3", features = ["formatting", "local-offset", "macros"] } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 5227f93c..d4e73b2e 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -51,7 +51,6 @@ use crate::exec::process_exec_tool_call; use crate::exec_env::create_env; use crate::flags::OPENAI_STREAM_MAX_RETRIES; use crate::mcp_connection_manager::McpConnectionManager; -use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name; use crate::mcp_tool_call::handle_mcp_tool_call; use crate::models::ContentItem; use crate::models::FunctionCallOutputPayload; @@ -1292,7 +1291,7 @@ async fn handle_function_call( handle_container_exec_with_params(params, sess, sub_id, call_id).await } _ => { - match try_parse_fully_qualified_tool_name(&name) { + match sess.mcp_connection_manager.parse_tool_name(&name) { Some((server, tool_name)) => { // TODO(mbolin): Determine appropriate timeout for tool call. let timeout = None; diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 7cf67627..c8161c9b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -7,6 +7,7 @@ //! `""` as the key. use std::collections::HashMap; +use std::collections::HashSet; use std::time::Duration; use anyhow::Context; @@ -16,8 +17,12 @@ use codex_mcp_client::McpClient; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::Tool; + +use sha1::Digest; +use sha1::Sha1; use tokio::task::JoinSet; use tracing::info; +use tracing::warn; use crate::config_types::McpServerConfig; @@ -26,7 +31,8 @@ use crate::config_types::McpServerConfig; /// /// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must /// choose a delimiter from this character set. -const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__"; +const MCP_TOOL_NAME_DELIMITER: &str = "__"; +const MAX_TOOL_NAME_LENGTH: usize = 64; /// Timeout for the `tools/list` request. const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); @@ -35,16 +41,42 @@ const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); /// spawned successfully. pub type ClientStartErrors = HashMap; -fn fully_qualified_tool_name(server: &str, tool: &str) -> String { - format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}") +fn qualify_tools(tools: Vec) -> HashMap { + let mut used_names = HashSet::new(); + let mut qualified_tools = HashMap::new(); + for tool in tools { + let mut qualified_name = format!( + "{}{}{}", + tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name + ); + if qualified_name.len() > MAX_TOOL_NAME_LENGTH { + let mut hasher = Sha1::new(); + hasher.update(qualified_name.as_bytes()); + let sha1 = hasher.finalize(); + let sha1_str = format!("{sha1:x}"); + + // Truncate to make room for the hash suffix + let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len(); + + qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str); + } + + if used_names.contains(&qualified_name) { + warn!("skipping duplicated tool {}", qualified_name); + continue; + } + + used_names.insert(qualified_name.clone()); + qualified_tools.insert(qualified_name, tool); + } + + qualified_tools } -pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> { - let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?; - if server.is_empty() || tool.is_empty() { - return None; - } - Some((server.to_string(), tool.to_string())) +struct ToolInfo { + server_name: String, + tool_name: String, + tool: Tool, } /// A thin wrapper around a set of running [`McpClient`] instances. @@ -57,7 +89,7 @@ pub(crate) struct McpConnectionManager { clients: HashMap>, /// Fully qualified tool name -> tool instance. - tools: HashMap, + tools: HashMap, } impl McpConnectionManager { @@ -141,7 +173,9 @@ impl McpConnectionManager { } } - let tools = list_all_tools(&clients).await?; + let all_tools = list_all_tools(&clients).await?; + + let tools = qualify_tools(all_tools); Ok((Self { clients, tools }, errors)) } @@ -149,7 +183,10 @@ impl McpConnectionManager { /// Returns a single map that contains **all** tools. Each key is the /// fully-qualified name for the tool. pub fn list_all_tools(&self) -> HashMap { - self.tools.clone() + self.tools + .iter() + .map(|(name, tool)| (name.clone(), tool.tool.clone())) + .collect() } /// Invoke the tool indicated by the (server, tool) pair. @@ -171,13 +208,19 @@ impl McpConnectionManager { .await .with_context(|| format!("tool call failed for `{server}/{tool}`")) } + + pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + self.tools + .get(tool_name) + .map(|tool| (tool.server_name.clone(), tool.tool_name.clone())) + } } /// Query every server for its available tools and return a single map that /// contains **all** tools. Each key is the fully-qualified name for the tool. -pub async fn list_all_tools( +async fn list_all_tools( clients: &HashMap>, -) -> Result> { +) -> Result> { let mut join_set = JoinSet::new(); // Spawn one task per server so we can query them concurrently. This @@ -194,18 +237,19 @@ pub async fn list_all_tools( }); } - let mut aggregated: HashMap = HashMap::with_capacity(join_set.len()); + let mut aggregated: Vec = Vec::with_capacity(join_set.len()); while let Some(join_res) = join_set.join_next().await { let (server_name, list_result) = join_res?; let list_result = list_result?; for tool in list_result.tools { - // TODO(mbolin): escape tool names that contain invalid characters. - let fq_name = fully_qualified_tool_name(&server_name, &tool.name); - if aggregated.insert(fq_name.clone(), tool).is_some() { - panic!("tool name collision for '{fq_name}': suspicious"); - } + let tool_info = ToolInfo { + server_name: server_name.clone(), + tool_name: tool.name.clone(), + tool, + }; + aggregated.push(tool_info); } } @@ -224,3 +268,90 @@ fn is_valid_mcp_server_name(server_name: &str) -> bool { .chars() .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-') } + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + use mcp_types::ToolInputSchema; + + fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo { + ToolInfo { + server_name: server_name.to_string(), + tool_name: tool_name.to_string(), + tool: Tool { + annotations: None, + description: Some(format!("Test tool: {tool_name}")), + input_schema: ToolInputSchema { + properties: None, + required: None, + r#type: "object".to_string(), + }, + name: tool_name.to_string(), + }, + } + } + + #[test] + fn test_qualify_tools_short_non_duplicated_names() { + let tools = vec![ + create_test_tool("server1", "tool1"), + create_test_tool("server1", "tool2"), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + assert!(qualified_tools.contains_key("server1__tool1")); + assert!(qualified_tools.contains_key("server1__tool2")); + } + + #[test] + fn test_qualify_tools_duplicated_names_skipped() { + let tools = vec![ + create_test_tool("server1", "duplicate_tool"), + create_test_tool("server1", "duplicate_tool"), + ]; + + let qualified_tools = qualify_tools(tools); + + // Only the first tool should remain, the second is skipped + assert_eq!(qualified_tools.len(), 1); + assert!(qualified_tools.contains_key("server1__duplicate_tool")); + } + + #[test] + fn test_qualify_tools_long_names_same_server() { + let server_name = "my_server"; + + let tools = vec![ + create_test_tool( + server_name, + "extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + create_test_tool( + server_name, + "yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits", + ), + ]; + + let qualified_tools = qualify_tools(tools); + + assert_eq!(qualified_tools.len(), 2); + + let mut keys: Vec<_> = qualified_tools.keys().cloned().collect(); + keys.sort(); + + assert_eq!(keys[0].len(), 64); + assert_eq!( + keys[0], + "my_server__extremely_lena02e507efc5a9de88637e436690364fd4219e4ef" + ); + + assert_eq!(keys[1].len(), 64); + assert_eq!( + keys[1], + "my_server__yet_another_e1c3987bd9c50b826cbe1687966f79f0c602d19ca" + ); + } +}