feat: support mcp_servers in config.toml (#829)

This adds initial support for MCP servers in the style of Claude Desktop
and Cursor. Note this PR is the bare minimum to get things working end
to end: all configured MCP servers are launched every time Codex is run,
there is no recovery for MCP servers that crash, etc.

(Also, I took some shortcuts to change some fields of `Session` to be
`pub(crate)`, which also means there are circular deps between
`codex.rs` and `mcp_tool_call.rs`, but I will clean that up in a
subsequent PR.)

`codex-rs/README.md` is updated as part of this PR to explain how to use
this feature. There is a bit of plumbing to route the new settings from
`Config` to the business logic in `codex.rs`. The most significant
chunks for new code are in `mcp_connection_manager.rs` (which defines
the `McpConnectionManager` struct) and `mcp_tool_call.rs`, which is
responsible for tool calls.

This PR also introduces new `McpToolCallBegin` and `McpToolCallEnd`
event types to the protocol, but does not add any handlers for them.
(See https://github.com/openai/codex/pull/836 for initial usage.)

To test, I added the following to my `~/.codex/config.toml`:

```toml
# Local build of https://github.com/hideya/mcp-server-weather-js
[mcp_servers.weather]
command = "/Users/mbolin/code/mcp-server-weather-js/dist/index.js"
args = []
```

And then I ran the following:

```
codex-rs$ cargo run --bin codex exec 'what is the weather in san francisco'
[2025-05-06T22:40:05] Task started: 1
[2025-05-06T22:40:18] Agent message: Here’s the latest National Weather Service forecast for San Francisco (downtown, near 37.77° N, 122.42° W):

This Afternoon (Tue):
• Sunny, high near 69 °F
• West-southwest wind around 12 mph

Tonight:
• Partly cloudy, low around 52 °F
• SW wind 7–10 mph
...
```

Note that Codex itself is not able to make network calls, so it would
not normally be able to get live weather information like this. However,
the weather MCP is [currently] not run under the Codex sandbox, so it is
able to hit `api.weather.gov` and fetch current weather information.

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with [ReviewStack](https://reviewstack.dev/openai/codex/pull/829).
* #836
* __->__ #829
This commit is contained in:
Michael Bolin
2025-05-06 15:47:59 -07:00
committed by GitHub
parent 49d040215a
commit 147a940449
11 changed files with 453 additions and 18 deletions

View File

@@ -1,4 +1,5 @@
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::io::BufRead;
use std::path::Path;
use std::pin::Pin;
@@ -13,6 +14,7 @@ use futures::prelude::*;
use reqwest::StatusCode;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use serde_json::Value;
use tokio::sync::mpsc;
use tokio::time::timeout;
@@ -42,6 +44,11 @@ pub struct Prompt {
pub instructions: Option<String>,
/// Whether to store response on server side (disable_response_storage = !store).
pub store: bool,
/// Additional tools sourced from external MCP servers. Note each key is
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
}
#[derive(Debug)]
@@ -59,7 +66,7 @@ struct Payload<'a> {
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
input: &'a Vec<ResponseItem>,
tools: &'a [Tool],
tools: &'a [serde_json::Value],
tool_choice: &'static str,
parallel_tool_calls: bool,
reasoning: Option<Reasoning>,
@@ -77,11 +84,12 @@ struct Reasoning {
generate_summary: Option<bool>,
}
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
#[derive(Debug, Serialize)]
struct Tool {
struct ResponsesApiTool {
name: &'static str,
#[serde(rename = "type")]
kind: &'static str, // "function"
r#type: &'static str, // "function"
description: &'static str,
strict: bool,
parameters: JsonSchema,
@@ -105,7 +113,7 @@ enum JsonSchema {
}
/// Tool usage specification
static TOOLS: LazyLock<Vec<Tool>> = LazyLock::new(|| {
static DEFAULT_TOOLS: LazyLock<Vec<ResponsesApiTool>> = LazyLock::new(|| {
let mut properties = BTreeMap::new();
properties.insert(
"command".to_string(),
@@ -116,9 +124,9 @@ static TOOLS: LazyLock<Vec<Tool>> = LazyLock::new(|| {
properties.insert("workdir".to_string(), JsonSchema::String);
properties.insert("timeout".to_string(), JsonSchema::Number);
vec![Tool {
vec![ResponsesApiTool {
name: "shell",
kind: "function",
r#type: "function",
description: "Runs a shell command, and returns its output.",
strict: false,
parameters: JsonSchema::Object {
@@ -149,11 +157,26 @@ impl ModelClient {
return stream_from_fixture(path).await;
}
// Assemble tool list: built-in tools + any extra tools from the prompt.
let mut tools_json: Vec<serde_json::Value> = DEFAULT_TOOLS
.iter()
.map(|t| serde_json::to_value(t).expect("serialize builtin tool"))
.collect();
tools_json.extend(
prompt
.extra_tools
.clone()
.into_iter()
.map(|(name, tool)| mcp_tool_to_openai_tool(name, tool)),
);
debug!("tools_json: {}", serde_json::to_string_pretty(&tools_json)?);
let payload = Payload {
model: &self.model,
instructions: prompt.instructions.as_ref(),
input: &prompt.input,
tools: &TOOLS,
tools: &tools_json,
tool_choice: "auto",
parallel_tool_calls: false,
reasoning: Some(Reasoning {
@@ -235,6 +258,20 @@ impl ModelClient {
}
}
fn mcp_tool_to_openai_tool(
fully_qualified_name: String,
tool: mcp_types::Tool,
) -> serde_json::Value {
// TODO(mbolin): Change the contract of this function to return
// ResponsesApiTool.
json!({
"name": fully_qualified_name,
"description": tool.description,
"parameters": tool.input_schema,
"type": "function",
})
}
#[derive(Debug, Deserialize, Serialize)]
struct SseEvent {
#[serde(rename = "type")]

View File

@@ -31,6 +31,8 @@ use tracing::warn;
use crate::client::ModelClient;
use crate::client::Prompt;
use crate::client::ResponseEvent;
use crate::config::Config;
use crate::config::ConfigOverrides;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::exec::process_exec_tool_call;
@@ -38,6 +40,9 @@ use crate::exec::ExecParams;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::flags::OPENAI_STREAM_MAX_RETRIES;
use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_tool_call::handle_mcp_tool_call;
use crate::models::ContentItem;
use crate::models::FunctionCallOutputPayload;
use crate::models::ResponseInputItem;
@@ -188,9 +193,9 @@ impl Recorder {
/// Context for an initialized model agent
///
/// A session has at most 1 running task at a time, and can be interrupted by user input.
struct Session {
pub(crate) struct Session {
client: ModelClient,
tx_event: Sender<Event>,
pub(crate) tx_event: Sender<Event>,
ctrl_c: Arc<Notify>,
/// The session's current working directory. All relative paths provided by
@@ -202,6 +207,9 @@ struct Session {
sandbox_policy: SandboxPolicy,
writable_roots: Mutex<Vec<PathBuf>>,
/// Manager for external MCP servers/tools.
pub(crate) mcp_connection_manager: McpConnectionManager,
/// External notifier command (will be passed as args to exec()). When
/// `None` this feature is disabled.
notify: Option<Vec<String>>,
@@ -433,7 +441,7 @@ impl State {
}
/// A series of Turns in response to user input.
struct AgentTask {
pub(crate) struct AgentTask {
sess: Arc<Session>,
sub_id: String,
handle: AbortHandle,
@@ -554,6 +562,26 @@ async fn submission_loop(
};
let writable_roots = Mutex::new(get_writable_roots(&cwd));
// Load config to initialize the MCP connection manager.
let config = match Config::load_with_overrides(ConfigOverrides::default()) {
Ok(cfg) => cfg,
Err(e) => {
error!("Failed to load config for MCP servers: {e:#}");
// Fall back to empty server map so the session can still proceed.
Config::load_default_config_for_test()
}
};
let mcp_connection_manager =
match McpConnectionManager::new(config.mcp_servers.clone()).await {
Ok(mgr) => mgr,
Err(e) => {
error!("Failed to create MCP connection manager: {e:#}");
McpConnectionManager::default()
}
};
sess = Some(Arc::new(Session {
client,
tx_event: tx_event.clone(),
@@ -563,6 +591,7 @@ async fn submission_loop(
sandbox_policy,
cwd,
writable_roots,
mcp_connection_manager,
notify,
state: Mutex::new(state),
}));
@@ -753,11 +782,14 @@ async fn run_turn(
} else {
None
};
let extra_tools = sess.mcp_connection_manager.list_all_tools();
let prompt = Prompt {
input,
prev_id,
instructions,
store,
extra_tools,
};
let mut retries = 0;
@@ -1141,13 +1173,20 @@ async fn handle_function_call(
}
}
_ => {
// Unknown function: reply with structured failure so the model can adapt.
ResponseInputItem::FunctionCallOutput {
call_id,
output: crate::models::FunctionCallOutputPayload {
content: format!("unsupported call: {}", name),
success: None,
},
match try_parse_fully_qualified_tool_name(&name) {
Some((server, tool_name)) => {
handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await
}
None => {
// Unknown function: reply with structured failure so the model can adapt.
ResponseInputItem::FunctionCallOutput {
call_id,
output: crate::models::FunctionCallOutputPayload {
content: format!("unsupported call: {}", name),
success: None,
},
}
}
}
}
}

View File

@@ -1,9 +1,11 @@
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::mcp_server_config::McpServerConfig;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPermission;
use crate::protocol::SandboxPolicy;
use dirs::home_dir;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
/// Embedded fallback instructions that mirror the TypeScript CLIs default
@@ -56,6 +58,9 @@ pub struct Config {
/// for the session. All relative paths inside the business-logic layer are
/// resolved against this path.
pub cwd: PathBuf,
/// Definition for MCP servers that Codex can reach out to for tool calls.
pub mcp_servers: HashMap<String, McpServerConfig>,
}
/// Base config deserialized from ~/.codex/config.toml.
@@ -84,6 +89,10 @@ pub struct ConfigToml {
/// System instructions.
pub instructions: Option<String>,
/// Definition for MCP servers that Codex can reach out to for tool calls.
#[serde(default)]
pub mcp_servers: HashMap<String, McpServerConfig>,
}
impl ConfigToml {
@@ -212,6 +221,7 @@ impl Config {
.unwrap_or(false),
notify: cfg.notify,
instructions,
mcp_servers: cfg.mcp_servers,
}
}

View File

@@ -15,6 +15,9 @@ mod flags;
mod is_safe_command;
#[cfg(target_os = "linux")]
pub mod linux;
mod mcp_connection_manager;
pub mod mcp_server_config;
mod mcp_tool_call;
mod models;
pub mod protocol;
mod safety;

View File

@@ -0,0 +1,162 @@
//! Connection manager for Model Context Protocol (MCP) servers.
//!
//! The [`McpConnectionManager`] owns one [`codex_mcp_client::McpClient`] per
//! configured server (keyed by the *server name*). It offers convenience
//! helpers to query the available tools across *all* servers and returns them
//! in a single aggregated map using the fully-qualified tool name
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
use std::collections::HashMap;
use anyhow::anyhow;
use anyhow::Context;
use anyhow::Result;
use codex_mcp_client::McpClient;
use mcp_types::Tool;
use tokio::task::JoinSet;
use tracing::info;
use crate::mcp_server_config::McpServerConfig;
/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
///
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
/// choose a delimiter from this character set.
const MCP_TOOL_NAME_DELIMITER: &str = "__OAI_CODEX_MCP__";
fn fully_qualified_tool_name(server: &str, tool: &str) -> String {
format!("{server}{MCP_TOOL_NAME_DELIMITER}{tool}")
}
pub(crate) fn try_parse_fully_qualified_tool_name(fq_name: &str) -> Option<(String, String)> {
let (server, tool) = fq_name.split_once(MCP_TOOL_NAME_DELIMITER)?;
if server.is_empty() || tool.is_empty() {
return None;
}
Some((server.to_string(), tool.to_string()))
}
/// A thin wrapper around a set of running [`McpClient`] instances.
#[derive(Default)]
pub(crate) struct McpConnectionManager {
/// Server-name -> client instance.
///
/// The server name originates from the keys of the `mcp_servers` map in
/// the user configuration.
clients: HashMap<String, std::sync::Arc<McpClient>>,
/// Fully qualified tool name -> tool instance.
tools: HashMap<String, Tool>,
}
impl McpConnectionManager {
/// Spawn a [`McpClient`] for each configured server.
///
/// * `mcp_servers` Map loaded from the user configuration where *keys*
/// are human-readable server identifiers and *values* are the spawn
/// instructions.
pub async fn new(mcp_servers: HashMap<String, McpServerConfig>) -> Result<Self> {
// Early exit if no servers are configured.
if mcp_servers.is_empty() {
return Ok(Self::default());
}
// Spin up all servers concurrently.
let mut join_set = JoinSet::new();
// Spawn tasks to launch each server.
for (server_name, cfg) in mcp_servers {
// TODO: Verify server name: require `^[a-zA-Z0-9_-]+$`?
join_set.spawn(async move {
let McpServerConfig { command, args, env } = cfg;
let client_res = McpClient::new_stdio_client(command, args, env).await;
(server_name, client_res)
});
}
let mut clients: HashMap<String, std::sync::Arc<McpClient>> =
HashMap::with_capacity(join_set.len());
while let Some(res) = join_set.join_next().await {
let (server_name, client_res) = res?;
let client = client_res
.with_context(|| format!("failed to spawn MCP server `{server_name}`"))?;
clients.insert(server_name, std::sync::Arc::new(client));
}
let tools = list_all_tools(&clients).await?;
Ok(Self { clients, tools })
}
/// Returns a single map that contains **all** tools. Each key is the
/// fully-qualified name for the tool.
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
self.tools.clone()
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
) -> Result<mcp_types::CallToolResult> {
let client = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.clone();
client
.call_tool(tool.to_string(), arguments)
.await
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
}
}
/// Query every server for its available tools and return a single map that
/// contains **all** tools. Each key is the fully-qualified name for the tool.
pub async fn list_all_tools(
clients: &HashMap<String, std::sync::Arc<McpClient>>,
) -> Result<HashMap<String, Tool>> {
let mut join_set = JoinSet::new();
// Spawn one task per server so we can query them concurrently. This
// keeps the overall latency roughly at the slowest server instead of
// the cumulative latency.
for (server_name, client) in clients {
let server_name_cloned = server_name.clone();
let client_clone = client.clone();
join_set.spawn(async move {
let res = client_clone.list_tools(None).await;
(server_name_cloned, res)
});
}
let mut aggregated: HashMap<String, Tool> = HashMap::with_capacity(join_set.len());
while let Some(join_res) = join_set.join_next().await {
let (server_name, list_result) = join_res?;
let list_result = list_result?;
for tool in list_result.tools {
// TODO(mbolin): escape tool names that contain invalid characters.
let fq_name = fully_qualified_tool_name(&server_name, &tool.name);
if aggregated.insert(fq_name.clone(), tool).is_some() {
panic!("tool name collision for '{fq_name}': suspicious");
}
}
}
info!(
"aggregated {} tools from {} servers",
aggregated.len(),
clients.len()
);
Ok(aggregated)
}

View File

@@ -0,0 +1,14 @@
use std::collections::HashMap;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone)]
pub struct McpServerConfig {
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub env: Option<HashMap<String, String>>,
}

View File

@@ -0,0 +1,107 @@
use tracing::error;
use crate::codex::Session;
use crate::models::FunctionCallOutputPayload;
use crate::models::ResponseInputItem;
use crate::protocol::Event;
use crate::protocol::EventMsg;
/// Handles the specified tool call dispatches the appropriate
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
pub(crate) async fn handle_mcp_tool_call(
sess: &Session,
sub_id: &str,
call_id: String,
server: String,
tool_name: String,
arguments: String,
) -> ResponseInputItem {
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
// is not.
let arguments_value = if arguments.trim().is_empty() {
None
} else {
match serde_json::from_str::<serde_json::Value>(&arguments) {
Ok(value) => Some(value),
Err(e) => {
error!("failed to parse tool call arguments: {e}");
return ResponseInputItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: format!("err: {e}"),
success: Some(false),
},
};
}
}
};
let tool_call_begin_event = EventMsg::McpToolCallBegin {
call_id: call_id.clone(),
server: server.clone(),
tool: tool_name.clone(),
arguments: arguments_value.clone(),
};
notify_mcp_tool_call_event(sess, sub_id, tool_call_begin_event).await;
// Perform the tool call.
let (tool_call_end_event, tool_call_err) = match sess
.mcp_connection_manager
.call_tool(&server, &tool_name, arguments_value)
.await
{
Ok(result) => (
EventMsg::McpToolCallEnd {
call_id,
success: !result.is_error.unwrap_or(false),
result: Some(result),
},
None,
),
Err(e) => (
EventMsg::McpToolCallEnd {
call_id,
success: false,
result: None,
},
Some(e),
),
};
notify_mcp_tool_call_event(sess, sub_id, tool_call_end_event.clone()).await;
let EventMsg::McpToolCallEnd {
call_id,
success,
result,
} = tool_call_end_event
else {
unimplemented!("unexpected event type");
};
ResponseInputItem::FunctionCallOutput {
call_id,
output: FunctionCallOutputPayload {
content: result.map_or_else(
|| format!("err: {tool_call_err:?}"),
|result| {
serde_json::to_string(&result)
.unwrap_or_else(|e| format!("JSON serialization error: {e}"))
},
),
success: Some(success),
},
}
}
async fn notify_mcp_tool_call_event(sess: &Session, sub_id: &str, event: EventMsg) {
if let Err(e) = sess
.tx_event
.send(Event {
id: sub_id.to_string(),
msg: event,
})
.await
{
error!("failed to send tool call event: {e}");
}
}

View File

@@ -7,6 +7,7 @@ use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use mcp_types::CallToolResult;
use serde::Deserialize;
use serde::Serialize;
@@ -316,6 +317,32 @@ pub enum EventMsg {
model: String,
},
McpToolCallBegin {
/// Identifier so this can be paired with the McpToolCallEnd event.
call_id: String,
/// Name of the MCP server as defined in the config.
server: String,
/// Name of the tool as given by the MCP server.
tool: String,
/// Arguments to the tool call.
arguments: Option<serde_json::Value>,
},
McpToolCallEnd {
/// Identifier for the McpToolCallBegin that finished.
call_id: String,
/// Whether the tool call was successful. If `false`, `result` might
/// not be present.
success: bool,
/// Result of the tool call. Note this could be an error.
result: Option<CallToolResult>,
},
/// Notification that the server is about to execute a command.
ExecCommandBegin {
/// Identifier so this can be paired with the ExecCommandEnd event.