[MCP] Add the ability to explicitly specify a credentials store (#4857)
This lets users/companies explicitly choose whether to force/disallow the keyring/fallback file storage for mcp credentials. People who develop with Codex will want to use this until we sign binaries or else each ad-hoc debug builds will require keychain access on every build. I don't love this and am open to other ideas for how to handle that. ```toml mcp_oauth_credentials_store = "auto" mcp_oauth_credentials_store = "file" mcp_oauth_credentials_store = "keyrung" ``` Defaults to `auto`
This commit is contained in:
@@ -236,7 +236,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs)
|
||||
_ => bail!("OAuth login is only supported for streamable HTTP servers."),
|
||||
};
|
||||
|
||||
perform_oauth_login(&name, &url).await?;
|
||||
perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?;
|
||||
println!("Successfully logged in to MCP server '{name}'.");
|
||||
Ok(())
|
||||
}
|
||||
@@ -259,7 +259,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr
|
||||
_ => bail!("OAuth logout is only supported for streamable_http transports."),
|
||||
};
|
||||
|
||||
match delete_oauth_tokens(&name, &url) {
|
||||
match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) {
|
||||
Ok(true) => println!("Removed OAuth credentials for '{name}'."),
|
||||
Ok(false) => println!("No OAuth credentials stored for '{name}'."),
|
||||
Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")),
|
||||
|
||||
@@ -364,6 +364,7 @@ impl Session {
|
||||
let mcp_fut = McpConnectionManager::new(
|
||||
config.mcp_servers.clone(),
|
||||
config.use_experimental_use_rmcp_client,
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
);
|
||||
let default_shell_fut = shell::default_user_shell();
|
||||
let history_meta_fut = crate::message_history::history_metadata(&config);
|
||||
|
||||
@@ -33,6 +33,7 @@ use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use dirs::home_dir;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
@@ -142,6 +143,15 @@ pub struct Config {
|
||||
/// Definition for MCP servers that Codex can reach out to for tool calls.
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred store for MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
/// auto (default): keyring if available, otherwise file.
|
||||
pub mcp_oauth_credentials_store_mode: OAuthCredentialsStoreMode,
|
||||
|
||||
/// Combined provider map (defaults merged with user-defined overrides).
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
|
||||
@@ -694,6 +704,14 @@ pub struct ConfigToml {
|
||||
#[serde(default)]
|
||||
pub mcp_servers: HashMap<String, McpServerConfig>,
|
||||
|
||||
/// Preferred backend for storing MCP OAuth credentials.
|
||||
/// keyring: Use an OS-specific keyring service.
|
||||
/// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2
|
||||
/// file: Use a file in the Codex home directory.
|
||||
/// auto (default): Use the OS-specific keyring service if available, otherwise use a file.
|
||||
#[serde(default)]
|
||||
pub mcp_oauth_credentials_store: Option<OAuthCredentialsStoreMode>,
|
||||
|
||||
/// User-defined provider entries that extend/override the built-in list.
|
||||
#[serde(default)]
|
||||
pub model_providers: HashMap<String, ModelProviderInfo>,
|
||||
@@ -1074,6 +1092,9 @@ impl Config {
|
||||
user_instructions,
|
||||
base_instructions,
|
||||
mcp_servers: cfg.mcp_servers,
|
||||
// The config.toml omits "_mode" because it's a config file. However, "_mode"
|
||||
// is important in code to differentiate the mode from the store implementation.
|
||||
mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(),
|
||||
model_providers,
|
||||
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
|
||||
project_doc_fallback_filenames: cfg
|
||||
@@ -1364,6 +1385,85 @@ exclude_slash_tmp = true
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_defaults_to_auto_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml::default();
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_honors_explicit_file_oauth_store_mode() -> std::io::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let cfg = ConfigToml {
|
||||
mcp_oauth_credentials_store: Some(OAuthCredentialsStoreMode::File),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
|
||||
assert_eq!(
|
||||
config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn managed_config_overrides_oauth_store_mode() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let managed_path = codex_home.path().join("managed_config.toml");
|
||||
let config_path = codex_home.path().join(CONFIG_TOML_FILE);
|
||||
|
||||
std::fs::write(&config_path, "mcp_oauth_credentials_store = \"file\"\n")?;
|
||||
std::fs::write(&managed_path, "mcp_oauth_credentials_store = \"keyring\"\n")?;
|
||||
|
||||
let overrides = crate::config_loader::LoaderOverrides {
|
||||
managed_config_path: Some(managed_path.clone()),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let root_value = load_resolved_config(codex_home.path(), Vec::new(), overrides).await?;
|
||||
let cfg: ConfigToml = root_value.try_into().map_err(|e| {
|
||||
tracing::error!("Failed to deserialize overridden config: {e}");
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e)
|
||||
})?;
|
||||
assert_eq!(
|
||||
cfg.mcp_oauth_credentials_store,
|
||||
Some(OAuthCredentialsStoreMode::Keyring),
|
||||
);
|
||||
|
||||
let final_config = Config::load_from_base_config_with_overrides(
|
||||
cfg,
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)?;
|
||||
assert_eq!(
|
||||
final_config.mcp_oauth_credentials_store_mode,
|
||||
OAuthCredentialsStoreMode::Keyring,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> {
|
||||
let codex_home = TempDir::new()?;
|
||||
@@ -1896,6 +1996,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -1958,6 +2059,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2035,6 +2137,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
@@ -2098,6 +2201,7 @@ model_verbosity = "high"
|
||||
notify: None,
|
||||
cwd: fixture.cwd(),
|
||||
mcp_servers: HashMap::new(),
|
||||
mcp_oauth_credentials_store_mode: Default::default(),
|
||||
model_providers: fixture.model_provider_map.clone(),
|
||||
project_doc_max_bytes: PROJECT_DOC_MAX_BYTES,
|
||||
project_doc_fallback_filenames: Vec::new(),
|
||||
|
||||
@@ -16,6 +16,7 @@ use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_mcp_client::McpClient;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
@@ -125,9 +126,11 @@ impl McpClientAdapter {
|
||||
bearer_token: Option<String>,
|
||||
params: mcp_types::InitializeRequestParams,
|
||||
startup_timeout: Duration,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let client = Arc::new(
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?,
|
||||
RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode)
|
||||
.await?,
|
||||
);
|
||||
client.initialize(params, Some(startup_timeout)).await?;
|
||||
Ok(McpClientAdapter::Rmcp(client))
|
||||
@@ -182,6 +185,7 @@ impl McpConnectionManager {
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
use_rmcp_client: bool,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
@@ -249,6 +253,7 @@ impl McpConnectionManager {
|
||||
bearer_token,
|
||||
params,
|
||||
startup_timeout,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ mod perform_oauth_login;
|
||||
mod rmcp_client;
|
||||
mod utils;
|
||||
|
||||
pub use oauth::OAuthCredentialsStoreMode;
|
||||
pub use oauth::StoredOAuthTokens;
|
||||
pub use oauth::WrappedOAuthTokenResponse;
|
||||
pub use oauth::delete_oauth_tokens;
|
||||
|
||||
@@ -58,6 +58,21 @@ pub struct StoredOAuthTokens {
|
||||
pub token_response: WrappedOAuthTokenResponse,
|
||||
}
|
||||
|
||||
/// Determine where Codex should store and read MCP credentials.
|
||||
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OAuthCredentialsStoreMode {
|
||||
/// `Keyring` when available; otherwise, `File`.
|
||||
/// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access.
|
||||
#[default]
|
||||
Auto,
|
||||
/// CODEX_HOME/.credentials.json
|
||||
/// This file will be readable to Codex and other applications running as the same user.
|
||||
File,
|
||||
/// Keyring when available, otherwise fail.
|
||||
Keyring,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CredentialStoreError(anyhow::Error);
|
||||
|
||||
@@ -83,15 +98,15 @@ impl fmt::Display for CredentialStoreError {
|
||||
|
||||
impl std::error::Error for CredentialStoreError {}
|
||||
|
||||
trait CredentialStore {
|
||||
trait KeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError>;
|
||||
fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>;
|
||||
fn delete(&self, service: &str, account: &str) -> Result<bool, CredentialStoreError>;
|
||||
}
|
||||
|
||||
struct KeyringCredentialStore;
|
||||
struct DefaultKeyringStore;
|
||||
|
||||
impl CredentialStore for KeyringCredentialStore {
|
||||
impl KeyringStore for DefaultKeyringStore {
|
||||
fn load(&self, service: &str, account: &str) -> Result<Option<String>, CredentialStoreError> {
|
||||
let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?;
|
||||
match entry.get_password() {
|
||||
@@ -129,47 +144,85 @@ impl PartialEq for WrappedOAuthTokenResponse {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result<Option<StoredOAuthTokens>> {
|
||||
let store = KeyringCredentialStore;
|
||||
load_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub(crate) fn load_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => {
|
||||
load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url)
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
load_oauth_tokens_from_keyring(&keyring_store, server_name, url)
|
||||
.with_context(|| "failed to read OAuth tokens from keyring".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn load_oauth_tokens_from_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
match load_oauth_tokens_from_keyring(keyring_store, server_name, url) {
|
||||
Ok(Some(tokens)) => Ok(Some(tokens)),
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
warn!("failed to read OAuth tokens from keyring: {error}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {error}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_oauth_tokens_from_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<Option<StoredOAuthTokens>> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
match store.load(KEYRING_SERVICE, &key) {
|
||||
match keyring_store.load(KEYRING_SERVICE, &key) {
|
||||
Ok(Some(serialized)) => {
|
||||
let tokens: StoredOAuthTokens = serde_json::from_str(&serialized)
|
||||
.context("failed to deserialize OAuth tokens from keyring")?;
|
||||
Ok(Some(tokens))
|
||||
}
|
||||
Ok(None) => load_oauth_tokens_from_file(server_name, url),
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to read OAuth tokens from keyring: {message}");
|
||||
load_oauth_tokens_from_file(server_name, url)
|
||||
.with_context(|| format!("failed to read OAuth tokens from keyring: {message}"))
|
||||
Ok(None) => Ok(None),
|
||||
Err(error) => Err(error.into_error()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&keyring_store,
|
||||
server_name,
|
||||
tokens,
|
||||
),
|
||||
OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens),
|
||||
OAuthCredentialsStoreMode::Keyring => {
|
||||
save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> {
|
||||
let store = KeyringCredentialStore;
|
||||
save_oauth_tokens_with_store(&store, server_name, tokens)
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn save_oauth_tokens_with_keyring<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?;
|
||||
|
||||
let key = compute_store_key(server_name, &tokens.url)?;
|
||||
match store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
match keyring_store.save(KEYRING_SERVICE, &key, &serialized) {
|
||||
Ok(()) => {
|
||||
if let Err(error) = delete_oauth_tokens_from_file(&key) {
|
||||
warn!("failed to remove OAuth tokens from fallback storage: {error:?}");
|
||||
@@ -177,31 +230,61 @@ fn save_oauth_tokens_with_store<C: CredentialStore>(
|
||||
Ok(())
|
||||
}
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to write OAuth tokens to keyring: {message}");
|
||||
let message = format!(
|
||||
"failed to write OAuth tokens to keyring: {}",
|
||||
error.message()
|
||||
);
|
||||
warn!("{message}");
|
||||
Err(error.into_error().context(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save_oauth_tokens_with_keyring_with_fallback_to_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
server_name: &str,
|
||||
tokens: &StoredOAuthTokens,
|
||||
) -> Result<()> {
|
||||
match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) => {
|
||||
let message = error.to_string();
|
||||
warn!("falling back to file storage for OAuth tokens: {message}");
|
||||
save_oauth_tokens_to_file(tokens)
|
||||
.with_context(|| format!("failed to write OAuth tokens to keyring: {message}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result<bool> {
|
||||
let store = KeyringCredentialStore;
|
||||
delete_oauth_tokens_with_store(&store, server_name, url)
|
||||
pub fn delete_oauth_tokens(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<bool> {
|
||||
let keyring_store = DefaultKeyringStore;
|
||||
delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url)
|
||||
}
|
||||
|
||||
fn delete_oauth_tokens_with_store<C: CredentialStore>(
|
||||
store: &C,
|
||||
fn delete_oauth_tokens_from_keyring_and_file<K: KeyringStore>(
|
||||
keyring_store: &K,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
) -> Result<bool> {
|
||||
let key = compute_store_key(server_name, url)?;
|
||||
let keyring_removed = match store.delete(KEYRING_SERVICE, &key) {
|
||||
let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key);
|
||||
let keyring_removed = match keyring_result {
|
||||
Ok(removed) => removed,
|
||||
Err(error) => {
|
||||
let message = error.message();
|
||||
warn!("failed to delete OAuth tokens from keyring: {message}");
|
||||
return Err(error.into_error()).context("failed to delete OAuth tokens from keyring");
|
||||
match store_mode {
|
||||
OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => {
|
||||
return Err(error.into_error())
|
||||
.context("failed to delete OAuth tokens from keyring");
|
||||
}
|
||||
OAuthCredentialsStoreMode::File => false,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -218,6 +301,7 @@ struct OAuthPersistorInner {
|
||||
server_name: String,
|
||||
url: String,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
last_credentials: Mutex<Option<StoredOAuthTokens>>,
|
||||
}
|
||||
|
||||
@@ -225,14 +309,16 @@ impl OAuthPersistor {
|
||||
pub(crate) fn new(
|
||||
server_name: String,
|
||||
url: String,
|
||||
manager: Arc<Mutex<AuthorizationManager>>,
|
||||
authorization_manager: Arc<Mutex<AuthorizationManager>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
initial_credentials: Option<StoredOAuthTokens>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(OAuthPersistorInner {
|
||||
server_name,
|
||||
url,
|
||||
authorization_manager: manager,
|
||||
authorization_manager,
|
||||
store_mode,
|
||||
last_credentials: Mutex::new(initial_credentials),
|
||||
}),
|
||||
}
|
||||
@@ -257,15 +343,18 @@ impl OAuthPersistor {
|
||||
};
|
||||
let mut last_credentials = self.inner.last_credentials.lock().await;
|
||||
if last_credentials.as_ref() != Some(&stored) {
|
||||
save_oauth_tokens(&self.inner.server_name, &stored)?;
|
||||
save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?;
|
||||
*last_credentials = Some(stored);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut last_serialized = self.inner.last_credentials.lock().await;
|
||||
if last_serialized.take().is_some()
|
||||
&& let Err(error) =
|
||||
delete_oauth_tokens(&self.inner.server_name, &self.inner.url)
|
||||
&& let Err(error) = delete_oauth_tokens(
|
||||
&self.inner.server_name,
|
||||
&self.inner.url,
|
||||
self.inner.store_mode,
|
||||
)
|
||||
{
|
||||
warn!(
|
||||
"failed to remove OAuth tokens for server {}: {error}",
|
||||
@@ -542,7 +631,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
impl CredentialStore for MockCredentialStore {
|
||||
impl KeyringStore for MockCredentialStore {
|
||||
fn load(
|
||||
&self,
|
||||
_service: &str,
|
||||
@@ -643,7 +732,8 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let loaded =
|
||||
super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
@@ -657,8 +747,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -674,8 +768,12 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?
|
||||
.expect("tokens should load from fallback");
|
||||
let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?
|
||||
.expect("tokens should load from fallback");
|
||||
assert_tokens_match_without_expiry(&loaded, &expected);
|
||||
Ok(())
|
||||
}
|
||||
@@ -689,7 +787,11 @@ mod tests {
|
||||
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(!fallback_path.exists(), "fallback file should be removed");
|
||||
@@ -706,7 +808,11 @@ mod tests {
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?;
|
||||
super::save_oauth_tokens_with_keyring_with_fallback_to_file(
|
||||
&store,
|
||||
&tokens.server_name,
|
||||
&tokens,
|
||||
)?;
|
||||
|
||||
let fallback_path = super::fallback_file_path()?;
|
||||
assert!(fallback_path.exists(), "fallback file should be created");
|
||||
@@ -734,8 +840,34 @@ mod tests {
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
super::save_oauth_tokens_to_file(&tokens)?;
|
||||
|
||||
let removed =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?;
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> {
|
||||
let _env = TempCodexHome::new();
|
||||
let store = MockCredentialStore::default();
|
||||
let tokens = sample_tokens();
|
||||
let serialized = serde_json::to_string(&tokens)?;
|
||||
let key = super::compute_store_key(&tokens.server_name, &tokens.url)?;
|
||||
store.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
assert!(store.contains(&key));
|
||||
|
||||
let removed = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
)?;
|
||||
assert!(removed);
|
||||
assert!(!store.contains(&key));
|
||||
assert!(!super::fallback_file_path()?.exists());
|
||||
@@ -751,8 +883,12 @@ mod tests {
|
||||
store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into()));
|
||||
super::save_oauth_tokens_to_file(&tokens).unwrap();
|
||||
|
||||
let result =
|
||||
super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url);
|
||||
let result = super::delete_oauth_tokens_from_keyring_and_file(
|
||||
&store,
|
||||
OAuthCredentialsStoreMode::Auto,
|
||||
&tokens.server_name,
|
||||
&tokens.url,
|
||||
);
|
||||
assert!(result.is_err());
|
||||
assert!(super::fallback_file_path().unwrap().exists());
|
||||
Ok(())
|
||||
|
||||
@@ -12,6 +12,7 @@ use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use urlencoding::decode;
|
||||
|
||||
use crate::OAuthCredentialsStoreMode;
|
||||
use crate::StoredOAuthTokens;
|
||||
use crate::WrappedOAuthTokenResponse;
|
||||
use crate::save_oauth_tokens;
|
||||
@@ -26,7 +27,11 @@ impl Drop for CallbackServerGuard {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<()> {
|
||||
pub async fn perform_oauth_login(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<()> {
|
||||
let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?);
|
||||
let guard = CallbackServerGuard {
|
||||
server: Arc::clone(&server),
|
||||
@@ -81,7 +86,7 @@ pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<
|
||||
client_id,
|
||||
token_response: WrappedOAuthTokenResponse(credentials),
|
||||
};
|
||||
save_oauth_tokens(server_name, &stored)?;
|
||||
save_oauth_tokens(server_name, &stored, store_mode)?;
|
||||
|
||||
drop(guard);
|
||||
Ok(())
|
||||
|
||||
@@ -35,6 +35,7 @@ use tracing::warn;
|
||||
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::logging_client_handler::LoggingClientHandler;
|
||||
use crate::oauth::OAuthCredentialsStoreMode;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
use crate::utils::convert_call_tool_result;
|
||||
@@ -119,10 +120,11 @@ impl RmcpClient {
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
bearer_token: Option<String>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<Self> {
|
||||
let initial_oauth_tokens = match bearer_token {
|
||||
Some(_) => None,
|
||||
None => match load_oauth_tokens(server_name, url) {
|
||||
None => match load_oauth_tokens(server_name, url, store_mode) {
|
||||
Ok(tokens) => tokens,
|
||||
Err(err) => {
|
||||
warn!("failed to read tokens for server `{server_name}`: {err}");
|
||||
@@ -132,7 +134,8 @@ impl RmcpClient {
|
||||
};
|
||||
let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() {
|
||||
let (transport, oauth_persistor) =
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?;
|
||||
create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode)
|
||||
.await?;
|
||||
PendingTransport::StreamableHttpWithOAuth {
|
||||
transport,
|
||||
oauth_persistor,
|
||||
@@ -286,6 +289,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name: &str,
|
||||
url: &str,
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<reqwest::Client>>,
|
||||
OAuthPersistor,
|
||||
@@ -320,6 +324,7 @@ async fn create_oauth_transport_and_runtime(
|
||||
server_name.to_string(),
|
||||
url.to_string(),
|
||||
auth_manager,
|
||||
credentials_store,
|
||||
Some(initial_tokens),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user