fix: trim MCP tool names to fit into tool name length limit (#1571)
Store fully qualified names along with tool entries so we don't have to re-parse them. Fixes: https://github.com/openai/codex/issues/1289
This commit is contained in:
@@ -51,7 +51,6 @@ use crate::exec::process_exec_tool_call;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::flags::OPENAI_STREAM_MAX_RETRIES;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name;
|
||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
@@ -1292,7 +1291,7 @@ async fn handle_function_call(
|
||||
handle_container_exec_with_params(params, sess, sub_id, call_id).await
|
||||
}
|
||||
_ => {
|
||||
match try_parse_fully_qualified_tool_name(&name) {
|
||||
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
||||
Some((server, tool_name)) => {
|
||||
// TODO(mbolin): Determine appropriate timeout for tool call.
|
||||
let timeout = None;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -16,8 +17,12 @@ use codex_mcp_client::McpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use sha1::Digest;
|
||||
use sha1::Sha1;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
|
||||
@@ -26,7 +31,8 @@ use crate::config_types::McpServerConfig;
|
||||
///
|
||||
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
|
||||
/// choose a delimiter from this character set.
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// Timeout for the `tools/list` request.
|
||||
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
@@ -35,16 +41,42 @@ const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
/// spawned successfully.
|
||||
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
|
||||
|
||||
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
|
||||
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
|
||||
fn qualify_tools(tools: Vec<ToolInfo>) -> HashMap<String, ToolInfo> {
|
||||
let mut used_names = HashSet::new();
|
||||
let mut qualified_tools = HashMap::new();
|
||||
for tool in tools {
|
||||
let mut qualified_name = format!(
|
||||
"{}{}{}",
|
||||
tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
|
||||
);
|
||||
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(qualified_name.as_bytes());
|
||||
let sha1 = hasher.finalize();
|
||||
let sha1_str = format!("{sha1:x}");
|
||||
|
||||
// Truncate to make room for the hash suffix
|
||||
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
|
||||
|
||||
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
|
||||
}
|
||||
|
||||
if used_names.contains(&qualified_name) {
|
||||
warn!("skipping duplicated tool {}", qualified_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
used_names.insert(qualified_name.clone());
|
||||
qualified_tools.insert(qualified_name, tool);
|
||||
}
|
||||
|
||||
qualified_tools
|
||||
}
|
||||
|
||||
pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> {
|
||||
let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?;
|
||||
if server.is_empty() || tool.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server.to_string(), tool.to_string()))
|
||||
struct ToolInfo {
|
||||
server_name: String,
|
||||
tool_name: String,
|
||||
tool: Tool,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`McpClient`] instances.
|
||||
@@ -57,7 +89,7 @@ pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, std::sync::Arc<McpClient>>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, Tool>,
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
@@ -141,7 +173,9 @@ impl McpConnectionManager {
|
||||
}
|
||||
}
|
||||
|
||||
let tools = list_all_tools(&clients).await?;
|
||||
let all_tools = list_all_tools(&clients).await?;
|
||||
|
||||
let tools = qualify_tools(all_tools);
|
||||
|
||||
Ok((Self { clients, tools }, errors))
|
||||
}
|
||||
@@ -149,7 +183,10 @@ impl McpConnectionManager {
|
||||
/// Returns a single map that contains **all** tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools.clone()
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
@@ -171,13 +208,19 @@ impl McpConnectionManager {
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains **all** tools. Each key is the fully-qualified name for the tool.
|
||||
pub async fn list_all_tools(
|
||||
async fn list_all_tools(
|
||||
clients: &HashMap<String, std::sync::Arc<McpClient>>,
|
||||
) -> Result<HashMap<String, Tool>> {
|
||||
) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
@@ -194,18 +237,19 @@ pub async fn list_all_tools(
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Tool> = HashMap::with_capacity(join_set.len());
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = join_res?;
|
||||
let list_result = list_result?;
|
||||
|
||||
for tool in list_result.tools {
|
||||
// TODO(mbolin): escape tool names that contain invalid characters.
|
||||
let fq_name = fully_qualified_tool_name(&server_name, &tool.name);
|
||||
if aggregated.insert(fq_name.clone(), tool).is_some() {
|
||||
panic!("tool name collision for '{fq_name}': suspicious");
|
||||
}
|
||||
let tool_info = ToolInfo {
|
||||
server_name: server_name.clone(),
|
||||
tool_name: tool.name.clone(),
|
||||
tool,
|
||||
};
|
||||
aggregated.push(tool_info);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,3 +268,90 @@ fn is_valid_mcp_server_name(server_name: &str) -> bool {
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
ToolInfo {
|
||||
server_name: server_name.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
tool: Tool {
|
||||
annotations: None,
|
||||
description: Some(format!("Test tool: {tool_name}")),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: tool_name.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_short_non_duplicated_names() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "tool1"),
|
||||
create_test_tool("server1", "tool2"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
assert!(qualified_tools.contains_key("server1__tool1"));
|
||||
assert!(qualified_tools.contains_key("server1__tool2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_duplicated_names_skipped() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
// Only the first tool should remain, the second is skipped
|
||||
assert_eq!(qualified_tools.len(), 1);
|
||||
assert!(qualified_tools.contains_key("server1__duplicate_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_long_names_same_server() {
|
||||
let server_name = "my_server";
|
||||
|
||||
let tools = vec![
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
|
||||
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
|
||||
keys.sort();
|
||||
|
||||
assert_eq!(keys[0].len(), 64);
|
||||
assert_eq!(
|
||||
keys[0],
|
||||
"my_server__extremely_lena02e507efc5a9de88637e436690364fd4219e4ef"
|
||||
);
|
||||
|
||||
assert_eq!(keys[1].len(), 64);
|
||||
assert_eq!(
|
||||
keys[1],
|
||||
"my_server__yet_another_e1c3987bd9c50b826cbe1687966f79f0c602d19ca"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user