diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 36d4f119..cb749fac 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -5,6 +5,7 @@ use std::path::Path; use std::path::PathBuf; use std::sync::Arc; use std::sync::Mutex; +use std::time::Duration; use anyhow::Context; use async_channel::Receiver; @@ -396,9 +397,10 @@ impl Session { server: &str, tool: &str, arguments: Option, + timeout: Option, ) -> anyhow::Result { self.mcp_connection_manager - .call_tool(server, tool, arguments) + .call_tool(server, tool, arguments, timeout) .await } @@ -1194,7 +1196,12 @@ async fn handle_function_call( _ => { match try_parse_fully_qualified_tool_name(&name) { Some((server, tool_name)) => { - handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await + // TODO(mbolin): Determine appropriate timeout for tool call. + let timeout = None; + handle_mcp_tool_call( + sess, &sub_id, call_id, server, tool_name, arguments, timeout, + ) + .await } None => { // Unknown function: reply with structured failure so the model can adapt. diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index f03b9f20..734c3514 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -7,6 +7,7 @@ //! `""` as the key. use std::collections::HashMap; +use std::time::Duration; use anyhow::Context; use anyhow::Result; @@ -25,6 +26,9 @@ use crate::mcp_server_config::McpServerConfig; /// choose a delimiter from this character set. const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__"; +/// Timeout for the `tools/list` request. +const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10); + fn fully_qualified_tool_name(server: &str, tool: &str) -> String { format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}") } @@ -104,6 +108,7 @@ impl McpConnectionManager { server: &str, tool: &str, arguments: Option, + timeout: Option, ) -> Result { let client = self .clients @@ -112,7 +117,7 @@ impl McpConnectionManager { .clone(); client - .call_tool(tool.to_string(), arguments) + .call_tool(tool.to_string(), arguments, timeout) .await .with_context(|| format!("tool call failed for `{server}/{tool}`")) } @@ -132,7 +137,9 @@ pub async fn list_all_tools( let server_name_cloned = server_name.clone(); let client_clone = client.clone(); join_set.spawn(async move { - let res = client_clone.list_tools(None).await; + let res = client_clone + .list_tools(None, Some(LIST_TOOLS_TIMEOUT)) + .await; (server_name_cloned, res) }); } diff --git a/codex-rs/core/src/mcp_tool_call.rs b/codex-rs/core/src/mcp_tool_call.rs index 0b6401f7..7cbbad7e 100644 --- a/codex-rs/core/src/mcp_tool_call.rs +++ b/codex-rs/core/src/mcp_tool_call.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use tracing::error; use crate::codex::Session; @@ -15,6 +17,7 @@ pub(crate) async fn handle_mcp_tool_call( server: String, tool_name: String, arguments: String, + timeout: Option, ) -> ResponseInputItem { // Parse the `arguments` as JSON. An empty string is OK, but invalid JSON // is not. @@ -45,25 +48,27 @@ pub(crate) async fn handle_mcp_tool_call( notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await; // Perform the tool call. - let (tool_call_end_event, tool_call_err) = - match sess.call_tool(&server, &tool_name, arguments_value).await { - Ok(result) => ( - EventMsg::McpToolCallEnd { - call_id, - success: !result.is_error.unwrap_or(false), - result: Some(result), - }, - None, - ), - Err(e) => ( - EventMsg::McpToolCallEnd { - call_id, - success: false, - result: None, - }, - Some(e), - ), - }; + let (tool_call_end_event, tool_call_err) = match sess + .call_tool(&server, &tool_name, arguments_value, timeout) + .await + { + Ok(result) => ( + EventMsg::McpToolCallEnd { + call_id, + success: !result.is_error.unwrap_or(false), + result: Some(result), + }, + None, + ), + Err(e) => ( + EventMsg::McpToolCallEnd { + call_id, + success: false, + result: None, + }, + Some(e), + ), + }; notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await; let EventMsg::McpToolCallEnd { diff --git a/codex-rs/mcp-client/Cargo.toml b/codex-rs/mcp-client/Cargo.toml index b98eccab..81f4b85e 100644 --- a/codex-rs/mcp-client/Cargo.toml +++ b/codex-rs/mcp-client/Cargo.toml @@ -16,6 +16,7 @@ tokio = { version = "1", features = [ "process", "rt-multi-thread", "sync", + "time", ] } [dev-dependencies] diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index 1e4ead98..eb784252 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -34,8 +34,9 @@ async fn main() -> Result<()> { .with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?; // Issue `tools/list` request (no params). + let timeout = None; let tools = client - .list_tools(None::) + .list_tools(None::, timeout) .await .context("tools/list request failed")?; diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index b36f78b3..1c6a765c 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -15,6 +15,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::sync::atomic::AtomicI64; use std::sync::atomic::Ordering; +use std::time::Duration; use anyhow::Result; use anyhow::anyhow; @@ -39,6 +40,7 @@ use tokio::process::Command; use tokio::sync::Mutex; use tokio::sync::mpsc; use tokio::sync::oneshot; +use tokio::time; use tracing::debug; use tracing::error; use tracing::info; @@ -175,7 +177,15 @@ impl McpClient { } /// Send an arbitrary MCP request and await the typed result. - pub async fn send_request(&self, params: R::Params) -> Result + /// + /// If `timeout` is `None` the call waits indefinitely. If `Some(duration)` + /// is supplied and no response is received within the given period, a + /// timeout error is returned. + pub async fn send_request( + &self, + params: R::Params, + timeout: Option, + ) -> Result where R: ModelContextProtocolRequest, R::Params: Serialize, @@ -220,10 +230,31 @@ impl McpClient { )); } - // Await the response. - let msg = rx - .await - .map_err(|_| anyhow!("response channel closed before a reply was received"))?; + // Await the response, optionally bounded by a timeout. + let msg = match timeout { + Some(duration) => { + match time::timeout(duration, rx).await { + Ok(Ok(msg)) => msg, + Ok(Err(_)) => { + // Channel closed without a reply – remove the pending entry. + let mut guard = self.pending.lock().await; + guard.remove(&id); + return Err(anyhow!( + "response channel closed before a reply was received" + )); + } + Err(_) => { + // Timed out. Remove the pending entry so we don't leak. + let mut guard = self.pending.lock().await; + guard.remove(&id); + return Err(anyhow!("request timed out")); + } + } + } + None => rx + .await + .map_err(|_| anyhow!("response channel closed before a reply was received"))?, + }; match msg { JSONRPCMessage::Response(JSONRPCResponse { result, .. }) => { @@ -245,8 +276,9 @@ impl McpClient { pub async fn list_tools( &self, params: Option, + timeout: Option, ) -> Result { - self.send_request::(params).await + self.send_request::(params, timeout).await } /// Convenience wrapper around `tools/call`. @@ -254,10 +286,11 @@ impl McpClient { &self, name: String, arguments: Option, + timeout: Option, ) -> Result { let params = CallToolRequestParams { name, arguments }; debug!("MCP tool call: {params:?}"); - self.send_request::(params).await + self.send_request::(params, timeout).await } /// Internal helper: route a JSON-RPC *response* object to the pending map.