[MCP] Add support for MCP Oauth credentials (#4517)

This PR adds oauth login support to streamable http servers when
`experimental_use_rmcp_client` is enabled.

This PR is large but represents the minimal amount of work required for
this to work. To keep this PR smaller, login can only be done with
`codex mcp login` and `codex mcp logout` but it doesn't appear in `/mcp`
or `codex mcp list` yet. Fingers crossed that this is the last large MCP
PR and that subsequent PRs can be smaller.

Under the hood, credentials are stored using platform credential
managers using the [keyring crate](https://crates.io/crates/keyring).
When the keyring isn't available, it falls back to storing credentials
in `CODEX_HOME/.credentials.json` which is consistent with how other
coding agents handle authentication.

I tested this on macOS, Windows, WSL (ubuntu), and Linux. I wasn't able
to test the dbus store on linux but did verify that the fallback works.

One quirk is that if you have credentials, during development, every
build will have its own ad-hoc binary so the keyring won't recognize the
reader as being the same as the write so it may ask for the user's
password. I may add an override to disable this or allow
users/enterprises to opt-out of the keyring storage if it causes issues.

<img width="5064" height="686" alt="CleanShot 2025-09-30 at 19 31 40"
src="https://github.com/user-attachments/assets/9573f9b4-07f1-4160-83b8-2920db287e2d"
/>
<img width="745" height="486" alt="image"
src="https://github.com/user-attachments/assets/9562649b-ea5f-4f22-ace2-d0cb438b143e"
/>
This commit is contained in:
Gabriel Peal
2025-10-03 10:43:12 -07:00
committed by GitHub
parent bfe3328129
commit 1d17ca1fa3
18 changed files with 2584 additions and 434 deletions

View File

@@ -91,11 +91,12 @@ escargot = { workspace = true }
maplit = { workspace = true }
predicates = { workspace = true }
pretty_assertions = { workspace = true }
serial_test = { workspace = true }
tempfile = { workspace = true }
tokio-test = { workspace = true }
tracing-test = { workspace = true, features = ["no-env-filter"] }
walkdir = { workspace = true }
wiremock = { workspace = true }
tracing-test = { workspace = true, features = ["no-env-filter"] }
[package.metadata.cargo-shear]
ignored = ["openssl-sys"]

View File

@@ -123,12 +123,15 @@ impl McpClientAdapter {
}
async fn new_streamable_http_client(
server_name: String,
url: String,
bearer_token: Option<String>,
params: mcp_types::InitializeRequestParams,
startup_timeout: Duration,
) -> Result<Self> {
let client = Arc::new(RmcpClient::new_streamable_http_client(url, bearer_token)?);
let client = Arc::new(
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
);
client.initialize(params, Some(startup_timeout)).await?;
Ok(McpClientAdapter::Rmcp(client))
}
@@ -208,8 +211,7 @@ impl McpConnectionManager {
) && !use_rmcp_client
{
info!(
"skipping MCP server `{}` configured with url because rmcp client is disabled",
server_name
"skipping MCP server `{server_name}` because the legacy MCP client only supports stdio servers",
);
continue;
}
@@ -217,7 +219,6 @@ impl McpConnectionManager {
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
let use_rmcp_client_flag = use_rmcp_client;
join_set.spawn(async move {
let McpServerConfig { transport, .. } = cfg;
let params = mcp_types::InitializeRequestParams {
@@ -246,17 +247,18 @@ impl McpConnectionManager {
let command_os: OsString = command.into();
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
McpClientAdapter::new_stdio_client(
use_rmcp_client_flag,
use_rmcp_client,
command_os,
args_os,
env,
params.clone(),
params,
startup_timeout,
)
.await
}
McpServerTransportConfig::StreamableHttp { url, bearer_token } => {
McpClientAdapter::new_streamable_http_client(
server_name.clone(),
url,
bearer_token,
params,

View File

@@ -13,7 +13,7 @@ use tempfile::TempDir;
use crate::load_default_config_for_test;
type ConfigMutator = dyn FnOnce(&mut Config);
type ConfigMutator = dyn FnOnce(&mut Config) + Send;
pub struct TestCodexBuilder {
config_mutators: Vec<Box<ConfigMutator>>,
@@ -22,7 +22,7 @@ pub struct TestCodexBuilder {
impl TestCodexBuilder {
pub fn with_config<T>(mut self, mutator: T) -> Self
where
T: FnOnce(&mut Config) + 'static,
T: FnOnce(&mut Config) + Send + 'static,
{
self.config_mutators.push(Box::new(mutator));
self

View File

@@ -1,6 +1,11 @@
use std::collections::HashMap;
use std::ffi::OsString;
use std::fs;
use std::net::TcpListener;
use std::path::Path;
use std::time::Duration;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use codex_core::config_types::McpServerConfig;
use codex_core::config_types::McpServerTransportConfig;
@@ -19,6 +24,8 @@ use core_test_support::wait_for_event;
use core_test_support::wait_for_event_with_timeout;
use escargot::CargoBuild;
use serde_json::Value;
use serial_test::serial;
use tempfile::tempdir;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
@@ -328,6 +335,189 @@ async fn streamable_http_tool_call_round_trip() -> anyhow::Result<()> {
Ok(())
}
/// This test writes to a fallback credentials file in CODEX_HOME.
/// Ideally, we wouldn't need to serialize the test but it's much more cumbersome to wire CODEX_HOME through the code.
#[serial(codex_home)]
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_with_oauth_round_trip() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
let call_id = "call-789";
let server_name = "rmcp_http_oauth";
let tool_name = format!("{server_name}__echo");
mount_sse_once_match(
&server,
any(),
responses::sse(vec![
serde_json::json!({
"type": "response.created",
"response": {"id": "resp-1"}
}),
responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"),
responses::ev_completed("resp-1"),
]),
)
.await;
mount_sse_once_match(
&server,
any(),
responses::sse(vec![
responses::ev_assistant_message(
"msg-1",
"rmcp streamable http oauth echo tool completed successfully.",
),
responses::ev_completed("resp-2"),
]),
)
.await;
let expected_env_value = "propagated-env-http-oauth";
let expected_token = "initial-access-token";
let client_id = "test-client-id";
let refresh_token = "initial-refresh-token";
let rmcp_http_server_bin = CargoBuild::new()
.package("codex-rmcp-client")
.bin("test_streamable_http_server")
.run()?
.path()
.to_string_lossy()
.into_owned();
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let server_url = format!("http://{bind_addr}/mcp");
let mut http_server_child = Command::new(&rmcp_http_server_bin)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.env("MCP_EXPECT_BEARER", expected_token)
.env("MCP_TEST_VALUE", expected_env_value)
.spawn()?;
wait_for_streamable_http_server(&mut http_server_child, &bind_addr, Duration::from_secs(5))
.await?;
let temp_home = tempdir()?;
let _guard = EnvVarGuard::set("CODEX_HOME", temp_home.path().as_os_str());
write_fallback_oauth_tokens(
temp_home.path(),
server_name,
&server_url,
client_id,
expected_token,
refresh_token,
)?;
let fixture = test_codex()
.with_config(move |config| {
config.use_experimental_use_rmcp_client = true;
config.mcp_servers.insert(
server_name.to_string(),
McpServerConfig {
transport: McpServerTransportConfig::StreamableHttp {
url: server_url,
bearer_token: None,
},
startup_timeout_sec: Some(Duration::from_secs(10)),
tool_timeout_sec: None,
},
);
})
.build(&server)
.await?;
let session_model = fixture.session_configured.model.clone();
fixture
.codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: "call the rmcp streamable http oauth echo tool".into(),
}],
final_output_json_schema: None,
cwd: fixture.cwd.path().to_path_buf(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: session_model,
effort: None,
summary: ReasoningSummary::Auto,
})
.await?;
let begin_event = wait_for_event_with_timeout(
&fixture.codex,
|ev| matches!(ev, EventMsg::McpToolCallBegin(_)),
Duration::from_secs(10),
)
.await;
let EventMsg::McpToolCallBegin(begin) = begin_event else {
unreachable!("event guard guarantees McpToolCallBegin");
};
assert_eq!(begin.invocation.server, server_name);
assert_eq!(begin.invocation.tool, "echo");
let end_event = wait_for_event(&fixture.codex, |ev| {
matches!(ev, EventMsg::McpToolCallEnd(_))
})
.await;
let EventMsg::McpToolCallEnd(end) = end_event else {
unreachable!("event guard guarantees McpToolCallEnd");
};
let result = end
.result
.as_ref()
.expect("rmcp echo tool should return success");
assert_eq!(result.is_error, Some(false));
assert!(
result.content.is_empty(),
"content should default to an empty array"
);
let structured = result
.structured_content
.as_ref()
.expect("structured content");
let Value::Object(map) = structured else {
panic!("structured content should be an object: {structured:?}");
};
let echo_value = map
.get("echo")
.and_then(Value::as_str)
.expect("echo payload present");
assert_eq!(echo_value, "ECHOING: ping");
let env_value = map
.get("env")
.and_then(Value::as_str)
.expect("env snapshot inserted");
assert_eq!(env_value, expected_env_value);
wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
server.verify().await;
match http_server_child.try_wait() {
Ok(Some(_)) => {}
Ok(None) => {
let _ = http_server_child.kill().await;
}
Err(error) => {
eprintln!("failed to check streamable http oauth server status: {error}");
let _ = http_server_child.kill().await;
}
}
if let Err(error) = http_server_child.wait().await {
eprintln!("failed to await streamable http oauth server shutdown: {error}");
}
Ok(())
}
async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
@@ -369,3 +559,60 @@ async fn wait_for_streamable_http_server(
sleep(Duration::from_millis(50)).await;
}
}
fn write_fallback_oauth_tokens(
home: &Path,
server_name: &str,
server_url: &str,
client_id: &str,
access_token: &str,
refresh_token: &str,
) -> anyhow::Result<()> {
let expires_at = SystemTime::now()
.checked_add(Duration::from_secs(3600))
.ok_or_else(|| anyhow::anyhow!("failed to compute expiry time"))?
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
let store = serde_json::json!({
"stub": {
"server_name": server_name,
"server_url": server_url,
"client_id": client_id,
"access_token": access_token,
"expires_at": expires_at,
"refresh_token": refresh_token,
"scopes": ["profile"],
}
});
let file_path = home.join(".credentials.json");
fs::write(&file_path, serde_json::to_vec(&store)?)?;
Ok(())
}
struct EnvVarGuard {
key: &'static str,
original: Option<OsString>,
}
impl EnvVarGuard {
fn set(key: &'static str, value: &std::ffi::OsStr) -> Self {
let original = std::env::var_os(key);
unsafe {
std::env::set_var(key, value);
}
Self { key, original }
}
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
}