Files
llmx/codex-rs/rmcp-client/src/rmcp_client.rs
Gabriel Peal 1d17ca1fa3 [MCP] Add support for MCP Oauth credentials (#4517)
This PR adds oauth login support to streamable http servers when
`experimental_use_rmcp_client` is enabled.

This PR is large but represents the minimal amount of work required for
this to work. To keep this PR smaller, login can only be done with
`codex mcp login` and `codex mcp logout` but it doesn't appear in `/mcp`
or `codex mcp list` yet. Fingers crossed that this is the last large MCP
PR and that subsequent PRs can be smaller.

Under the hood, credentials are stored using platform credential
managers using the [keyring crate](https://crates.io/crates/keyring).
When the keyring isn't available, it falls back to storing credentials
in `CODEX_HOME/.credentials.json` which is consistent with how other
coding agents handle authentication.

I tested this on macOS, Windows, WSL (ubuntu), and Linux. I wasn't able
to test the dbus store on linux but did verify that the fallback works.

One quirk is that if you have credentials, during development, every
build will have its own ad-hoc binary so the keyring won't recognize the
reader as being the same as the write so it may ask for the user's
password. I may add an override to disable this or allow
users/enterprises to opt-out of the keyring storage if it causes issues.

<img width="5064" height="686" alt="CleanShot 2025-09-30 at 19 31 40"
src="https://github.com/user-attachments/assets/9573f9b4-07f1-4160-83b8-2920db287e2d"
/>
<img width="745" height="486" alt="image"
src="https://github.com/user-attachments/assets/9562649b-ea5f-4f22-ace2-d0cb438b143e"
/>
2025-10-03 13:43:12 -04:00

325 lines
11 KiB
Rust

use std::collections::HashMap;
use std::ffi::OsString;
use std::io;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use anyhow::anyhow;
use futures::FutureExt;
use mcp_types::CallToolRequestParams;
use mcp_types::CallToolResult;
use mcp_types::InitializeRequestParams;
use mcp_types::InitializeResult;
use mcp_types::ListToolsRequestParams;
use mcp_types::ListToolsResult;
use rmcp::model::CallToolRequestParam;
use rmcp::model::InitializeRequestParam;
use rmcp::model::PaginatedRequestParam;
use rmcp::service::RoleClient;
use rmcp::service::RunningService;
use rmcp::service::{self};
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::time;
use tracing::info;
use tracing::warn;
use crate::load_oauth_tokens;
use crate::logging_client_handler::LoggingClientHandler;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
use crate::utils::convert_call_tool_result;
use crate::utils::convert_to_mcp;
use crate::utils::convert_to_rmcp;
use crate::utils::create_env_for_mcp_server;
use crate::utils::run_with_timeout;
enum PendingTransport {
ChildProcess(TokioChildProcess),
StreamableHttp {
transport: StreamableHttpClientTransport<reqwest::Client>,
},
StreamableHttpWithOAuth {
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
oauth_persistor: OAuthPersistor,
},
}
enum ClientState {
Connecting {
transport: Option<PendingTransport>,
},
Ready {
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
oauth: Option<OAuthPersistor>,
},
}
/// MCP client implemented on top of the official `rmcp` SDK.
/// https://github.com/modelcontextprotocol/rust-sdk
pub struct RmcpClient {
state: Mutex<ClientState>,
}
impl RmcpClient {
pub async fn new_stdio_client(
program: OsString,
args: Vec<OsString>,
env: Option<HashMap<String, String>>,
) -> io::Result<Self> {
let program_name = program.to_string_lossy().into_owned();
let mut command = Command::new(&program);
command
.kill_on_drop(true)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.env_clear()
.envs(create_env_for_mcp_server(env))
.args(&args);
let (transport, stderr) = TokioChildProcess::builder(command)
.stderr(Stdio::piped())
.spawn()?;
if let Some(stderr) = stderr {
tokio::spawn(async move {
let mut reader = BufReader::new(stderr).lines();
loop {
match reader.next_line().await {
Ok(Some(line)) => {
info!("MCP server stderr ({program_name}): {line}");
}
Ok(None) => break,
Err(error) => {
warn!("Failed to read MCP server stderr ({program_name}): {error}");
break;
}
}
}
});
}
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(PendingTransport::ChildProcess(transport)),
}),
})
}
pub async fn new_streamable_http_client(
server_name: &str,
url: &str,
bearer_token: Option<String>,
) -> Result<Self> {
let initial_tokens = match load_oauth_tokens(server_name, url) {
Ok(tokens) => tokens,
Err(err) => {
warn!("failed to read tokens for server `{server_name}`: {err}");
None
}
};
let transport = if let Some(initial_tokens) = initial_tokens.clone() {
let (transport, oauth_persistor) =
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
PendingTransport::StreamableHttpWithOAuth {
transport,
oauth_persistor,
}
} else {
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
if let Some(bearer_token) = bearer_token {
http_config = http_config.auth_header(format!("Bearer {bearer_token}"));
}
let transport = StreamableHttpClientTransport::from_config(http_config);
PendingTransport::StreamableHttp { transport }
};
Ok(Self {
state: Mutex::new(ClientState::Connecting {
transport: Some(transport),
}),
})
}
/// Perform the initialization handshake with the MCP server.
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization
pub async fn initialize(
&self,
params: InitializeRequestParams,
timeout: Option<Duration>,
) -> Result<InitializeResult> {
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
let client_handler = LoggingClientHandler::new(rmcp_params);
let (transport, oauth_persistor) = {
let mut guard = self.state.lock().await;
match &mut *guard {
ClientState::Connecting { transport } => match transport.take() {
Some(PendingTransport::ChildProcess(transport)) => (
service::serve_client(client_handler.clone(), transport).boxed(),
None,
),
Some(PendingTransport::StreamableHttp { transport }) => (
service::serve_client(client_handler.clone(), transport).boxed(),
None,
),
Some(PendingTransport::StreamableHttpWithOAuth {
transport,
oauth_persistor,
}) => (
service::serve_client(client_handler.clone(), transport).boxed(),
Some(oauth_persistor),
),
None => return Err(anyhow!("client already initializing")),
},
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
}
};
let service = match timeout {
Some(duration) => time::timeout(duration, transport)
.await
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
None => transport
.await
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
};
let initialize_result_rmcp = service
.peer()
.peer_info()
.ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?;
let initialize_result = convert_to_mcp(initialize_result_rmcp)?;
{
let mut guard = self.state.lock().await;
*guard = ClientState::Ready {
service: Arc::new(service),
oauth: oauth_persistor.clone(),
};
}
if let Some(runtime) = oauth_persistor
&& let Err(error) = runtime.persist_if_needed().await
{
warn!("failed to persist OAuth tokens after initialize: {error}");
}
Ok(initialize_result)
}
pub async fn list_tools(
&self,
params: Option<ListToolsRequestParams>,
timeout: Option<Duration>,
) -> Result<ListToolsResult> {
let service = self.service().await?;
let rmcp_params = params
.map(convert_to_rmcp::<_, PaginatedRequestParam>)
.transpose()?;
let fut = service.list_tools(rmcp_params);
let result = run_with_timeout(fut, timeout, "tools/list").await?;
let converted = convert_to_mcp(result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
pub async fn call_tool(
&self,
name: String,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<CallToolResult> {
let service = self.service().await?;
let params = CallToolRequestParams { arguments, name };
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
let fut = service.call_tool(rmcp_params);
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
let converted = convert_call_tool_result(rmcp_result)?;
self.persist_oauth_tokens().await;
Ok(converted)
}
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
}
}
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
let guard = self.state.lock().await;
match &*guard {
ClientState::Ready {
oauth: Some(runtime),
service: _,
} => Some(runtime.clone()),
_ => None,
}
}
async fn persist_oauth_tokens(&self) {
if let Some(runtime) = self.oauth_persistor().await
&& let Err(error) = runtime.persist_if_needed().await
{
warn!("failed to persist OAuth tokens: {error}");
}
}
}
async fn create_oauth_transport_and_runtime(
server_name: &str,
url: &str,
initial_tokens: StoredOAuthTokens,
) -> Result<(
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
OAuthPersistor,
)> {
let http_client = reqwest::Client::builder().build()?;
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
oauth_state
.set_credentials(
&initial_tokens.client_id,
initial_tokens.token_response.0.clone(),
)
.await?;
let manager = match oauth_state {
OAuthState::Authorized(manager) => manager,
OAuthState::Unauthorized(manager) => manager,
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
return Err(anyhow!("unexpected OAuth state during client setup"));
}
};
let auth_client = AuthClient::new(http_client, manager);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(
auth_client,
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
);
let runtime = OAuthPersistor::new(
server_name.to_string(),
url.to_string(),
auth_manager,
Some(initial_tokens),
);
Ok((transport, runtime))
}