[MCP] Add support for MCP Oauth credentials (#4517)
This PR adds oauth login support to streamable http servers when `experimental_use_rmcp_client` is enabled. This PR is large but represents the minimal amount of work required for this to work. To keep this PR smaller, login can only be done with `codex mcp login` and `codex mcp logout` but it doesn't appear in `/mcp` or `codex mcp list` yet. Fingers crossed that this is the last large MCP PR and that subsequent PRs can be smaller. Under the hood, credentials are stored using platform credential managers using the [keyring crate](https://crates.io/crates/keyring). When the keyring isn't available, it falls back to storing credentials in `CODEX_HOME/.credentials.json` which is consistent with how other coding agents handle authentication. I tested this on macOS, Windows, WSL (ubuntu), and Linux. I wasn't able to test the dbus store on linux but did verify that the fallback works. One quirk is that if you have credentials, during development, every build will have its own ad-hoc binary so the keyring won't recognize the reader as being the same as the write so it may ask for the user's password. I may add an override to disable this or allow users/enterprises to opt-out of the keyring storage if it causes issues. <img width="5064" height="686" alt="CleanShot 2025-09-30 at 19 31 40" src="https://github.com/user-attachments/assets/9573f9b4-07f1-4160-83b8-2920db287e2d" /> <img width="745" height="486" alt="image" src="https://github.com/user-attachments/assets/9562649b-ea5f-4f22-ace2-d0cb438b143e" />
This commit is contained in:
@@ -5,6 +5,14 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use axum::body::Body;
|
||||
use axum::extract::State;
|
||||
use axum::http::Request;
|
||||
use axum::http::StatusCode;
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
use axum::middleware;
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use rmcp::ErrorData as McpError;
|
||||
use rmcp::handler::server::ServerHandler;
|
||||
use rmcp::model::CallToolRequestParam;
|
||||
@@ -161,7 +169,30 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
),
|
||||
);
|
||||
|
||||
let router = if let Ok(token) = std::env::var("MCP_EXPECT_BEARER") {
|
||||
let expected = Arc::new(format!("Bearer {token}"));
|
||||
router.layer(middleware::from_fn_with_state(expected, require_bearer))
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
axum::serve(listener, router).await?;
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn require_bearer(
|
||||
State(expected): State<Arc<String>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.is_some_and(|value| value.as_bytes() == expected.as_bytes())
|
||||
{
|
||||
Ok(next.run(request).await)
|
||||
} else {
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
|
||||
33
codex-rs/rmcp-client/src/find_codex_home.rs
Normal file
33
codex-rs/rmcp-client/src/find_codex_home.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use dirs::home_dir;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// This was copied from codex-core but codex-core depends on this crate.
|
||||
/// TODO: move this to a shared crate lower in the dependency tree.
|
||||
///
|
||||
///
|
||||
/// Returns the path to the Codex configuration directory, which can be
|
||||
/// specified by the `CODEX_HOME` environment variable. If not set, defaults to
|
||||
/// `~/.codex`.
|
||||
///
|
||||
/// - If `CODEX_HOME` is set, the value will be canonicalized and this
|
||||
/// function will Err if the path does not exist.
|
||||
/// - If `CODEX_HOME` is not set, this function does not verify that the
|
||||
/// directory exists.
|
||||
pub(crate) fn find_codex_home() -> std::io::Result<PathBuf> {
|
||||
// Honor the `CODEX_HOME` environment variable when it is set to allow users
|
||||
// (and tests) to override the default location.
|
||||
if let Ok(val) = std::env::var("CODEX_HOME")
|
||||
&& !val.is_empty()
|
||||
{
|
||||
return PathBuf::from(val).canonicalize();
|
||||
}
|
||||
|
||||
let mut p = home_dir().ok_or_else(|| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
"Could not find home directory",
|
||||
)
|
||||
})?;
|
||||
p.push(".codex");
|
||||
Ok(p)
|
||||
}
|
||||
@@ -1,5 +1,14 @@
|
||||
mod find_codex_home;
|
||||
mod logging_client_handler;
|
||||
mod oauth;
|
||||
mod perform_oauth_login;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
pub(crate) use oauth::load_oauth_tokens;
|
||||
pub use oauth::save_oauth_tokens;
|
||||
pub use perform_oauth_login::perform_oauth_login;
|
||||
pub use rmcp_client::RmcpClient;
|
||||
|
||||
822
codex-rs/rmcp-client/src/oauth.rs
Normal file
822
codex-rs/rmcp-client/src/oauth.rs
Normal file
@@ -0,0 +1,822 @@
|
||||
//! This file handles all logic related to managing MCP OAuth credentials.
|
||||
//! All credentials are stored using the keyring crate which uses os-specific keyring services.
|
||||
//! https://crates.io/crates/keyring
|
||||
//! macOS: macOS keychain.
|
||||
//! Windows: Windows Credential Manager
|
||||
//! Linux: DBus-based Secret Service, the kernel keyutils, and a combo of the two
|
||||
//! FreeBSD, OpenBSD: DBus-based Secret Service
|
||||
//!
|
||||
//! For Linux, we use linux-native-async-persistent which uses both keyutils and async-secret-service (see below) for storage.
|
||||
//! See the docs for the keyutils_persistent module for a full explanation of why both are used. Because this store uses the
|
||||
//! async-secret-service, you must specify the additional features required by that store
|
||||
//!
|
||||
//! async-secret-service provides access to the DBus-based Secret Service storage on Linux, FreeBSD, and OpenBSD. This is an asynchronous
|
||||
//! keystore that always encrypts secrets when they are transferred across the bus. If DBus isn't installed the keystore will fall back to the json
|
||||
//! file because we don't use the "vendored" feature.
|
||||
//!
|
||||
//! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents.
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use keyring::Entry;
|
||||
use oauth2::AccessToken;
|
||||
use oauth2::EmptyExtraTokenFields;
|
||||
use oauth2::RefreshToken;
|
||||
use oauth2::Scope;
|
||||
use oauth2::TokenResponse;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use rmcp::transport::auth::OAuthTokenResponse;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::map::Map as JsonMap;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt;
|
||||
use std::fs;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use tracing::warn;
|
||||
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::find_codex_home::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct StoredOAuthTokens {
|
||||
pub server_name: String,
|
||||
pub url: String,
|
||||
pub client_id: String,
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CredentialStoreError(anyhow::Error);
|
||||
|
||||
impl CredentialStoreError {
|
||||
fn new(error: impl Into<anyhow::Error>) -> Self {
|
||||
Self(error.into())
|
||||
}
|
||||
|
||||
fn message(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
fn into_error(self) -> anyhow::Error {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CredentialStoreError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CredentialStoreError {}
|
||||
|
||||
trait CredentialStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
|
||||
}
|
||||
|
||||
struct KeyringCredentialStore;
|
||||
|
||||
impl CredentialStore for KeyringCredentialStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.get_password() {
|
||||
Ok(password) => Ok(Some(password)),
|
||||
Err(keyring::Error::NoEntry) => Ok(None),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
entry.set_password(value).map_err(CredentialStoreError::new)
|
||||
}
|
||||
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.delete_credential() {
|
||||
Ok(()) => Ok(true),
|
||||
Err(keyring::Error::NoEntry) => Ok(false),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap OAuthTokenResponse to allow for partial equality comparison.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WrappedOAuthTokenResponse(pub OAuthTokenResponse);
|
||||
|
||||
impl PartialEq for WrappedOAuthTokenResponse {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (serde_json::to_string(self), serde_json::to_string(other)) {
|
||||
(Ok(s1), Ok(s2)) => s1 == s2,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let store = KeyringCredentialStore;
|
||||
load_oauth_tokens_with_store(&store, server_name, url)
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to read OAuth tokens from keyring: {message}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let store = KeyringCredentialStore;
|
||||
save_oauth_tokens_with_store(&store, server_name, tokens)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to write OAuth tokens to keyring: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result<bool> {
|
||||
let store = KeyringCredentialStore;
|
||||
delete_oauth_tokens_with_store(&store, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_removed = match store.delete(KEYRING_SERVICE, &key) {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
return Err(error.into_error()).context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
};
|
||||
|
||||
let file_removed = delete_oauth_tokens_from_file(&key)?;
|
||||
Ok(keyring_removed || file_removed)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct OAuthPersistor {
|
||||
inner: Arc<OAuthPersistorInner>,
|
||||
}
|
||||
|
||||
struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
manager: Arc<Mutex<AuthorizationManager>>,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager: manager,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the latest stored credentials if they have changed.
|
||||
/// Deletes the credentials if they are no longer present.
|
||||
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
|
||||
let (client_id, maybe_credentials) = {
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.get_credentials().await
|
||||
}?;
|
||||
|
||||
match maybe_credentials {
|
||||
Some(credentials) => {
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.inner.server_name.clone(),
|
||||
url: self.inner.url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials.clone()),
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) =
|
||||
delete_oauth_tokens(&self.inner.server_name, &self.inner.url)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
self.inner.server_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
const FALLBACK_FILENAME: &str = ".credentials.json";
|
||||
const MCP_SERVER_TYPE: &str = "http";
|
||||
|
||||
type FallbackFile = BTreeMap<String, FallbackTokenEntry>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FallbackTokenEntry {
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
client_id: String,
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let Some(store) = read_fallback_file()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
|
||||
for entry in store.values() {
|
||||
let entry_key = compute_store_key(&entry.server_name, &entry.server_url)?;
|
||||
if entry_key != key {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut token_response = OAuthTokenResponse::new(
|
||||
AccessToken::new(entry.access_token.clone()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
|
||||
if let Some(refresh) = entry.refresh_token.clone() {
|
||||
token_response.set_refresh_token(Some(RefreshToken::new(refresh)));
|
||||
}
|
||||
|
||||
let scopes = entry.scopes.clone();
|
||||
if !scopes.is_empty() {
|
||||
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at
|
||||
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
|
||||
{
|
||||
let duration = Duration::from_secs(seconds);
|
||||
token_response.set_expires_in(Some(&duration));
|
||||
}
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: entry.server_name.clone(),
|
||||
url: entry.server_url.clone(),
|
||||
client_id: entry.client_id.clone(),
|
||||
token_response: WrappedOAuthTokenResponse(token_response),
|
||||
};
|
||||
|
||||
return Ok(Some(stored));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let key = compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let mut store = read_fallback_file()?.unwrap_or_default();
|
||||
|
||||
let token_response = &tokens.token_response.0;
|
||||
let refresh_token = token_response
|
||||
.refresh_token()
|
||||
.map(|token| token.secret().to_string());
|
||||
let scopes = token_response
|
||||
.scopes()
|
||||
.map(|s| s.iter().map(|s| s.to_string()).collect())
|
||||
.unwrap_or_default();
|
||||
let entry = FallbackTokenEntry {
|
||||
server_name: tokens.server_name.clone(),
|
||||
server_url: tokens.url.clone(),
|
||||
client_id: tokens.client_id.clone(),
|
||||
access_token: token_response.access_token().secret().to_string(),
|
||||
expires_at: compute_expires_at_millis(token_response),
|
||||
refresh_token,
|
||||
scopes,
|
||||
};
|
||||
|
||||
store.insert(key, entry);
|
||||
write_fallback_file(&store)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
|
||||
let mut store = match read_fallback_file()? {
|
||||
Some(store) => store,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let removed = store.remove(key).is_some();
|
||||
|
||||
if removed {
|
||||
write_fallback_file(&store)?;
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
||||
let expires_in = response.expires_in()?;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let expiry = now.checked_add(expires_in)?;
|
||||
let millis = expiry.as_millis();
|
||||
if millis > u128::from(u64::MAX) {
|
||||
Some(u64::MAX)
|
||||
} else {
|
||||
Some(millis as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let now_ms = now.as_millis() as u64;
|
||||
|
||||
if expires_at <= now_ms {
|
||||
None
|
||||
} else {
|
||||
Some((expires_at - now_ms) / 1000)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
||||
let mut payload = JsonMap::new();
|
||||
payload.insert(
|
||||
"type".to_string(),
|
||||
Value::String(MCP_SERVER_TYPE.to_string()),
|
||||
);
|
||||
payload.insert("url".to_string(), Value::String(server_url.to_string()));
|
||||
payload.insert("headers".to_string(), Value::Object(JsonMap::new()));
|
||||
|
||||
let truncated = sha_256_prefix(&Value::Object(payload))?;
|
||||
Ok(format!("{server_name}|{truncated}"))
|
||||
}
|
||||
|
||||
fn fallback_file_path() -> Result<PathBuf> {
|
||||
let mut path = find_codex_home()?;
|
||||
path.push(FALLBACK_FILENAME);
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
fn read_fallback_file() -> Result<Option<FallbackFile>> {
|
||||
let path = fallback_file_path()?;
|
||||
let contents = match fs::read_to_string(&path) {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).context(format!(
|
||||
"failed to read credentials file at {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_str::<FallbackFile>(&contents) {
|
||||
Ok(store) => Ok(Some(store)),
|
||||
Err(e) => Err(e).context(format!(
|
||||
"failed to parse credentials file at {}",
|
||||
path.display()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_fallback_file(store: &FallbackFile) -> Result<()> {
|
||||
let path = fallback_file_path()?;
|
||||
|
||||
if store.is_empty() {
|
||||
if path.exists() {
|
||||
fs::remove_file(path)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let serialized = serde_json::to_string(store)?;
|
||||
fs::write(&path, serialized)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = fs::Permissions::from_mode(0o600);
|
||||
fs::set_permissions(&path, perms)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sha_256_prefix(value: &Value) -> Result<String> {
|
||||
let serialized =
|
||||
serde_json::to_string(&value).context("failed to serialize MCP OAuth key payload")?;
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(serialized.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let hex = format!("{digest:x}");
|
||||
let truncated = &hex[..16];
|
||||
Ok(truncated.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use keyring::Error as KeyringError;
|
||||
use keyring::credential::CredentialApi as _;
|
||||
use keyring::mock::MockCredential;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::PoisonError;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct MockCredentialStore {
|
||||
credentials: Arc<Mutex<HashMap<String, Arc<MockCredential>>>>,
|
||||
}
|
||||
|
||||
impl MockCredentialStore {
|
||||
fn credential(&self, account: &str) -> Arc<MockCredential> {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard
|
||||
.entry(account.to_string())
|
||||
.or_insert_with(|| Arc::new(MockCredential::default()))
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn saved_value(&self, account: &str) -> Option<String> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
}?;
|
||||
credential.get_password().ok()
|
||||
}
|
||||
|
||||
fn set_error(&self, account: &str, error: KeyringError) {
|
||||
let credential = self.credential(account);
|
||||
credential.set_error(error);
|
||||
}
|
||||
|
||||
fn contains(&self, account: &str) -> bool {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.contains_key(account)
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialStore for MockCredentialStore {
|
||||
fn load(
|
||||
&self,
|
||||
_service: &str,
|
||||
account: &str,
|
||||
) -> Result<Option<String>, CredentialStoreError> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
};
|
||||
|
||||
let Some(credential) = credential else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match credential.get_password() {
|
||||
Ok(password) => Ok(Some(password)),
|
||||
Err(KeyringError::NoEntry) => Ok(None),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn save(
|
||||
&self,
|
||||
_service: &str,
|
||||
account: &str,
|
||||
value: &str,
|
||||
) -> Result<(), CredentialStoreError> {
|
||||
let credential = self.credential(account);
|
||||
credential
|
||||
.set_password(value)
|
||||
.map_err(CredentialStoreError::new)
|
||||
}
|
||||
|
||||
fn delete(&self, _service: &str, account: &str) -> Result<bool, CredentialStoreError> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
};
|
||||
|
||||
let Some(credential) = credential else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
match credential.delete_credential() {
|
||||
Ok(()) => {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard.remove(account);
|
||||
Ok(true)
|
||||
}
|
||||
Err(KeyringError::NoEntry) => {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard.remove(account);
|
||||
Ok(false)
|
||||
}
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TempCodexHome {
|
||||
_guard: MutexGuard<'static, ()>,
|
||||
_dir: tempfile::TempDir,
|
||||
}
|
||||
|
||||
impl TempCodexHome {
|
||||
fn new() -> Self {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
let guard = LOCK
|
||||
.get_or_init(Mutex::default)
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner);
|
||||
let dir = tempdir().expect("create CODEX_HOME temp dir");
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", dir.path());
|
||||
}
|
||||
Self {
|
||||
_guard: guard,
|
||||
_dir: dir,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempCodexHome {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
std::env::remove_var("CODEX_HOME");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_reads_from_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_missing_in_keyring() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
let stored = store.saved_value(&key).expect("value saved to keyring");
|
||||
assert_eq!(serde_json::from_str::<StoredOAuthTokens>(&stored)?, tokens);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
let saved = super::read_fallback_file()?.expect("fallback file should load");
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let entry = saved.get(&key).expect("entry for key");
|
||||
assert_eq!(entry.server_name, tokens.server_name);
|
||||
assert_eq!(entry.server_url, tokens.url);
|
||||
assert_eq!(entry.client_id, tokens.client_id);
|
||||
assert_eq!(
|
||||
entry.access_token,
|
||||
tokens.token_response.0.access_token().secret().as_str()
|
||||
);
|
||||
assert!(store.saved_value(&key).is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_removes_all_storage() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
) {
|
||||
assert_eq!(actual.server_name, expected.server_name);
|
||||
assert_eq!(actual.url, expected.url);
|
||||
assert_eq!(actual.client_id, expected.client_id);
|
||||
assert_token_response_match_without_expiry(
|
||||
&actual.token_response,
|
||||
&expected.token_response,
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_token_response_match_without_expiry(
|
||||
actual: &WrappedOAuthTokenResponse,
|
||||
expected: &WrappedOAuthTokenResponse,
|
||||
) {
|
||||
let actual_response = &actual.0;
|
||||
let expected_response = &expected.0;
|
||||
|
||||
assert_eq!(
|
||||
actual_response.access_token().secret(),
|
||||
expected_response.access_token().secret()
|
||||
);
|
||||
assert_eq!(actual_response.token_type(), expected_response.token_type());
|
||||
assert_eq!(
|
||||
actual_response.refresh_token().map(RefreshToken::secret),
|
||||
expected_response.refresh_token().map(RefreshToken::secret),
|
||||
);
|
||||
assert_eq!(actual_response.scopes(), expected_response.scopes());
|
||||
assert_eq!(
|
||||
actual_response.extra_fields(),
|
||||
expected_response.extra_fields()
|
||||
);
|
||||
assert_eq!(
|
||||
actual_response.expires_in().is_some(),
|
||||
expected_response.expires_in().is_some()
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_tokens() -> StoredOAuthTokens {
|
||||
let mut response = OAuthTokenResponse::new(
|
||||
AccessToken::new("access-token".to_string()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
response.set_refresh_token(Some(RefreshToken::new("refresh-token".to_string())));
|
||||
response.set_scopes(Some(vec![
|
||||
Scope::new("scope-a".to_string()),
|
||||
Scope::new("scope-b".to_string()),
|
||||
]));
|
||||
let expires_in = Duration::from_secs(3600);
|
||||
response.set_expires_in(Some(&expires_in));
|
||||
|
||||
StoredOAuthTokens {
|
||||
server_name: "test-server".to_string(),
|
||||
url: "https://example.test".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
token_response: WrappedOAuthTokenResponse(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
141
codex-rs/rmcp-client/src/perform_oauth_login.rs
Normal file
141
codex-rs/rmcp-client/src/perform_oauth_login.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use std::string::String;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use tiny_http::Response;
|
||||
use tiny_http::Server;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
|
||||
struct CallbackServerGuard {
|
||||
server: Arc<Server>,
|
||||
}
|
||||
|
||||
impl Drop for CallbackServerGuard {
|
||||
fn drop(&mut self) {
|
||||
self.server.unblock();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
};
|
||||
|
||||
let redirect_uri = match server.server_addr() {
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
|
||||
format!("http://{}:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
|
||||
format!("http://[{}]:{}/callback", addr.ip(), addr.port())
|
||||
}
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
_ => return Err(anyhow!("unable to determine callback address")),
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
spawn_callback_server(server, tx);
|
||||
|
||||
let mut oauth_state = OAuthState::new(server_url, None).await?;
|
||||
oauth_state.start_authorization(&[], &redirect_uri).await?;
|
||||
let auth_url = oauth_state.get_authorization_url().await?;
|
||||
|
||||
println!("Authorize `{server_name}` by opening this URL in your browser:\n{auth_url}\n");
|
||||
|
||||
if webbrowser::open(&auth_url).is_err() {
|
||||
println!("(Browser launch failed; please copy the URL above manually.)");
|
||||
}
|
||||
|
||||
let (code, csrf_state) = timeout(Duration::from_secs(300), rx)
|
||||
.await
|
||||
.context("timed out waiting for OAuth callback")?
|
||||
.context("OAuth callback was cancelled")?;
|
||||
|
||||
oauth_state
|
||||
.handle_callback(&code, &csrf_state)
|
||||
.await
|
||||
.context("failed to handle OAuth callback")?;
|
||||
|
||||
let (client_id, credentials_opt) = oauth_state
|
||||
.get_credentials()
|
||||
.await
|
||||
.context("failed to retrieve OAuth credentials")?;
|
||||
let credentials =
|
||||
credentials_opt.ok_or_else(|| anyhow!("OAuth provider did not return credentials"))?;
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: server_name.to_string(),
|
||||
url: server_url.to_string(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spawn_callback_server(server: Arc<Server>, tx: oneshot::Sender<(String, String)>) {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
while let Ok(request) = server.recv() {
|
||||
let path = request.url().to_string();
|
||||
if let Some(OauthCallbackResult { code, state }) = parse_oauth_callback(&path) {
|
||||
let response =
|
||||
Response::from_string("Authentication complete. You may close this window.");
|
||||
if let Err(err) = request.respond(response) {
|
||||
eprintln!("Failed to respond to OAuth callback: {err}");
|
||||
}
|
||||
if let Err(err) = tx.send((code, state)) {
|
||||
eprintln!("Failed to send OAuth callback: {err:?}");
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
let response =
|
||||
Response::from_string("Invalid OAuth callback").with_status_code(400);
|
||||
if let Err(err) = request.respond(response) {
|
||||
eprintln!("Failed to respond to OAuth callback: {err}");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
struct OauthCallbackResult {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
fn parse_oauth_callback(path: &str) -> Option<OauthCallbackResult> {
|
||||
let (route, query) = path.split_once('?')?;
|
||||
if route != "/callback" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut code = None;
|
||||
let mut state = None;
|
||||
|
||||
for pair in query.split('&') {
|
||||
let (key, value) = pair.split_once('=')?;
|
||||
let decoded = decode(value).ok()?.into_owned();
|
||||
match key {
|
||||
"code" => code = Some(decoded),
|
||||
"state" => state = Some(decoded),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Some(OauthCallbackResult {
|
||||
code: code?,
|
||||
state: state?,
|
||||
})
|
||||
}
|
||||
@@ -21,6 +21,8 @@ use rmcp::service::RoleClient;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::service::{self};
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::child_process::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
@@ -31,7 +33,10 @@ use tokio::time;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
use crate::utils::convert_to_mcp;
|
||||
use crate::utils::convert_to_rmcp;
|
||||
@@ -40,7 +45,13 @@ use crate::utils::run_with_timeout;
|
||||
|
||||
enum PendingTransport {
|
||||
ChildProcess(TokioChildProcess),
|
||||
StreamableHttp(StreamableHttpClientTransport<reqwest::Client>),
|
||||
StreamableHttp {
|
||||
transport: StreamableHttpClientTransport<reqwest::Client>,
|
||||
},
|
||||
StreamableHttpWithOAuth {
|
||||
transport: StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
oauth_persistor: OAuthPersistor,
|
||||
},
|
||||
}
|
||||
|
||||
enum ClientState {
|
||||
@@ -49,6 +60,7 @@ enum ClientState {
|
||||
},
|
||||
Ready {
|
||||
service: Arc<RunningService<RoleClient, LoggingClientHandler>>,
|
||||
oauth: Option<OAuthPersistor>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -103,17 +115,37 @@ impl RmcpClient {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new_streamable_http_client(url: String, bearer_token: Option<String>) -> Result<Self> {
|
||||
let mut config = StreamableHttpClientTransportConfig::with_uri(url);
|
||||
if let Some(token) = bearer_token {
|
||||
config = config.auth_header(format!("Bearer {token}"));
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(config);
|
||||
pub async fn new_streamable_http_client(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let initial_tokens = match load_oauth_tokens(server_name, url) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
None
|
||||
}
|
||||
};
|
||||
let transport = if let Some(initial_tokens) = initial_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}
|
||||
} else {
|
||||
let mut http_config = StreamableHttpClientTransportConfig::with_uri(url.to_string());
|
||||
if let Some(bearer_token) = bearer_token {
|
||||
http_config = http_config.auth_header(format!("Bearer {bearer_token}"));
|
||||
}
|
||||
|
||||
let transport = StreamableHttpClientTransport::from_config(http_config);
|
||||
PendingTransport::StreamableHttp { transport }
|
||||
};
|
||||
Ok(Self {
|
||||
state: Mutex::new(ClientState::Connecting {
|
||||
transport: Some(PendingTransport::StreamableHttp(transport)),
|
||||
transport: Some(transport),
|
||||
}),
|
||||
})
|
||||
}
|
||||
@@ -125,35 +157,40 @@ impl RmcpClient {
|
||||
params: InitializeRequestParams,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<InitializeResult> {
|
||||
let transport = {
|
||||
let rmcp_params: InitializeRequestParam = convert_to_rmcp(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(rmcp_params);
|
||||
|
||||
let (transport, oauth_persistor) = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match &mut *guard {
|
||||
ClientState::Connecting { transport } => transport
|
||||
.take()
|
||||
.ok_or_else(|| anyhow!("client already initializing"))?,
|
||||
ClientState::Ready { .. } => {
|
||||
return Err(anyhow!("client already initialized"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?;
|
||||
let client_handler = LoggingClientHandler::new(client_info);
|
||||
let service_future = match transport {
|
||||
PendingTransport::ChildProcess(transport) => {
|
||||
service::serve_client(client_handler.clone(), transport).boxed()
|
||||
}
|
||||
PendingTransport::StreamableHttp(transport) => {
|
||||
service::serve_client(client_handler, transport).boxed()
|
||||
ClientState::Connecting { transport } => match transport.take() {
|
||||
Some(PendingTransport::ChildProcess(transport)) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttp { transport }) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
None,
|
||||
),
|
||||
Some(PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
}) => (
|
||||
service::serve_client(client_handler.clone(), transport).boxed(),
|
||||
Some(oauth_persistor),
|
||||
),
|
||||
None => return Err(anyhow!("client already initializing")),
|
||||
},
|
||||
ClientState::Ready { .. } => return Err(anyhow!("client already initialized")),
|
||||
}
|
||||
};
|
||||
|
||||
let service = match timeout {
|
||||
Some(duration) => time::timeout(duration, service_future)
|
||||
Some(duration) => time::timeout(duration, transport)
|
||||
.await
|
||||
.map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))?
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
None => service_future
|
||||
None => transport
|
||||
.await
|
||||
.map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?,
|
||||
};
|
||||
@@ -168,9 +205,16 @@ impl RmcpClient {
|
||||
let mut guard = self.state.lock().await;
|
||||
*guard = ClientState::Ready {
|
||||
service: Arc::new(service),
|
||||
oauth: oauth_persistor.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(runtime) = oauth_persistor
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens after initialize: {error}");
|
||||
}
|
||||
|
||||
Ok(initialize_result)
|
||||
}
|
||||
|
||||
@@ -186,7 +230,9 @@ impl RmcpClient {
|
||||
|
||||
let fut = service.list_tools(rmcp_params);
|
||||
let result = run_with_timeout(fut, timeout, "tools/list").await?;
|
||||
convert_to_mcp(result)
|
||||
let converted = convert_to_mcp(result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
pub async fn call_tool(
|
||||
@@ -200,14 +246,79 @@ impl RmcpClient {
|
||||
let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?;
|
||||
let fut = service.call_tool(rmcp_params);
|
||||
let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?;
|
||||
convert_call_tool_result(rmcp_result)
|
||||
let converted = convert_call_tool_result(rmcp_result)?;
|
||||
self.persist_oauth_tokens().await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
async fn service(&self) -> Result<Arc<RunningService<RoleClient, LoggingClientHandler>>> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready { service } => Ok(Arc::clone(service)),
|
||||
ClientState::Ready { service, .. } => Ok(Arc::clone(service)),
|
||||
ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn oauth_persistor(&self) -> Option<OAuthPersistor> {
|
||||
let guard = self.state.lock().await;
|
||||
match &*guard {
|
||||
ClientState::Ready {
|
||||
oauth: Some(runtime),
|
||||
service: _,
|
||||
} => Some(runtime.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn persist_oauth_tokens(&self) {
|
||||
if let Some(runtime) = self.oauth_persistor().await
|
||||
&& let Err(error) = runtime.persist_if_needed().await
|
||||
{
|
||||
warn!("failed to persist OAuth tokens: {error}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client = reqwest::Client::builder().build()?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
&initial_tokens.client_id,
|
||||
initial_tokens.token_response.0.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let manager = match oauth_state {
|
||||
OAuthState::Authorized(manager) => manager,
|
||||
OAuthState::Unauthorized(manager) => manager,
|
||||
OAuthState::Session(_) | OAuthState::AuthorizedHttpClient(_) => {
|
||||
return Err(anyhow!("unexpected OAuth state during client setup"));
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(http_client, manager);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
auth_client,
|
||||
StreamableHttpClientTransportConfig::with_uri(url.to_string()),
|
||||
);
|
||||
|
||||
let runtime = OAuthPersistor::new(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
Ok((transport, runtime))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user