[MCP] Add support for MCP Oauth credentials (#4517)
This PR adds oauth login support to streamable http servers when `experimental_use_rmcp_client` is enabled. This PR is large but represents the minimal amount of work required for this to work. To keep this PR smaller, login can only be done with `codex mcp login` and `codex mcp logout` but it doesn't appear in `/mcp` or `codex mcp list` yet. Fingers crossed that this is the last large MCP PR and that subsequent PRs can be smaller. Under the hood, credentials are stored using platform credential managers using the [keyring crate](https://crates.io/crates/keyring). When the keyring isn't available, it falls back to storing credentials in `CODEX_HOME/.credentials.json` which is consistent with how other coding agents handle authentication. I tested this on macOS, Windows, WSL (ubuntu), and Linux. I wasn't able to test the dbus store on linux but did verify that the fallback works. One quirk is that if you have credentials, during development, every build will have its own ad-hoc binary so the keyring won't recognize the reader as being the same as the write so it may ask for the user's password. I may add an override to disable this or allow users/enterprises to opt-out of the keyring storage if it causes issues. <img width="5064" height="686" alt="CleanShot 2025-09-30 at 19 31 40" src="https://github.com/user-attachments/assets/9573f9b4-07f1-4160-83b8-2920db287e2d" /> <img width="745" height="486" alt="image" src="https://github.com/user-attachments/assets/9562649b-ea5f-4f22-ace2-d0cb438b143e" />
This commit is contained in:
@@ -21,6 +21,8 @@ use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
@@ -31,7 +33,10 @@ use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
@@ -40,7 +45,13 @@ use crate::utils::run_with_timeout;
|
||||
|
||||
enum PendingTransport {
|
||||
ChildProcess(TokioChildProcess),
|
||||
StreamableHttp(StreamableHttpClientTransport<reqwest::Client>),
|
||||
StreamableHttp {
|
||||
transport: StreamableHttpClientTransport<reqwest::Client>,
|
||||
},
|
||||
StreamableHttpWithOAuth {
|
||||
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
oauth_persistor: OAuthPersistor,
|
||||
},
|
||||
}
|
||||
|
||||
enum ClientState {
|
||||
@@ -49,6 +60,7 @@ enum ClientState {
|
||||
},
|
||||
Ready {
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
oauth: Option<OAuthPersistor>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -103,17 +115,37 @@ impl RmcpClient {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_streamable_http_client(url: String, bearer_token: Option<String>) -> Result<Self> {
|
||||
let mut config = StreamableHttpClientTransportConfig::with_uri(url);
|
||||
if let Some(token) = bearer_token {
|
||||
config = config.auth_header(format!("Bearer {token}"));
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(config);
|
||||
pub async fn new_streamable_http_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let initial_tokens = match load_oauth_tokens(server_name, url) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
}
|
||||
};
|
||||
let transport = if let Some(initial_tokens) = initial_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token {
|
||||
http_config = http_config.auth_header(format!("Bearer {bearer_token}"));
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(http_config);
|
||||
PendingTransport::StreamableHttp { transport }
|
||||
};
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(PendingTransport::StreamableHttp(transport)),
|
||||
transport: Some(transport),
|
||||
}),
|
||||
})
|
||||
}
|
||||
@@ -125,35 +157,40 @@ impl RmcpClient {
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<InitializeResult> {
|
||||
let transport = {
|
||||
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(rmcp_params);
|
||||
|
||||
let (transport, oauth_persistor) = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match &mut *guard {
|
||||
ClientState::Connecting { transport } => transport
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("client already initializing"))?,
|
||||
ClientState::Ready { .. } => {
|
||||
return Err(anyhow!("client already initialized"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(client_info);
|
||||
let service_future = match transport {
|
||||
PendingTransport::ChildProcess(transport) => {
|
||||
service::serve_client(client_handler.clone(), transport).boxed()
|
||||
}
|
||||
PendingTransport::StreamableHttp(transport) => {
|
||||
service::serve_client(client_handler, transport).boxed()
|
||||
ClientState::Connecting { transport } => match transport.take() {
|
||||
Some(PendingTransport::ChildProcess(transport)) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttp { transport }) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
Some(oauth_persistor),
|
||||
),
|
||||
None => return Err(anyhow!("client already initializing")),
|
||||
},
|
||||
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
|
||||
}
|
||||
};
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, service_future)
|
||||
Some(duration) => time::timeout(duration, transport)
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
None => service_future
|
||||
None => transport
|
||||
.await
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
@@ -168,9 +205,16 @@ impl RmcpClient {
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
service: Arc::new(service),
|
||||
oauth: oauth_persistor.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(runtime) = oauth_persistor
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens after initialize: {error}");
|
||||
}
|
||||
|
||||
Ok(initialize_result)
|
||||
}
|
||||
|
||||
@@ -186,7 +230,9 @@ impl RmcpClient {
|
||||
|
||||
let fut = service.list_tools(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
||||
convert_to_mcp(result)
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
@@ -200,14 +246,79 @@ impl RmcpClient {
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.call_tool(rmcp_params);
|
||||
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
||||
convert_call_tool_result(rmcp_result)
|
||||
let converted = convert_call_tool_result(rmcp_result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready { service } => Ok(Arc::clone(service)),
|
||||
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
|
||||
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready {
|
||||
oauth: Some(runtime),
|
||||
service: _,
|
||||
} => Some(runtime.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_oauth_tokens(&self) {
|
||||
if let Some(runtime) = self.oauth_persistor().await
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens: {error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client = reqwest::Client::builder().build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(http_client, manager);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
auth_client,
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
||||
);
|
||||
|
||||
let runtime = OAuthPersistor::new(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Ok((transport, runtime))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user