[MCP] Add support for MCP Oauth credentials (#4517)
This PR adds oauth login support to streamable http servers when `experimental_use_rmcp_client` is enabled. This PR is large but represents the minimal amount of work required for this to work. To keep this PR smaller, login can only be done with `codex mcp login` and `codex mcp logout` but it doesn't appear in `/mcp` or `codex mcp list` yet. Fingers crossed that this is the last large MCP PR and that subsequent PRs can be smaller. Under the hood, credentials are stored using platform credential managers using the [keyring crate](https://crates.io/crates/keyring). When the keyring isn't available, it falls back to storing credentials in `CODEX_HOME/.credentials.json` which is consistent with how other coding agents handle authentication. I tested this on macOS, Windows, WSL (ubuntu), and Linux. I wasn't able to test the dbus store on linux but did verify that the fallback works. One quirk is that if you have credentials, during development, every build will have its own ad-hoc binary so the keyring won't recognize the reader as being the same as the write so it may ask for the user's password. I may add an override to disable this or allow users/enterprises to opt-out of the keyring storage if it causes issues. <img width="5064" height="686" alt="CleanShot 2025-09-30 at 19 31 40" src="https://github.com/user-attachments/assets/9573f9b4-07f1-4160-83b8-2920db287e2d" /> <img width="745" height="486" alt="image" src="https://github.com/user-attachments/assets/9562649b-ea5f-4f22-ace2-d0cb438b143e" />
This commit is contained in:
822
codex-rs/rmcp-client/src/oauth.rs
Normal file
822
codex-rs/rmcp-client/src/oauth.rs
Normal file
@@ -0,0 +1,822 @@
|
||||
//! This file handles all logic related to managing MCP OAuth credentials.
|
||||
//! All credentials are stored using the keyring crate which uses os-specific keyring services.
|
||||
//! https://crates.io/crates/keyring
|
||||
//! macOS: macOS keychain.
|
||||
//! Windows: Windows Credential Manager
|
||||
//! Linux: DBus-based Secret Service, the kernel keyutils, and a combo of the two
|
||||
//! FreeBSD, OpenBSD: DBus-based Secret Service
|
||||
//!
|
||||
//! For Linux, we use linux-native-async-persistent which uses both keyutils and async-secret-service (see below) for storage.
|
||||
//! See the docs for the keyutils_persistent module for a full explanation of why both are used. Because this store uses the
|
||||
//! async-secret-service, you must specify the additional features required by that store
|
||||
//!
|
||||
//! async-secret-service provides access to the DBus-based Secret Service storage on Linux, FreeBSD, and OpenBSD. This is an asynchronous
|
||||
//! keystore that always encrypts secrets when they are transferred across the bus. If DBus isn't installed the keystore will fall back to the json
|
||||
//! file because we don't use the "vendored" feature.
|
||||
//!
|
||||
//! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents.
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use keyring::Entry;
|
||||
use oauth2::AccessToken;
|
||||
use oauth2::EmptyExtraTokenFields;
|
||||
use oauth2::RefreshToken;
|
||||
use oauth2::Scope;
|
||||
use oauth2::TokenResponse;
|
||||
use oauth2::basic::BasicTokenType;
|
||||
use rmcp::transport::auth::OAuthTokenResponse;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::map::Map as JsonMap;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt;
|
||||
use std::fs;
|
||||
use std::io::ErrorKind;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use tracing::warn;
|
||||
|
||||
use rmcp::transport::auth::AuthorizationManager;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::find_codex_home::find_codex_home;
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct StoredOAuthTokens {
|
||||
pub server_name: String,
|
||||
pub url: String,
|
||||
pub client_id: String,
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CredentialStoreError(anyhow::Error);
|
||||
|
||||
impl CredentialStoreError {
|
||||
fn new(error: impl Into<anyhow::Error>) -> Self {
|
||||
Self(error.into())
|
||||
}
|
||||
|
||||
fn message(&self) -> String {
|
||||
self.0.to_string()
|
||||
}
|
||||
|
||||
fn into_error(self) -> anyhow::Error {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for CredentialStoreError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CredentialStoreError {}
|
||||
|
||||
trait CredentialStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
|
||||
}
|
||||
|
||||
struct KeyringCredentialStore;
|
||||
|
||||
impl CredentialStore for KeyringCredentialStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.get_password() {
|
||||
Ok(password) => Ok(Some(password)),
|
||||
Err(keyring::Error::NoEntry) => Ok(None),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
entry.set_password(value).map_err(CredentialStoreError::new)
|
||||
}
|
||||
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.delete_credential() {
|
||||
Ok(()) => Ok(true),
|
||||
Err(keyring::Error::NoEntry) => Ok(false),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap OAuthTokenResponse to allow for partial equality comparison.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WrappedOAuthTokenResponse(pub OAuthTokenResponse);
|
||||
|
||||
impl PartialEq for WrappedOAuthTokenResponse {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (serde_json::to_string(self), serde_json::to_string(other)) {
|
||||
(Ok(s1), Ok(s2)) => s1 == s2,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let store = KeyringCredentialStore;
|
||||
load_oauth_tokens_with_store(&store, server_name, url)
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to read OAuth tokens from keyring: {message}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let store = KeyringCredentialStore;
|
||||
save_oauth_tokens_with_store(&store, server_name, tokens)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to write OAuth tokens to keyring: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result<bool> {
|
||||
let store = KeyringCredentialStore;
|
||||
delete_oauth_tokens_with_store(&store, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_removed = match store.delete(KEYRING_SERVICE, &key) {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
return Err(error.into_error()).context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
};
|
||||
|
||||
let file_removed = delete_oauth_tokens_from_file(&key)?;
|
||||
Ok(keyring_removed || file_removed)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct OAuthPersistor {
|
||||
inner: Arc<OAuthPersistorInner>,
|
||||
}
|
||||
|
||||
struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
manager: Arc<Mutex<AuthorizationManager>>,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager: manager,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Persists the latest stored credentials if they have changed.
|
||||
/// Deletes the credentials if they are no longer present.
|
||||
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
|
||||
let (client_id, maybe_credentials) = {
|
||||
let manager = self.inner.authorization_manager.clone();
|
||||
let guard = manager.lock().await;
|
||||
guard.get_credentials().await
|
||||
}?;
|
||||
|
||||
match maybe_credentials {
|
||||
Some(credentials) => {
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: self.inner.server_name.clone(),
|
||||
url: self.inner.url.clone(),
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials.clone()),
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) =
|
||||
delete_oauth_tokens(&self.inner.server_name, &self.inner.url)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
self.inner.server_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
const FALLBACK_FILENAME: &str = ".credentials.json";
|
||||
const MCP_SERVER_TYPE: &str = "http";
|
||||
|
||||
type FallbackFile = BTreeMap<String, FallbackTokenEntry>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FallbackTokenEntry {
|
||||
server_name: String,
|
||||
server_url: String,
|
||||
client_id: String,
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
expires_at: Option<u64>,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let Some(store) = read_fallback_file()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
|
||||
for entry in store.values() {
|
||||
let entry_key = compute_store_key(&entry.server_name, &entry.server_url)?;
|
||||
if entry_key != key {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut token_response = OAuthTokenResponse::new(
|
||||
AccessToken::new(entry.access_token.clone()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
|
||||
if let Some(refresh) = entry.refresh_token.clone() {
|
||||
token_response.set_refresh_token(Some(RefreshToken::new(refresh)));
|
||||
}
|
||||
|
||||
let scopes = entry.scopes.clone();
|
||||
if !scopes.is_empty() {
|
||||
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
|
||||
}
|
||||
|
||||
if let Some(expires_at) = entry.expires_at
|
||||
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
|
||||
{
|
||||
let duration = Duration::from_secs(seconds);
|
||||
token_response.set_expires_in(Some(&duration));
|
||||
}
|
||||
|
||||
let stored = StoredOAuthTokens {
|
||||
server_name: entry.server_name.clone(),
|
||||
url: entry.server_url.clone(),
|
||||
client_id: entry.client_id.clone(),
|
||||
token_response: WrappedOAuthTokenResponse(token_response),
|
||||
};
|
||||
|
||||
return Ok(Some(stored));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let key = compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let mut store = read_fallback_file()?.unwrap_or_default();
|
||||
|
||||
let token_response = &tokens.token_response.0;
|
||||
let refresh_token = token_response
|
||||
.refresh_token()
|
||||
.map(|token| token.secret().to_string());
|
||||
let scopes = token_response
|
||||
.scopes()
|
||||
.map(|s| s.iter().map(|s| s.to_string()).collect())
|
||||
.unwrap_or_default();
|
||||
let entry = FallbackTokenEntry {
|
||||
server_name: tokens.server_name.clone(),
|
||||
server_url: tokens.url.clone(),
|
||||
client_id: tokens.client_id.clone(),
|
||||
access_token: token_response.access_token().secret().to_string(),
|
||||
expires_at: compute_expires_at_millis(token_response),
|
||||
refresh_token,
|
||||
scopes,
|
||||
};
|
||||
|
||||
store.insert(key, entry);
|
||||
write_fallback_file(&store)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
|
||||
let mut store = match read_fallback_file()? {
|
||||
Some(store) => store,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
let removed = store.remove(key).is_some();
|
||||
|
||||
if removed {
|
||||
write_fallback_file(&store)?;
|
||||
}
|
||||
|
||||
Ok(removed)
|
||||
}
|
||||
|
||||
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
||||
let expires_in = response.expires_in()?;
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let expiry = now.checked_add(expires_in)?;
|
||||
let millis = expiry.as_millis();
|
||||
if millis > u128::from(u64::MAX) {
|
||||
Some(u64::MAX)
|
||||
} else {
|
||||
Some(millis as u64)
|
||||
}
|
||||
}
|
||||
|
||||
fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0));
|
||||
let now_ms = now.as_millis() as u64;
|
||||
|
||||
if expires_at <= now_ms {
|
||||
None
|
||||
} else {
|
||||
Some((expires_at - now_ms) / 1000)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
||||
let mut payload = JsonMap::new();
|
||||
payload.insert(
|
||||
"type".to_string(),
|
||||
Value::String(MCP_SERVER_TYPE.to_string()),
|
||||
);
|
||||
payload.insert("url".to_string(), Value::String(server_url.to_string()));
|
||||
payload.insert("headers".to_string(), Value::Object(JsonMap::new()));
|
||||
|
||||
let truncated = sha_256_prefix(&Value::Object(payload))?;
|
||||
Ok(format!("{server_name}|{truncated}"))
|
||||
}
|
||||
|
||||
fn fallback_file_path() -> Result<PathBuf> {
|
||||
let mut path = find_codex_home()?;
|
||||
path.push(FALLBACK_FILENAME);
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
fn read_fallback_file() -> Result<Option<FallbackFile>> {
|
||||
let path = fallback_file_path()?;
|
||||
let contents = match fs::read_to_string(&path) {
|
||||
Ok(contents) => contents,
|
||||
Err(err) if err.kind() == ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => {
|
||||
return Err(err).context(format!(
|
||||
"failed to read credentials file at {}",
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match serde_json::from_str::<FallbackFile>(&contents) {
|
||||
Ok(store) => Ok(Some(store)),
|
||||
Err(e) => Err(e).context(format!(
|
||||
"failed to parse credentials file at {}",
|
||||
path.display()
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_fallback_file(store: &FallbackFile) -> Result<()> {
|
||||
let path = fallback_file_path()?;
|
||||
|
||||
if store.is_empty() {
|
||||
if path.exists() {
|
||||
fs::remove_file(path)?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let serialized = serde_json::to_string(store)?;
|
||||
fs::write(&path, serialized)?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = fs::Permissions::from_mode(0o600);
|
||||
fs::set_permissions(&path, perms)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sha_256_prefix(value: &Value) -> Result<String> {
|
||||
let serialized =
|
||||
serde_json::to_string(&value).context("failed to serialize MCP OAuth key payload")?;
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(serialized.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let hex = format!("{digest:x}");
|
||||
let truncated = &hex[..16];
|
||||
Ok(truncated.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use keyring::Error as KeyringError;
|
||||
use keyring::credential::CredentialApi as _;
|
||||
use keyring::mock::MockCredential;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::MutexGuard;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::PoisonError;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct MockCredentialStore {
|
||||
credentials: Arc<Mutex<HashMap<String, Arc<MockCredential>>>>,
|
||||
}
|
||||
|
||||
impl MockCredentialStore {
|
||||
fn credential(&self, account: &str) -> Arc<MockCredential> {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard
|
||||
.entry(account.to_string())
|
||||
.or_insert_with(|| Arc::new(MockCredential::default()))
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn saved_value(&self, account: &str) -> Option<String> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
}?;
|
||||
credential.get_password().ok()
|
||||
}
|
||||
|
||||
fn set_error(&self, account: &str, error: KeyringError) {
|
||||
let credential = self.credential(account);
|
||||
credential.set_error(error);
|
||||
}
|
||||
|
||||
fn contains(&self, account: &str) -> bool {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.contains_key(account)
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialStore for MockCredentialStore {
|
||||
fn load(
|
||||
&self,
|
||||
_service: &str,
|
||||
account: &str,
|
||||
) -> Result<Option<String>, CredentialStoreError> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
};
|
||||
|
||||
let Some(credential) = credential else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match credential.get_password() {
|
||||
Ok(password) => Ok(Some(password)),
|
||||
Err(KeyringError::NoEntry) => Ok(None),
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
|
||||
fn save(
|
||||
&self,
|
||||
_service: &str,
|
||||
account: &str,
|
||||
value: &str,
|
||||
) -> Result<(), CredentialStoreError> {
|
||||
let credential = self.credential(account);
|
||||
credential
|
||||
.set_password(value)
|
||||
.map_err(CredentialStoreError::new)
|
||||
}
|
||||
|
||||
fn delete(&self, _service: &str, account: &str) -> Result<bool, CredentialStoreError> {
|
||||
let credential = {
|
||||
let guard = self.credentials.lock().unwrap();
|
||||
guard.get(account).cloned()
|
||||
};
|
||||
|
||||
let Some(credential) = credential else {
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
match credential.delete_credential() {
|
||||
Ok(()) => {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard.remove(account);
|
||||
Ok(true)
|
||||
}
|
||||
Err(KeyringError::NoEntry) => {
|
||||
let mut guard = self.credentials.lock().unwrap();
|
||||
guard.remove(account);
|
||||
Ok(false)
|
||||
}
|
||||
Err(error) => Err(CredentialStoreError::new(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TempCodexHome {
|
||||
_guard: MutexGuard<'static, ()>,
|
||||
_dir: tempfile::TempDir,
|
||||
}
|
||||
|
||||
impl TempCodexHome {
|
||||
fn new() -> Self {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
let guard = LOCK
|
||||
.get_or_init(Mutex::default)
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner);
|
||||
let dir = tempdir().expect("create CODEX_HOME temp dir");
|
||||
unsafe {
|
||||
std::env::set_var("CODEX_HOME", dir.path());
|
||||
}
|
||||
Self {
|
||||
_guard: guard,
|
||||
_dir: dir,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempCodexHome {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
std::env::remove_var("CODEX_HOME");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_reads_from_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_missing_in_keyring() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let expected = tokens.clone();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
let stored = store.saved_value(&key).expect("value saved to keyring");
|
||||
assert_eq!(serde_json::from_str::<StoredOAuthTokens>(&stored)?, tokens);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
let saved = super::read_fallback_file()?.expect("fallback file should load");
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
let entry = saved.get(&key).expect("entry for key");
|
||||
assert_eq!(entry.server_name, tokens.server_name);
|
||||
assert_eq!(entry.server_url, tokens.url);
|
||||
assert_eq!(entry.client_id, tokens.client_id);
|
||||
assert_eq!(
|
||||
entry.access_token,
|
||||
tokens.token_response.0.access_token().secret().as_str()
|
||||
);
|
||||
assert!(store.saved_value(&key).is_none());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_removes_all_storage() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_tokens_match_without_expiry(
|
||||
actual: &StoredOAuthTokens,
|
||||
expected: &StoredOAuthTokens,
|
||||
) {
|
||||
assert_eq!(actual.server_name, expected.server_name);
|
||||
assert_eq!(actual.url, expected.url);
|
||||
assert_eq!(actual.client_id, expected.client_id);
|
||||
assert_token_response_match_without_expiry(
|
||||
&actual.token_response,
|
||||
&expected.token_response,
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_token_response_match_without_expiry(
|
||||
actual: &WrappedOAuthTokenResponse,
|
||||
expected: &WrappedOAuthTokenResponse,
|
||||
) {
|
||||
let actual_response = &actual.0;
|
||||
let expected_response = &expected.0;
|
||||
|
||||
assert_eq!(
|
||||
actual_response.access_token().secret(),
|
||||
expected_response.access_token().secret()
|
||||
);
|
||||
assert_eq!(actual_response.token_type(), expected_response.token_type());
|
||||
assert_eq!(
|
||||
actual_response.refresh_token().map(RefreshToken::secret),
|
||||
expected_response.refresh_token().map(RefreshToken::secret),
|
||||
);
|
||||
assert_eq!(actual_response.scopes(), expected_response.scopes());
|
||||
assert_eq!(
|
||||
actual_response.extra_fields(),
|
||||
expected_response.extra_fields()
|
||||
);
|
||||
assert_eq!(
|
||||
actual_response.expires_in().is_some(),
|
||||
expected_response.expires_in().is_some()
|
||||
);
|
||||
}
|
||||
|
||||
fn sample_tokens() -> StoredOAuthTokens {
|
||||
let mut response = OAuthTokenResponse::new(
|
||||
AccessToken::new("access-token".to_string()),
|
||||
BasicTokenType::Bearer,
|
||||
EmptyExtraTokenFields {},
|
||||
);
|
||||
response.set_refresh_token(Some(RefreshToken::new("refresh-token".to_string())));
|
||||
response.set_scopes(Some(vec![
|
||||
Scope::new("scope-a".to_string()),
|
||||
Scope::new("scope-b".to_string()),
|
||||
]));
|
||||
let expires_in = Duration::from_secs(3600);
|
||||
response.set_expires_in(Some(&expires_in));
|
||||
|
||||
StoredOAuthTokens {
|
||||
server_name: "test-server".to_string(),
|
||||
url: "https://example.test".to_string(),
|
||||
client_id: "client-id".to_string(),
|
||||
token_response: WrappedOAuthTokenResponse(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user