Follow-up PR to #5569. Add Keyring Support for Auth Storage in Codex CLI as well as a hybrid mode (default to persisting in keychain but fall back to file when unavailable.) It also refactors out the keyringstore implementation from rmcp-client [here](https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs) to a new keyring-store crate. There will be a follow-up that picks the right credential mode depending on the config, instead of hardcoding `AuthCredentialsStoreMode::File`.
815 lines
27 KiB
Rust
815 lines
27 KiB
Rust
//! This file handles all logic related to managing MCP OAuth credentials.
|
|
//! All credentials are stored using the keyring crate which uses os-specific keyring services.
|
|
//! https://crates.io/crates/keyring
|
|
//! macOS: macOS keychain.
|
|
//! Windows: Windows Credential Manager
|
|
//! Linux: DBus-based Secret Service, the kernel keyutils, and a combo of the two
|
|
//! FreeBSD, OpenBSD: DBus-based Secret Service
|
|
//!
|
|
//! For Linux, we use linux-native-async-persistent which uses both keyutils and async-secret-service (see below) for storage.
|
|
//! See the docs for the keyutils_persistent module for a full explanation of why both are used. Because this store uses the
|
|
//! async-secret-service, you must specify the additional features required by that store
|
|
//!
|
|
//! async-secret-service provides access to the DBus-based Secret Service storage on Linux, FreeBSD, and OpenBSD. This is an asynchronous
|
|
//! keystore that always encrypts secrets when they are transferred across the bus. If DBus isn't installed the keystore will fall back to the json
|
|
//! file because we don't use the "vendored" feature.
|
|
//!
|
|
//! If the keyring is not available or fails, we fall back to CODEX_HOME/.credentials.json which is consistent with other coding CLI agents.
|
|
|
|
use anyhow::Context;
|
|
use anyhow::Error;
|
|
use anyhow::Result;
|
|
use oauth2::AccessToken;
|
|
use oauth2::EmptyExtraTokenFields;
|
|
use oauth2::RefreshToken;
|
|
use oauth2::Scope;
|
|
use oauth2::TokenResponse;
|
|
use oauth2::basic::BasicTokenType;
|
|
use rmcp::transport::auth::OAuthTokenResponse;
|
|
use serde::Deserialize;
|
|
use serde::Serialize;
|
|
use serde_json::Value;
|
|
use serde_json::map::Map as JsonMap;
|
|
use sha2::Digest;
|
|
use sha2::Sha256;
|
|
use std::collections::BTreeMap;
|
|
use std::fs;
|
|
use std::io::ErrorKind;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use std::time::SystemTime;
|
|
use std::time::UNIX_EPOCH;
|
|
use tracing::warn;
|
|
|
|
use codex_keyring_store::DefaultKeyringStore;
|
|
use codex_keyring_store::KeyringStore;
|
|
use rmcp::transport::auth::AuthorizationManager;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::find_codex_home::find_codex_home;
|
|
|
|
const KEYRING_SERVICE: &str = "Codex MCP Credentials";
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct StoredOAuthTokens {
|
|
pub server_name: String,
|
|
pub url: String,
|
|
pub client_id: String,
|
|
pub token_response: WrappedOAuthTokenResponse,
|
|
}
|
|
|
|
/// Determine where Codex should store and read MCP credentials.
|
|
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum OAuthCredentialsStoreMode {
|
|
/// `Keyring` when available; otherwise, `File`.
|
|
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
|
#[default]
|
|
Auto,
|
|
/// CODEX_HOME/.credentials.json
|
|
/// This file will be readable to Codex and other applications running as the same user.
|
|
File,
|
|
/// Keyring when available, otherwise fail.
|
|
Keyring,
|
|
}
|
|
|
|
/// Wrap OAuthTokenResponse to allow for partial equality comparison.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WrappedOAuthTokenResponse(pub OAuthTokenResponse);
|
|
|
|
impl PartialEq for WrappedOAuthTokenResponse {
|
|
fn eq(&self, other: &Self) -> bool {
|
|
match (serde_json::to_string(self), serde_json::to_string(other)) {
|
|
(Ok(s1), Ok(s2)) => s1 == s2,
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) fn load_oauth_tokens(
|
|
server_name: &str,
|
|
url: &str,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<Option<StoredOAuthTokens>> {
|
|
let keyring_store = DefaultKeyringStore;
|
|
match store_mode {
|
|
OAuthCredentialsStoreMode::Auto => {
|
|
load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url)
|
|
}
|
|
OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url),
|
|
OAuthCredentialsStoreMode::Keyring => {
|
|
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
|
.with_context(|| "failed to read OAuth tokens from keyring".to_string())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) fn has_oauth_tokens(
|
|
server_name: &str,
|
|
url: &str,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<bool> {
|
|
Ok(load_oauth_tokens(server_name, url, store_mode)?.is_some())
|
|
}
|
|
|
|
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
|
keyring_store: &K,
|
|
server_name: &str,
|
|
url: &str,
|
|
) -> Result<Option<StoredOAuthTokens>> {
|
|
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
|
Ok(Some(tokens)) => Ok(Some(tokens)),
|
|
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
|
Err(error) => {
|
|
warn!("failed to read OAuth tokens from keyring: {error}");
|
|
load_oauth_tokens_from_file(server_name, url)
|
|
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
|
keyring_store: &K,
|
|
server_name: &str,
|
|
url: &str,
|
|
) -> Result<Option<StoredOAuthTokens>> {
|
|
let key = compute_store_key(server_name, url)?;
|
|
match keyring_store.load(KEYRING_SERVICE, &key) {
|
|
Ok(Some(serialized)) => {
|
|
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
|
.context("failed to deserialize OAuth tokens from keyring")?;
|
|
Ok(Some(tokens))
|
|
}
|
|
Ok(None) => Ok(None),
|
|
Err(error) => Err(Error::new(error.into_error())),
|
|
}
|
|
}
|
|
|
|
pub fn save_oauth_tokens(
|
|
server_name: &str,
|
|
tokens: &StoredOAuthTokens,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<()> {
|
|
let keyring_store = DefaultKeyringStore;
|
|
match store_mode {
|
|
OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file(
|
|
&keyring_store,
|
|
server_name,
|
|
tokens,
|
|
),
|
|
OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens),
|
|
OAuthCredentialsStoreMode::Keyring => {
|
|
save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens)
|
|
}
|
|
}
|
|
}
|
|
|
|
fn save_oauth_tokens_with_keyring<K: KeyringStore>(
|
|
keyring_store: &K,
|
|
server_name: &str,
|
|
tokens: &StoredOAuthTokens,
|
|
) -> Result<()> {
|
|
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
|
|
|
let key = compute_store_key(server_name, &tokens.url)?;
|
|
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
|
|
Ok(()) => {
|
|
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
|
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
|
}
|
|
Ok(())
|
|
}
|
|
Err(error) => {
|
|
let message = format!(
|
|
"failed to write OAuth tokens to keyring: {}",
|
|
error.message()
|
|
);
|
|
warn!("{message}");
|
|
Err(Error::new(error.into_error()).context(message))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn save_oauth_tokens_with_keyring_with_fallback_to_file<K: KeyringStore>(
|
|
keyring_store: &K,
|
|
server_name: &str,
|
|
tokens: &StoredOAuthTokens,
|
|
) -> Result<()> {
|
|
match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) {
|
|
Ok(()) => Ok(()),
|
|
Err(error) => {
|
|
let message = error.to_string();
|
|
warn!("falling back to file storage for OAuth tokens: {message}");
|
|
save_oauth_tokens_to_file(tokens)
|
|
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn delete_oauth_tokens(
|
|
server_name: &str,
|
|
url: &str,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
) -> Result<bool> {
|
|
let keyring_store = DefaultKeyringStore;
|
|
delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url)
|
|
}
|
|
|
|
fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
|
|
keyring_store: &K,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
server_name: &str,
|
|
url: &str,
|
|
) -> Result<bool> {
|
|
let key = compute_store_key(server_name, url)?;
|
|
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
|
|
let keyring_removed = match keyring_result {
|
|
Ok(removed) => removed,
|
|
Err(error) => {
|
|
let message = error.message();
|
|
warn!("failed to delete OAuth tokens from keyring: {message}");
|
|
match store_mode {
|
|
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
|
|
return Err(error.into_error())
|
|
.context("failed to delete OAuth tokens from keyring");
|
|
}
|
|
OAuthCredentialsStoreMode::File => false,
|
|
}
|
|
}
|
|
};
|
|
|
|
let file_removed = delete_oauth_tokens_from_file(&key)?;
|
|
Ok(keyring_removed || file_removed)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub(crate) struct OAuthPersistor {
|
|
inner: Arc<OAuthPersistorInner>,
|
|
}
|
|
|
|
struct OAuthPersistorInner {
|
|
server_name: String,
|
|
url: String,
|
|
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
|
}
|
|
|
|
impl OAuthPersistor {
|
|
pub(crate) fn new(
|
|
server_name: String,
|
|
url: String,
|
|
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
|
store_mode: OAuthCredentialsStoreMode,
|
|
initial_credentials: Option<StoredOAuthTokens>,
|
|
) -> Self {
|
|
Self {
|
|
inner: Arc::new(OAuthPersistorInner {
|
|
server_name,
|
|
url,
|
|
authorization_manager,
|
|
store_mode,
|
|
last_credentials: Mutex::new(initial_credentials),
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Persists the latest stored credentials if they have changed.
|
|
/// Deletes the credentials if they are no longer present.
|
|
pub(crate) async fn persist_if_needed(&self) -> Result<()> {
|
|
let (client_id, maybe_credentials) = {
|
|
let manager = self.inner.authorization_manager.clone();
|
|
let guard = manager.lock().await;
|
|
guard.get_credentials().await
|
|
}?;
|
|
|
|
match maybe_credentials {
|
|
Some(credentials) => {
|
|
let stored = StoredOAuthTokens {
|
|
server_name: self.inner.server_name.clone(),
|
|
url: self.inner.url.clone(),
|
|
client_id,
|
|
token_response: WrappedOAuthTokenResponse(credentials.clone()),
|
|
};
|
|
let mut last_credentials = self.inner.last_credentials.lock().await;
|
|
if last_credentials.as_ref() != Some(&stored) {
|
|
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
|
*last_credentials = Some(stored);
|
|
}
|
|
}
|
|
None => {
|
|
let mut last_serialized = self.inner.last_credentials.lock().await;
|
|
if last_serialized.take().is_some()
|
|
&& let Err(error) = delete_oauth_tokens(
|
|
&self.inner.server_name,
|
|
&self.inner.url,
|
|
self.inner.store_mode,
|
|
)
|
|
{
|
|
warn!(
|
|
"failed to remove OAuth tokens for server {}: {error}",
|
|
self.inner.server_name
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
const FALLBACK_FILENAME: &str = ".credentials.json";
|
|
const MCP_SERVER_TYPE: &str = "http";
|
|
|
|
type FallbackFile = BTreeMap<String, FallbackTokenEntry>;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct FallbackTokenEntry {
|
|
server_name: String,
|
|
server_url: String,
|
|
client_id: String,
|
|
access_token: String,
|
|
#[serde(default)]
|
|
expires_at: Option<u64>,
|
|
#[serde(default)]
|
|
refresh_token: Option<String>,
|
|
#[serde(default)]
|
|
scopes: Vec<String>,
|
|
}
|
|
|
|
fn load_oauth_tokens_from_file(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
|
let Some(store) = read_fallback_file()? else {
|
|
return Ok(None);
|
|
};
|
|
|
|
let key = compute_store_key(server_name, url)?;
|
|
|
|
for entry in store.values() {
|
|
let entry_key = compute_store_key(&entry.server_name, &entry.server_url)?;
|
|
if entry_key != key {
|
|
continue;
|
|
}
|
|
|
|
let mut token_response = OAuthTokenResponse::new(
|
|
AccessToken::new(entry.access_token.clone()),
|
|
BasicTokenType::Bearer,
|
|
EmptyExtraTokenFields {},
|
|
);
|
|
|
|
if let Some(refresh) = entry.refresh_token.clone() {
|
|
token_response.set_refresh_token(Some(RefreshToken::new(refresh)));
|
|
}
|
|
|
|
let scopes = entry.scopes.clone();
|
|
if !scopes.is_empty() {
|
|
token_response.set_scopes(Some(scopes.into_iter().map(Scope::new).collect()));
|
|
}
|
|
|
|
if let Some(expires_at) = entry.expires_at
|
|
&& let Some(seconds) = expires_in_from_timestamp(expires_at)
|
|
{
|
|
let duration = Duration::from_secs(seconds);
|
|
token_response.set_expires_in(Some(&duration));
|
|
}
|
|
|
|
let stored = StoredOAuthTokens {
|
|
server_name: entry.server_name.clone(),
|
|
url: entry.server_url.clone(),
|
|
client_id: entry.client_id.clone(),
|
|
token_response: WrappedOAuthTokenResponse(token_response),
|
|
};
|
|
|
|
return Ok(Some(stored));
|
|
}
|
|
|
|
Ok(None)
|
|
}
|
|
|
|
fn save_oauth_tokens_to_file(tokens: &StoredOAuthTokens) -> Result<()> {
|
|
let key = compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
let mut store = read_fallback_file()?.unwrap_or_default();
|
|
|
|
let token_response = &tokens.token_response.0;
|
|
let refresh_token = token_response
|
|
.refresh_token()
|
|
.map(|token| token.secret().to_string());
|
|
let scopes = token_response
|
|
.scopes()
|
|
.map(|s| s.iter().map(|s| s.to_string()).collect())
|
|
.unwrap_or_default();
|
|
let entry = FallbackTokenEntry {
|
|
server_name: tokens.server_name.clone(),
|
|
server_url: tokens.url.clone(),
|
|
client_id: tokens.client_id.clone(),
|
|
access_token: token_response.access_token().secret().to_string(),
|
|
expires_at: compute_expires_at_millis(token_response),
|
|
refresh_token,
|
|
scopes,
|
|
};
|
|
|
|
store.insert(key, entry);
|
|
write_fallback_file(&store)
|
|
}
|
|
|
|
fn delete_oauth_tokens_from_file(key: &str) -> Result<bool> {
|
|
let mut store = match read_fallback_file()? {
|
|
Some(store) => store,
|
|
None => return Ok(false),
|
|
};
|
|
|
|
let removed = store.remove(key).is_some();
|
|
|
|
if removed {
|
|
write_fallback_file(&store)?;
|
|
}
|
|
|
|
Ok(removed)
|
|
}
|
|
|
|
fn compute_expires_at_millis(response: &OAuthTokenResponse) -> Option<u64> {
|
|
let expires_in = response.expires_in()?;
|
|
let now = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap_or_else(|_| Duration::from_secs(0));
|
|
let expiry = now.checked_add(expires_in)?;
|
|
let millis = expiry.as_millis();
|
|
if millis > u128::from(u64::MAX) {
|
|
Some(u64::MAX)
|
|
} else {
|
|
Some(millis as u64)
|
|
}
|
|
}
|
|
|
|
fn expires_in_from_timestamp(expires_at: u64) -> Option<u64> {
|
|
let now = SystemTime::now()
|
|
.duration_since(UNIX_EPOCH)
|
|
.unwrap_or_else(|_| Duration::from_secs(0));
|
|
let now_ms = now.as_millis() as u64;
|
|
|
|
if expires_at <= now_ms {
|
|
None
|
|
} else {
|
|
Some((expires_at - now_ms) / 1000)
|
|
}
|
|
}
|
|
|
|
fn compute_store_key(server_name: &str, server_url: &str) -> Result<String> {
|
|
let mut payload = JsonMap::new();
|
|
payload.insert(
|
|
"type".to_string(),
|
|
Value::String(MCP_SERVER_TYPE.to_string()),
|
|
);
|
|
payload.insert("url".to_string(), Value::String(server_url.to_string()));
|
|
payload.insert("headers".to_string(), Value::Object(JsonMap::new()));
|
|
|
|
let truncated = sha_256_prefix(&Value::Object(payload))?;
|
|
Ok(format!("{server_name}|{truncated}"))
|
|
}
|
|
|
|
fn fallback_file_path() -> Result<PathBuf> {
|
|
let mut path = find_codex_home()?;
|
|
path.push(FALLBACK_FILENAME);
|
|
Ok(path)
|
|
}
|
|
|
|
fn read_fallback_file() -> Result<Option<FallbackFile>> {
|
|
let path = fallback_file_path()?;
|
|
let contents = match fs::read_to_string(&path) {
|
|
Ok(contents) => contents,
|
|
Err(err) if err.kind() == ErrorKind::NotFound => return Ok(None),
|
|
Err(err) => {
|
|
return Err(err).context(format!(
|
|
"failed to read credentials file at {}",
|
|
path.display()
|
|
));
|
|
}
|
|
};
|
|
|
|
match serde_json::from_str::<FallbackFile>(&contents) {
|
|
Ok(store) => Ok(Some(store)),
|
|
Err(e) => Err(e).context(format!(
|
|
"failed to parse credentials file at {}",
|
|
path.display()
|
|
)),
|
|
}
|
|
}
|
|
|
|
fn write_fallback_file(store: &FallbackFile) -> Result<()> {
|
|
let path = fallback_file_path()?;
|
|
|
|
if store.is_empty() {
|
|
if path.exists() {
|
|
fs::remove_file(path)?;
|
|
}
|
|
return Ok(());
|
|
}
|
|
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)?;
|
|
}
|
|
|
|
let serialized = serde_json::to_string(store)?;
|
|
fs::write(&path, serialized)?;
|
|
|
|
#[cfg(unix)]
|
|
{
|
|
use std::os::unix::fs::PermissionsExt;
|
|
let perms = fs::Permissions::from_mode(0o600);
|
|
fs::set_permissions(&path, perms)?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn sha_256_prefix(value: &Value) -> Result<String> {
|
|
let serialized =
|
|
serde_json::to_string(&value).context("failed to serialize MCP OAuth key payload")?;
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(serialized.as_bytes());
|
|
let digest = hasher.finalize();
|
|
let hex = format!("{digest:x}");
|
|
let truncated = &hex[..16];
|
|
Ok(truncated.to_string())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use anyhow::Result;
|
|
use keyring::Error as KeyringError;
|
|
use pretty_assertions::assert_eq;
|
|
use std::sync::Mutex;
|
|
use std::sync::MutexGuard;
|
|
use std::sync::OnceLock;
|
|
use std::sync::PoisonError;
|
|
use tempfile::tempdir;
|
|
|
|
use codex_keyring_store::tests::MockKeyringStore;
|
|
|
|
struct TempCodexHome {
|
|
_guard: MutexGuard<'static, ()>,
|
|
_dir: tempfile::TempDir,
|
|
}
|
|
|
|
impl TempCodexHome {
|
|
fn new() -> Self {
|
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
|
let guard = LOCK
|
|
.get_or_init(Mutex::default)
|
|
.lock()
|
|
.unwrap_or_else(PoisonError::into_inner);
|
|
let dir = tempdir().expect("create CODEX_HOME temp dir");
|
|
unsafe {
|
|
std::env::set_var("CODEX_HOME", dir.path());
|
|
}
|
|
Self {
|
|
_guard: guard,
|
|
_dir: dir,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Drop for TempCodexHome {
|
|
fn drop(&mut self) {
|
|
unsafe {
|
|
std::env::remove_var("CODEX_HOME");
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn load_oauth_tokens_reads_from_keyring_when_available() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let expected = tokens.clone();
|
|
let serialized = serde_json::to_string(&tokens)?;
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
|
|
|
let loaded =
|
|
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
|
assert_eq!(loaded, Some(expected));
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn load_oauth_tokens_falls_back_when_missing_in_keyring() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let expected = tokens.clone();
|
|
|
|
super::save_oauth_tokens_to_file(&tokens)?;
|
|
|
|
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
|
&store,
|
|
&tokens.server_name,
|
|
&tokens.url,
|
|
)?
|
|
.expect("tokens should load from fallback");
|
|
assert_tokens_match_without_expiry(&loaded, &expected);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn load_oauth_tokens_falls_back_when_keyring_errors() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let expected = tokens.clone();
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
|
|
|
super::save_oauth_tokens_to_file(&tokens)?;
|
|
|
|
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
|
&store,
|
|
&tokens.server_name,
|
|
&tokens.url,
|
|
)?
|
|
.expect("tokens should load from fallback");
|
|
assert_tokens_match_without_expiry(&loaded, &expected);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn save_oauth_tokens_prefers_keyring_when_available() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
|
|
super::save_oauth_tokens_to_file(&tokens)?;
|
|
|
|
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
|
&store,
|
|
&tokens.server_name,
|
|
&tokens,
|
|
)?;
|
|
|
|
let fallback_path = super::fallback_file_path()?;
|
|
assert!(!fallback_path.exists(), "fallback file should be removed");
|
|
let stored = store.saved_value(&key).expect("value saved to keyring");
|
|
assert_eq!(serde_json::from_str::<StoredOAuthTokens>(&stored)?, tokens);
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn save_oauth_tokens_writes_fallback_when_keyring_fails() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
|
|
|
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
|
&store,
|
|
&tokens.server_name,
|
|
&tokens,
|
|
)?;
|
|
|
|
let fallback_path = super::fallback_file_path()?;
|
|
assert!(fallback_path.exists(), "fallback file should be created");
|
|
let saved = super::read_fallback_file()?.expect("fallback file should load");
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
let entry = saved.get(&key).expect("entry for key");
|
|
assert_eq!(entry.server_name, tokens.server_name);
|
|
assert_eq!(entry.server_url, tokens.url);
|
|
assert_eq!(entry.client_id, tokens.client_id);
|
|
assert_eq!(
|
|
entry.access_token,
|
|
tokens.token_response.0.access_token().secret().as_str()
|
|
);
|
|
assert!(store.saved_value(&key).is_none());
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn delete_oauth_tokens_removes_all_storage() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let serialized = serde_json::to_string(&tokens)?;
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
|
super::save_oauth_tokens_to_file(&tokens)?;
|
|
|
|
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
|
&store,
|
|
OAuthCredentialsStoreMode::Auto,
|
|
&tokens.server_name,
|
|
&tokens.url,
|
|
)?;
|
|
assert!(removed);
|
|
assert!(!store.contains(&key));
|
|
assert!(!super::fallback_file_path()?.exists());
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let serialized = serde_json::to_string(&tokens)?;
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
|
assert!(store.contains(&key));
|
|
|
|
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
|
&store,
|
|
OAuthCredentialsStoreMode::Auto,
|
|
&tokens.server_name,
|
|
&tokens.url,
|
|
)?;
|
|
assert!(removed);
|
|
assert!(!store.contains(&key));
|
|
assert!(!super::fallback_file_path()?.exists());
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn delete_oauth_tokens_propagates_keyring_errors() -> Result<()> {
|
|
let _env = TempCodexHome::new();
|
|
let store = MockKeyringStore::default();
|
|
let tokens = sample_tokens();
|
|
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
|
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
|
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
|
|
|
let result = super::delete_oauth_tokens_from_keyring_and_file(
|
|
&store,
|
|
OAuthCredentialsStoreMode::Auto,
|
|
&tokens.server_name,
|
|
&tokens.url,
|
|
);
|
|
assert!(result.is_err());
|
|
assert!(super::fallback_file_path().unwrap().exists());
|
|
Ok(())
|
|
}
|
|
|
|
fn assert_tokens_match_without_expiry(
|
|
actual: &StoredOAuthTokens,
|
|
expected: &StoredOAuthTokens,
|
|
) {
|
|
assert_eq!(actual.server_name, expected.server_name);
|
|
assert_eq!(actual.url, expected.url);
|
|
assert_eq!(actual.client_id, expected.client_id);
|
|
assert_token_response_match_without_expiry(
|
|
&actual.token_response,
|
|
&expected.token_response,
|
|
);
|
|
}
|
|
|
|
fn assert_token_response_match_without_expiry(
|
|
actual: &WrappedOAuthTokenResponse,
|
|
expected: &WrappedOAuthTokenResponse,
|
|
) {
|
|
let actual_response = &actual.0;
|
|
let expected_response = &expected.0;
|
|
|
|
assert_eq!(
|
|
actual_response.access_token().secret(),
|
|
expected_response.access_token().secret()
|
|
);
|
|
assert_eq!(actual_response.token_type(), expected_response.token_type());
|
|
assert_eq!(
|
|
actual_response.refresh_token().map(RefreshToken::secret),
|
|
expected_response.refresh_token().map(RefreshToken::secret),
|
|
);
|
|
assert_eq!(actual_response.scopes(), expected_response.scopes());
|
|
assert_eq!(
|
|
actual_response.extra_fields(),
|
|
expected_response.extra_fields()
|
|
);
|
|
assert_eq!(
|
|
actual_response.expires_in().is_some(),
|
|
expected_response.expires_in().is_some()
|
|
);
|
|
}
|
|
|
|
fn sample_tokens() -> StoredOAuthTokens {
|
|
let mut response = OAuthTokenResponse::new(
|
|
AccessToken::new("access-token".to_string()),
|
|
BasicTokenType::Bearer,
|
|
EmptyExtraTokenFields {},
|
|
);
|
|
response.set_refresh_token(Some(RefreshToken::new("refresh-token".to_string())));
|
|
response.set_scopes(Some(vec![
|
|
Scope::new("scope-a".to_string()),
|
|
Scope::new("scope-b".to_string()),
|
|
]));
|
|
let expires_in = Duration::from_secs(3600);
|
|
response.set_expires_in(Some(&expires_in));
|
|
|
|
StoredOAuthTokens {
|
|
server_name: "test-server".to_string(),
|
|
url: "https://example.test".to_string(),
|
|
client_id: "client-id".to_string(),
|
|
token_response: WrappedOAuthTokenResponse(response),
|
|
}
|
|
}
|
|
}
|