timeouts for mcp tool calls (#3959)

defaults to 60sec, overridable with MCP_TOOL_TIMEOUT or on a per-server
basis in the config.
This commit is contained in:
Jeremy Rose
2025-09-22 10:30:59 -07:00
committed by GitHub
parent e258ca61b4
commit 19f46439ae
7 changed files with 166 additions and 50 deletions

View File

@@ -1008,10 +1008,9 @@ impl Session {
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> anyhow::Result<CallToolResult> {
self.mcp_connection_manager
.call_tool(server, tool, arguments, timeout)
.call_tool(server, tool, arguments)
.await
}
@@ -2596,12 +2595,7 @@ async fn handle_function_call(
_ => {
match sess.mcp_connection_manager.parse_tool_name(&name) {
Some((server, tool_name)) => {
// TODO(mbolin): Determine appropriate timeout for tool call.
let timeout = None;
handle_mcp_tool_call(
sess, &sub_id, call_id, server, tool_name, arguments, timeout,
)
.await
handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await
}
None => {
// Unknown function: reply with structured failure so the model can adapt.

View File

@@ -333,14 +333,12 @@ pub fn write_global_mcp_servers(
entry["env"] = TomlItem::Table(env_table);
}
if let Some(timeout) = config.startup_timeout_ms {
let timeout = i64::try_from(timeout).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"startup_timeout_ms exceeds supported range",
)
})?;
entry["startup_timeout_ms"] = toml_edit::value(timeout);
if let Some(timeout) = config.startup_timeout_sec {
entry["startup_timeout_sec"] = toml_edit::value(timeout.as_secs_f64());
}
if let Some(timeout) = config.tool_timeout_sec {
entry["tool_timeout_sec"] = toml_edit::value(timeout.as_secs_f64());
}
doc["mcp_servers"][name.as_str()] = TomlItem::Table(entry);
@@ -1168,6 +1166,7 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tempfile::TempDir;
#[test]
@@ -1292,7 +1291,8 @@ exclude_slash_tmp = true
command: "echo".to_string(),
args: vec!["hello".to_string()],
env: None,
startup_timeout_ms: None,
startup_timeout_sec: Some(Duration::from_secs(3)),
tool_timeout_sec: Some(Duration::from_secs(5)),
},
);
@@ -1303,6 +1303,8 @@ exclude_slash_tmp = true
let docs = loaded.get("docs").expect("docs entry");
assert_eq!(docs.command, "echo");
assert_eq!(docs.args, vec!["hello".to_string()]);
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3)));
assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5)));
let empty = BTreeMap::new();
write_global_mcp_servers(codex_home.path(), &empty)?;
@@ -1312,6 +1314,28 @@ exclude_slash_tmp = true
Ok(())
}
#[test]
fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
let codex_home = TempDir::new()?;
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
std::fs::write(
&config_path,
r#"
[mcp_servers]
[mcp_servers.docs]
command = "echo"
startup_timeout_ms = 2500
"#,
)?;
let servers = load_global_mcp_servers(codex_home.path())?;
let docs = servers.get("docs").expect("docs entry");
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_millis(2500)));
Ok(())
}
#[tokio::test]
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
let codex_home = TempDir::new()?;

View File

@@ -5,11 +5,15 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::de::Error as SerdeError;
#[derive(Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
pub command: String,
@@ -19,9 +23,84 @@ pub struct McpServerConfig {
#[serde(default)]
pub env: Option<HashMap<String, String>>,
/// Startup timeout in milliseconds for initializing MCP server & initially listing tools.
#[serde(default)]
pub startup_timeout_ms: Option<u64>,
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
#[serde(
default,
with = "option_duration_secs",
skip_serializing_if = "Option::is_none"
)]
pub startup_timeout_sec: Option<Duration>,
/// Default timeout for MCP tool calls initiated via this server.
#[serde(default, with = "option_duration_secs")]
pub tool_timeout_sec: Option<Duration>,
}
impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawMcpServerConfig {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: Option<HashMap<String, String>>,
#[serde(default)]
startup_timeout_sec: Option<f64>,
#[serde(default)]
startup_timeout_ms: Option<u64>,
#[serde(default, with = "option_duration_secs")]
tool_timeout_sec: Option<Duration>,
}
let raw = RawMcpServerConfig::deserialize(deserializer)?;
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
(Some(sec), _) => {
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
Some(duration)
}
(None, Some(ms)) => Some(Duration::from_millis(ms)),
(None, None) => None,
};
Ok(Self {
command: raw.command,
args: raw.args,
env: raw.env,
startup_timeout_sec,
tool_timeout_sec: raw.tool_timeout_sec,
})
}
}
mod option_duration_secs {
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
use std::time::Duration;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let secs = Option::<f64>::deserialize(deserializer)?;
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
.transpose()
}
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]

View File

@@ -40,6 +40,9 @@ const MAX_TOOL_NAME_LENGTH: usize = 64;
/// Default timeout for initializing MCP server & initially listing tools.
const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
/// Default timeout for individual tool calls.
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);
/// Map that holds a startup error for every MCP server that could **not** be
/// spawned successfully.
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
@@ -85,6 +88,7 @@ struct ToolInfo {
struct ManagedClient {
client: Arc<McpClient>,
startup_timeout: Duration,
tool_timeout: Option<Duration>,
}
/// A thin wrapper around a set of running [`McpClient`] instances.
@@ -132,10 +136,9 @@ impl McpConnectionManager {
continue;
}
let startup_timeout = cfg
.startup_timeout_ms
.map(Duration::from_millis)
.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
join_set.spawn(async move {
let McpServerConfig {
@@ -171,19 +174,19 @@ impl McpConnectionManager {
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
};
let initialize_notification_params = None;
match client
let init_result = client
.initialize(
params,
initialize_notification_params,
Some(startup_timeout),
)
.await
{
Ok(_response) => (server_name, Ok((client, startup_timeout))),
Err(e) => (server_name, Err(e)),
}
.await;
(
(server_name, tool_timeout),
init_result.map(|_| (client, startup_timeout)),
)
}
Err(e) => (server_name, Err(e.into())),
Err(e) => ((server_name, tool_timeout), Err(e.into())),
}
});
}
@@ -191,8 +194,8 @@ impl McpConnectionManager {
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
while let Some(res) = join_set.join_next().await {
let (server_name, client_res) = match res {
Ok((server_name, client_res)) => (server_name, client_res),
let ((server_name, tool_timeout), client_res) = match res {
Ok(result) => result,
Err(e) => {
warn!("Task panic when starting MCP server: {e:#}");
continue;
@@ -206,6 +209,7 @@ impl McpConnectionManager {
ManagedClient {
client: Arc::new(client),
startup_timeout,
tool_timeout: Some(tool_timeout),
},
);
}
@@ -243,14 +247,13 @@ impl McpConnectionManager {
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<mcp_types::CallToolResult> {
let client = self
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.client
.clone();
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
client
.call_tool(tool.to_string(), arguments, timeout)

View File

@@ -1,4 +1,3 @@
use std::time::Duration;
use std::time::Instant;
use tracing::error;
@@ -21,7 +20,6 @@ pub(crate) async fn handle_mcp_tool_call(
server: String,
tool_name: String,
arguments: String,
timeout: Option<Duration>,
) -> ResponseInputItem {
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
// is not.
@@ -58,7 +56,7 @@ pub(crate) async fn handle_mcp_tool_call(
let start = Instant::now();
// Perform the tool call.
let result = sess
.call_tool(&server, &tool_name, arguments_value.clone(), timeout)
.call_tool(&server, &tool_name, arguments_value.clone())
.await
.map_err(|e| format!("tool call error: {e}"));
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {