[MCP] Add experimental support for streamable HTTP MCP servers (#4317)
This PR adds support for streamable HTTP MCP servers when the `experimental_use_rmcp_client` is enabled. To set one up, simply add a new mcp server config with the url: ``` [mcp_servers.figma] url = "http://127.0.0.1:3845/mcp" ``` It also supports an optional `bearer_token` which will be provided in an authorization header. The full oauth flow is not supported yet. The config parsing will throw if it detects that the user mixed and matched config fields (like command + bearer token or url + env). The best way to review it is to review `core/src` and then `rmcp-client/src/rmcp_client.rs` first. The rest is tests and propagating the `Transport` struct around the codebase. Example with the Figma MCP: <img width="5084" height="1614" alt="CleanShot 2025-09-26 at 13 35 40" src="https://github.com/user-attachments/assets/eaf2771e-df3e-4300-816b-184d7dec5a28" />
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
use crate::config_profile::ConfigProfile;
|
||||
use crate::config_types::History;
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::McpServerTransportConfig;
|
||||
use crate::config_types::Notifications;
|
||||
use crate::config_types::ReasoningSummaryFormat;
|
||||
use crate::config_types::SandboxWorkspaceWrite;
|
||||
@@ -314,27 +315,37 @@ pub fn write_global_mcp_servers(
|
||||
for (name, config) in servers {
|
||||
let mut entry = TomlTable::new();
|
||||
entry.set_implicit(false);
|
||||
entry["command"] = toml_edit::value(config.command.clone());
|
||||
match &config.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
entry["command"] = toml_edit::value(command.clone());
|
||||
|
||||
if !config.args.is_empty() {
|
||||
let mut args = TomlArray::new();
|
||||
for arg in &config.args {
|
||||
args.push(arg.clone());
|
||||
}
|
||||
entry["args"] = TomlItem::Value(args.into());
|
||||
}
|
||||
if !args.is_empty() {
|
||||
let mut args_array = TomlArray::new();
|
||||
for arg in args {
|
||||
args_array.push(arg.clone());
|
||||
}
|
||||
entry["args"] = TomlItem::Value(args_array.into());
|
||||
}
|
||||
|
||||
if let Some(env) = &config.env
|
||||
&& !env.is_empty()
|
||||
{
|
||||
let mut env_table = TomlTable::new();
|
||||
env_table.set_implicit(false);
|
||||
let mut pairs: Vec<_> = env.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
for (key, value) in pairs {
|
||||
env_table.insert(key, toml_edit::value(value.clone()));
|
||||
if let Some(env) = env
|
||||
&& !env.is_empty()
|
||||
{
|
||||
let mut env_table = TomlTable::new();
|
||||
env_table.set_implicit(false);
|
||||
let mut pairs: Vec<_> = env.iter().collect();
|
||||
pairs.sort_by(|(a, _), (b, _)| a.cmp(b));
|
||||
for (key, value) in pairs {
|
||||
env_table.insert(key, toml_edit::value(value.clone()));
|
||||
}
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
entry["url"] = toml_edit::value(url.clone());
|
||||
if let Some(token) = bearer_token {
|
||||
entry["bearer_token"] = toml_edit::value(token.clone());
|
||||
}
|
||||
}
|
||||
entry["env"] = TomlItem::Table(env_table);
|
||||
}
|
||||
|
||||
if let Some(timeout) = config.startup_timeout_sec {
|
||||
@@ -1294,9 +1305,11 @@ exclude_slash_tmp = true
|
||||
servers.insert(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string()],
|
||||
env: None,
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string()],
|
||||
env: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(3)),
|
||||
tool_timeout_sec: Some(Duration::from_secs(5)),
|
||||
},
|
||||
@@ -1307,8 +1320,14 @@ exclude_slash_tmp = true
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
assert_eq!(loaded.len(), 1);
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
assert_eq!(docs.command, "echo");
|
||||
assert_eq!(docs.args, vec!["hello".to_string()]);
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
assert_eq!(command, "echo");
|
||||
assert_eq!(args, &vec!["hello".to_string()]);
|
||||
assert!(env.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3)));
|
||||
assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5)));
|
||||
|
||||
@@ -1342,6 +1361,134 @@ startup_timeout_ms = 2500
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_env_sorted() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: "docs-server".to_string(),
|
||||
args: vec!["--verbose".to_string()],
|
||||
env: Some(HashMap::from([
|
||||
("ZIG_VAR".to_string(), "3".to_string()),
|
||||
("ALPHA_VAR".to_string(), "1".to_string()),
|
||||
])),
|
||||
},
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
command = "docs-server"
|
||||
args = ["--verbose"]
|
||||
|
||||
[mcp_servers.docs.env]
|
||||
ALPHA_VAR = "1"
|
||||
ZIG_VAR = "3"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
assert_eq!(command, "docs-server");
|
||||
assert_eq!(args, &vec!["--verbose".to_string()]);
|
||||
let env = env
|
||||
.as_ref()
|
||||
.expect("env should be preserved for stdio transport");
|
||||
assert_eq!(env.get("ALPHA_VAR"), Some(&"1".to_string()));
|
||||
assert_eq!(env.get("ZIG_VAR"), Some(&"3".to_string()));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_global_mcp_servers_serializes_streamable_http() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
let mut servers = BTreeMap::from([(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret-token".to_string()),
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(2)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret-token"
|
||||
startup_timeout_sec = 2.0
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert_eq!(bearer_token.as_deref(), Some("secret-token"));
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(2)));
|
||||
|
||||
servers.insert(
|
||||
"docs".to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None,
|
||||
},
|
||||
startup_timeout_sec: None,
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
write_global_mcp_servers(codex_home.path(), &servers)?;
|
||||
|
||||
let serialized = std::fs::read_to_string(&config_path)?;
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"[mcp_servers.docs]
|
||||
url = "https://example.com/mcp"
|
||||
"#
|
||||
);
|
||||
|
||||
let loaded = load_global_mcp_servers(codex_home.path())?;
|
||||
let docs = loaded.get("docs").expect("docs entry");
|
||||
match &docs.transport {
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
assert_eq!(url, "https://example.com/mcp");
|
||||
assert!(bearer_token.is_none());
|
||||
}
|
||||
other => panic!("unexpected transport {other:?}"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
|
||||
@@ -3,25 +3,20 @@
|
||||
// Note this file should generally be restricted to simple struct/enum
|
||||
// definitions that do not contain business logic.
|
||||
|
||||
use serde::Deserializer;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use wildmatch::WildMatchPattern;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serialize;
|
||||
use serde::de::Error as SerdeError;
|
||||
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
pub struct McpServerConfig {
|
||||
pub command: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub env: Option<HashMap<String, String>>,
|
||||
#[serde(flatten)]
|
||||
pub transport: McpServerTransportConfig,
|
||||
|
||||
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
|
||||
#[serde(
|
||||
@@ -43,11 +38,15 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
struct RawMcpServerConfig {
|
||||
command: String,
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
args: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
env: Option<HashMap<String, String>>,
|
||||
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
#[serde(default)]
|
||||
@@ -67,16 +66,81 @@ impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
(None, None) => None,
|
||||
};
|
||||
|
||||
fn throw_if_set<E, T>(transport: &str, field: &str, value: Option<&T>) -> Result<(), E>
|
||||
where
|
||||
E: SerdeError,
|
||||
{
|
||||
if value.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
Err(E::custom(format!(
|
||||
"{field} is not supported for {transport}",
|
||||
)))
|
||||
}
|
||||
|
||||
let transport = match raw {
|
||||
RawMcpServerConfig {
|
||||
command: Some(command),
|
||||
args,
|
||||
env,
|
||||
url,
|
||||
bearer_token,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("stdio", "url", url.as_ref())?;
|
||||
throw_if_set("stdio", "bearer_token", bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: args.unwrap_or_default(),
|
||||
env,
|
||||
}
|
||||
}
|
||||
RawMcpServerConfig {
|
||||
url: Some(url),
|
||||
bearer_token,
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
..
|
||||
} => {
|
||||
throw_if_set("streamable_http", "command", command.as_ref())?;
|
||||
throw_if_set("streamable_http", "args", args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", env.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token }
|
||||
}
|
||||
_ => return Err(SerdeError::custom("invalid transport")),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
command: raw.command,
|
||||
args: raw.args,
|
||||
env: raw.env,
|
||||
transport,
|
||||
startup_timeout_sec,
|
||||
tool_timeout_sec: raw.tool_timeout_sec,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
|
||||
pub enum McpServerTransportConfig {
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio
|
||||
Stdio {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
},
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// A plain text bearer token to use for authentication.
|
||||
/// This bearer token will be included in the HTTP request header as an `Authorization: Bearer <token>` header.
|
||||
/// This should be used with caution because it lives on disk in clear text.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
mod option_duration_secs {
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
@@ -303,3 +367,139 @@ pub enum ReasoningSummaryFormat {
|
||||
None,
|
||||
Experimental,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_args() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
args = ["hello", "world"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
args = ["hello", "world"]
|
||||
env = { "FOO" = "BAR" }
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())]))
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: None
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_bearer_token() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token: Some("secret".to_string())
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_command_and_url() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
url = "https://example.com"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject command+url");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_env_for_http_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
url = "https://example.com"
|
||||
env = { "FOO" = "BAR" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject env for http transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_bearer_token_for_stdio_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer token for stdio transport");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config_types::McpServerConfig;
|
||||
use crate::config_types::McpServerTransportConfig;
|
||||
|
||||
/// Delimiter used to separate the server name from the tool name in a fully
|
||||
/// qualified tool name.
|
||||
@@ -121,6 +122,17 @@ impl McpClientAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_streamable_http_client(
|
||||
url: String,
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
}
|
||||
|
||||
async fn list_tools(
|
||||
&self,
|
||||
params: Option<mcp_types::ListToolsRequestParams>,
|
||||
@@ -176,8 +188,6 @@ impl McpConnectionManager {
|
||||
return Ok((Self::default(), ClientStartErrors::default()));
|
||||
}
|
||||
|
||||
tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}");
|
||||
|
||||
// Launch all configured servers concurrently.
|
||||
let mut join_set = JoinSet::new();
|
||||
let mut errors = ClientStartErrors::new();
|
||||
@@ -193,16 +203,24 @@ impl McpConnectionManager {
|
||||
continue;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp { .. }
|
||||
) && !use_rmcp_client
|
||||
{
|
||||
info!(
|
||||
"skipping MCP server `{}` configured with url because rmcp client is disabled",
|
||||
server_name
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
|
||||
let use_rmcp_client_flag = use_rmcp_client;
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig {
|
||||
command, args, env, ..
|
||||
} = cfg;
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
@@ -224,15 +242,30 @@ impl McpConnectionManager {
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
let client = McpClientAdapter::new_stdio_client(
|
||||
use_rmcp_client_flag,
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
let client = match transport {
|
||||
McpServerTransportConfig::Stdio { command, args, env } => {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
McpClientAdapter::new_stdio_client(
|
||||
use_rmcp_client_flag,
|
||||
command_os,
|
||||
args_os,
|
||||
env,
|
||||
params.clone(),
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
|
||||
McpClientAdapter::new_streamable_http_client(
|
||||
url,
|
||||
bearer_token,
|
||||
params,
|
||||
startup_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
.map(|c| (c, startup_timeout));
|
||||
|
||||
((server_name, tool_timeout), client)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::TcpListener;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::config_types::McpServerConfig;
|
||||
use codex_core::config_types::McpServerTransportConfig;
|
||||
|
||||
use codex_core::protocol::AskForApproval;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
@@ -16,10 +19,15 @@ use core_test_support::wait_for_event;
|
||||
use core_test_support::wait_for_event_with_timeout;
|
||||
use escargot::CargoBuild;
|
||||
use serde_json::Value;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::Instant;
|
||||
use tokio::time::sleep;
|
||||
use wiremock::matchers::any;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn stdio_server_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
@@ -54,7 +62,7 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
let expected_env_value = "propagated-env";
|
||||
let rmcp_test_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("rmcp_test_server")
|
||||
.bin("test_stdio_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
@@ -66,12 +74,14 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
command: rmcp_test_server_bin.clone(),
|
||||
args: Vec::new(),
|
||||
env: Some(HashMap::from([(
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
transport: McpServerTransportConfig::Stdio {
|
||||
command: rmcp_test_server_bin.clone(),
|
||||
args: Vec::new(),
|
||||
env: Some(HashMap::from([(
|
||||
"MCP_TEST_VALUE".to_string(),
|
||||
expected_env_value.to_string(),
|
||||
)])),
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
@@ -97,18 +107,13 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
eprintln!("waiting for mcp tool call begin event");
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| {
|
||||
eprintln!("ev: {ev:?}");
|
||||
matches!(ev, EventMsg::McpToolCallBegin(_))
|
||||
},
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
eprintln!("mcp tool call begin event: {begin_event:?}");
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
@@ -119,7 +124,6 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
eprintln!("end_event: {end_event:?}");
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
@@ -145,18 +149,223 @@ async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ping");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
let task_complete_event =
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
eprintln!("task_complete_event: {task_complete_event:?}");
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
|
||||
let call_id = "call-456";
|
||||
let server_name = "rmcp_http";
|
||||
let tool_name = format!("{server_name}__echo");
|
||||
|
||||
mount_sse_once(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
serde_json::json!({
|
||||
"type": "response.created",
|
||||
"response": {"id": "resp-1"}
|
||||
}),
|
||||
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
|
||||
responses::ev_completed("resp-1"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
mount_sse_once(
|
||||
&server,
|
||||
any(),
|
||||
responses::sse(vec![
|
||||
responses::ev_assistant_message(
|
||||
"msg-1",
|
||||
"rmcp streamable http echo tool completed successfully.",
|
||||
),
|
||||
responses::ev_completed("resp-2"),
|
||||
]),
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_env_value = "propagated-env-http";
|
||||
let rmcp_http_server_bin = CargoBuild::new()
|
||||
.package("codex-rmcp-client")
|
||||
.bin("test_streamable_http_server")
|
||||
.run()?
|
||||
.path()
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let server_url = format!("http://{bind_addr}/mcp");
|
||||
|
||||
let mut http_server_child = Command::new(&rmcp_http_server_bin)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.env("MCP_TEST_VALUE", expected_env_value)
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
|
||||
.await?;
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.use_experimental_use_rmcp_client = true;
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
McpServerConfig {
|
||||
transport: McpServerTransportConfig::StreamableHttp {
|
||||
url: server_url,
|
||||
bearer_token: None,
|
||||
},
|
||||
startup_timeout_sec: Some(Duration::from_secs(10)),
|
||||
tool_timeout_sec: None,
|
||||
},
|
||||
);
|
||||
})
|
||||
.build(&server)
|
||||
.await?;
|
||||
let session_model = fixture.session_configured.model.clone();
|
||||
|
||||
fixture
|
||||
.codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "call the rmcp streamable http echo tool".into(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
cwd: fixture.cwd.path().to_path_buf(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: session_model,
|
||||
effort: None,
|
||||
summary: ReasoningSummary::Auto,
|
||||
})
|
||||
.await?;
|
||||
|
||||
let begin_event = wait_for_event_with_timeout(
|
||||
&fixture.codex,
|
||||
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let EventMsg::McpToolCallBegin(begin) = begin_event else {
|
||||
unreachable!("event guard guarantees McpToolCallBegin");
|
||||
};
|
||||
assert_eq!(begin.invocation.server, server_name);
|
||||
assert_eq!(begin.invocation.tool, "echo");
|
||||
|
||||
let end_event = wait_for_event(&fixture.codex, |ev| {
|
||||
matches!(ev, EventMsg::McpToolCallEnd(_))
|
||||
})
|
||||
.await;
|
||||
let EventMsg::McpToolCallEnd(end) = end_event else {
|
||||
unreachable!("event guard guarantees McpToolCallEnd");
|
||||
};
|
||||
|
||||
let result = end
|
||||
.result
|
||||
.as_ref()
|
||||
.expect("rmcp echo tool should return success");
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert!(
|
||||
result.content.is_empty(),
|
||||
"content should default to an empty array"
|
||||
);
|
||||
|
||||
let structured = result
|
||||
.structured_content
|
||||
.as_ref()
|
||||
.expect("structured content");
|
||||
let Value::Object(map) = structured else {
|
||||
panic!("structured content should be an object: {structured:?}");
|
||||
};
|
||||
let echo_value = map
|
||||
.get("echo")
|
||||
.and_then(Value::as_str)
|
||||
.expect("echo payload present");
|
||||
assert_eq!(echo_value, "ECHOING: ping");
|
||||
let env_value = map
|
||||
.get("env")
|
||||
.and_then(Value::as_str)
|
||||
.expect("env snapshot inserted");
|
||||
assert_eq!(env_value, expected_env_value);
|
||||
|
||||
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
server.verify().await;
|
||||
|
||||
match http_server_child.try_wait() {
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => {
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
Err(error) => {
|
||||
eprintln!("failed to check streamable http server status: {error}");
|
||||
let _ = http_server_child.kill().await;
|
||||
}
|
||||
}
|
||||
if let Err(error) = http_server_child.wait().await {
|
||||
eprintln!("failed to await streamable http server shutdown: {error}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
let deadline = Instant::now() + timeout;
|
||||
|
||||
loop {
|
||||
if let Some(status) = server_child.try_wait()? {
|
||||
return Err(anyhow::anyhow!(
|
||||
"streamable HTTP server exited early with status {status}"
|
||||
));
|
||||
}
|
||||
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
|
||||
if remaining.is_zero() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: deadline reached"
|
||||
));
|
||||
}
|
||||
|
||||
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
|
||||
Ok(Ok(_)) => return Ok(()),
|
||||
Ok(Err(error)) => {
|
||||
if Instant::now() >= deadline {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: {error}"
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user