fix: add optional timeout to McpClient::send_request() (#852)
We now impose a 10s timeout on the initial `tools/list` request to an MCP server. We do not apply a timeout for other types of requests yet, but we should start enforcing those, as well.
This commit is contained in:
@@ -5,6 +5,7 @@ use std::path::Path;
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use async_channel::Receiver;
|
use async_channel::Receiver;
|
||||||
@@ -396,9 +397,10 @@ impl Session {
|
|||||||
server: &str,
|
server: &str,
|
||||||
tool: &str,
|
tool: &str,
|
||||||
arguments: Option<serde_json::Value>,
|
arguments: Option<serde_json::Value>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
) -> anyhow::Result<mcp_types::CallToolResult> {
|
) -> anyhow::Result<mcp_types::CallToolResult> {
|
||||||
self.mcp_connection_manager
|
self.mcp_connection_manager
|
||||||
.call_tool(server, tool, arguments)
|
.call_tool(server, tool, arguments, timeout)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1194,7 +1196,12 @@ async fn handle_function_call(
|
|||||||
_ => {
|
_ => {
|
||||||
match try_parse_fully_qualified_tool_name(&name) {
|
match try_parse_fully_qualified_tool_name(&name) {
|
||||||
Some((server, tool_name)) => {
|
Some((server, tool_name)) => {
|
||||||
handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await
|
// TODO(mbolin): Determine appropriate timeout for tool call.
|
||||||
|
let timeout = None;
|
||||||
|
handle_mcp_tool_call(
|
||||||
|
sess, &sub_id, call_id, server, tool_name, arguments, timeout,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// Unknown function: reply with structured failure so the model can adapt.
|
// Unknown function: reply with structured failure so the model can adapt.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
@@ -25,6 +26,9 @@ use crate::mcp_server_config::McpServerConfig;
|
|||||||
/// choose a delimiter from this character set.
|
/// choose a delimiter from this character set.
|
||||||
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
|
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
|
||||||
|
|
||||||
|
/// Timeout for the `tools/list` request.
|
||||||
|
const LIST_TOOLS_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
|
||||||
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
|
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
|
||||||
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
|
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
|
||||||
}
|
}
|
||||||
@@ -104,6 +108,7 @@ impl McpConnectionManager {
|
|||||||
server: &str,
|
server: &str,
|
||||||
tool: &str,
|
tool: &str,
|
||||||
arguments: Option<serde_json::Value>,
|
arguments: Option<serde_json::Value>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
) -> Result<mcp_types::CallToolResult> {
|
) -> Result<mcp_types::CallToolResult> {
|
||||||
let client = self
|
let client = self
|
||||||
.clients
|
.clients
|
||||||
@@ -112,7 +117,7 @@ impl McpConnectionManager {
|
|||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
client
|
client
|
||||||
.call_tool(tool.to_string(), arguments)
|
.call_tool(tool.to_string(), arguments, timeout)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||||
}
|
}
|
||||||
@@ -132,7 +137,9 @@ pub async fn list_all_tools(
|
|||||||
let server_name_cloned = server_name.clone();
|
let server_name_cloned = server_name.clone();
|
||||||
let client_clone = client.clone();
|
let client_clone = client.clone();
|
||||||
join_set.spawn(async move {
|
join_set.spawn(async move {
|
||||||
let res = client_clone.list_tools(None).await;
|
let res = client_clone
|
||||||
|
.list_tools(None, Some(LIST_TOOLS_TIMEOUT))
|
||||||
|
.await;
|
||||||
(server_name_cloned, res)
|
(server_name_cloned, res)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
@@ -15,6 +17,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
|||||||
server: String,
|
server: String,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
arguments: String,
|
arguments: String,
|
||||||
|
timeout: Option<Duration>,
|
||||||
) -> ResponseInputItem {
|
) -> ResponseInputItem {
|
||||||
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
|
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
|
||||||
// is not.
|
// is not.
|
||||||
@@ -45,25 +48,27 @@ pub(crate) async fn handle_mcp_tool_call(
|
|||||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await;
|
notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await;
|
||||||
|
|
||||||
// Perform the tool call.
|
// Perform the tool call.
|
||||||
let (tool_call_end_event, tool_call_err) =
|
let (tool_call_end_event, tool_call_err) = match sess
|
||||||
match sess.call_tool(&server, &tool_name, arguments_value).await {
|
.call_tool(&server, &tool_name, arguments_value, timeout)
|
||||||
Ok(result) => (
|
.await
|
||||||
EventMsg::McpToolCallEnd {
|
{
|
||||||
call_id,
|
Ok(result) => (
|
||||||
success: !result.is_error.unwrap_or(false),
|
EventMsg::McpToolCallEnd {
|
||||||
result: Some(result),
|
call_id,
|
||||||
},
|
success: !result.is_error.unwrap_or(false),
|
||||||
None,
|
result: Some(result),
|
||||||
),
|
},
|
||||||
Err(e) => (
|
None,
|
||||||
EventMsg::McpToolCallEnd {
|
),
|
||||||
call_id,
|
Err(e) => (
|
||||||
success: false,
|
EventMsg::McpToolCallEnd {
|
||||||
result: None,
|
call_id,
|
||||||
},
|
success: false,
|
||||||
Some(e),
|
result: None,
|
||||||
),
|
},
|
||||||
};
|
Some(e),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await;
|
notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await;
|
||||||
let EventMsg::McpToolCallEnd {
|
let EventMsg::McpToolCallEnd {
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ tokio = { version = "1", features = [
|
|||||||
"process",
|
"process",
|
||||||
"rt-multi-thread",
|
"rt-multi-thread",
|
||||||
"sync",
|
"sync",
|
||||||
|
"time",
|
||||||
] }
|
] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|||||||
@@ -34,8 +34,9 @@ async fn main() -> Result<()> {
|
|||||||
.with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?;
|
.with_context(|| format!("failed to spawn subprocess: {original_args:?}"))?;
|
||||||
|
|
||||||
// Issue `tools/list` request (no params).
|
// Issue `tools/list` request (no params).
|
||||||
|
let timeout = None;
|
||||||
let tools = client
|
let tools = client
|
||||||
.list_tools(None::<ListToolsRequestParams>)
|
.list_tools(None::<ListToolsRequestParams>, timeout)
|
||||||
.await
|
.await
|
||||||
.context("tools/list request failed")?;
|
.context("tools/list request failed")?;
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicI64;
|
use std::sync::atomic::AtomicI64;
|
||||||
use std::sync::atomic::Ordering;
|
use std::sync::atomic::Ordering;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
@@ -39,6 +40,7 @@ use tokio::process::Command;
|
|||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::time;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
@@ -175,7 +177,15 @@ impl McpClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Send an arbitrary MCP request and await the typed result.
|
/// Send an arbitrary MCP request and await the typed result.
|
||||||
pub async fn send_request<R>(&self, params: R::Params) -> Result<R::Result>
|
///
|
||||||
|
/// If `timeout` is `None` the call waits indefinitely. If `Some(duration)`
|
||||||
|
/// is supplied and no response is received within the given period, a
|
||||||
|
/// timeout error is returned.
|
||||||
|
pub async fn send_request<R>(
|
||||||
|
&self,
|
||||||
|
params: R::Params,
|
||||||
|
timeout: Option<Duration>,
|
||||||
|
) -> Result<R::Result>
|
||||||
where
|
where
|
||||||
R: ModelContextProtocolRequest,
|
R: ModelContextProtocolRequest,
|
||||||
R::Params: Serialize,
|
R::Params: Serialize,
|
||||||
@@ -220,10 +230,31 @@ impl McpClient {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Await the response.
|
// Await the response, optionally bounded by a timeout.
|
||||||
let msg = rx
|
let msg = match timeout {
|
||||||
.await
|
Some(duration) => {
|
||||||
.map_err(|_| anyhow!("response channel closed before a reply was received"))?;
|
match time::timeout(duration, rx).await {
|
||||||
|
Ok(Ok(msg)) => msg,
|
||||||
|
Ok(Err(_)) => {
|
||||||
|
// Channel closed without a reply – remove the pending entry.
|
||||||
|
let mut guard = self.pending.lock().await;
|
||||||
|
guard.remove(&id);
|
||||||
|
return Err(anyhow!(
|
||||||
|
"response channel closed before a reply was received"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Timed out. Remove the pending entry so we don't leak.
|
||||||
|
let mut guard = self.pending.lock().await;
|
||||||
|
guard.remove(&id);
|
||||||
|
return Err(anyhow!("request timed out"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => rx
|
||||||
|
.await
|
||||||
|
.map_err(|_| anyhow!("response channel closed before a reply was received"))?,
|
||||||
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
JSONRPCMessage::Response(JSONRPCResponse { result, .. }) => {
|
JSONRPCMessage::Response(JSONRPCResponse { result, .. }) => {
|
||||||
@@ -245,8 +276,9 @@ impl McpClient {
|
|||||||
pub async fn list_tools(
|
pub async fn list_tools(
|
||||||
&self,
|
&self,
|
||||||
params: Option<ListToolsRequestParams>,
|
params: Option<ListToolsRequestParams>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
) -> Result<ListToolsResult> {
|
) -> Result<ListToolsResult> {
|
||||||
self.send_request::<ListToolsRequest>(params).await
|
self.send_request::<ListToolsRequest>(params, timeout).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience wrapper around `tools/call`.
|
/// Convenience wrapper around `tools/call`.
|
||||||
@@ -254,10 +286,11 @@ impl McpClient {
|
|||||||
&self,
|
&self,
|
||||||
name: String,
|
name: String,
|
||||||
arguments: Option<serde_json::Value>,
|
arguments: Option<serde_json::Value>,
|
||||||
|
timeout: Option<Duration>,
|
||||||
) -> Result<mcp_types::CallToolResult> {
|
) -> Result<mcp_types::CallToolResult> {
|
||||||
let params = CallToolRequestParams { name, arguments };
|
let params = CallToolRequestParams { name, arguments };
|
||||||
debug!("MCP tool call: {params:?}");
|
debug!("MCP tool call: {params:?}");
|
||||||
self.send_request::<CallToolRequest>(params).await
|
self.send_request::<CallToolRequest>(params, timeout).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal helper: route a JSON-RPC *response* object to the pending map.
|
/// Internal helper: route a JSON-RPC *response* object to the pending map.
|
||||||
|
|||||||
Reference in New Issue
Block a user