diff --git a/codex-rs/cli/src/mcp_cmd.rs b/codex-rs/cli/src/mcp_cmd.rs index a105dbba..c6add86f 100644 --- a/codex-rs/cli/src/mcp_cmd.rs +++ b/codex-rs/cli/src/mcp_cmd.rs @@ -236,7 +236,7 @@ async fn run_login(config_overrides: &CliConfigOverrides, login_args: LoginArgs) _ => bail!("OAuth login is only supported for streamable HTTP servers."), }; - perform_oauth_login(&name, &url).await?; + perform_oauth_login(&name, &url, config.mcp_oauth_credentials_store_mode).await?; println!("Successfully logged in to MCP server '{name}'."); Ok(()) } @@ -259,7 +259,7 @@ async fn run_logout(config_overrides: &CliConfigOverrides, logout_args: LogoutAr _ => bail!("OAuth logout is only supported for streamable_http transports."), }; - match delete_oauth_tokens(&name, &url) { + match delete_oauth_tokens(&name, &url, config.mcp_oauth_credentials_store_mode) { Ok(true) => println!("Removed OAuth credentials for '{name}'."), Ok(false) => println!("No OAuth credentials stored for '{name}'."), Err(err) => return Err(anyhow!("failed to delete OAuth credentials: {err}")), diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index d11a53b5..356e25ed 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -364,6 +364,7 @@ impl Session { let mcp_fut = McpConnectionManager::new( config.mcp_servers.clone(), config.use_experimental_use_rmcp_client, + config.mcp_oauth_credentials_store_mode, ); let default_shell_fut = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 28ad84ba..43d3f996 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -33,6 +33,7 @@ use codex_protocol::config_types::ReasoningEffort; use codex_protocol::config_types::ReasoningSummary; use codex_protocol::config_types::SandboxMode; use codex_protocol::config_types::Verbosity; +use codex_rmcp_client::OAuthCredentialsStoreMode; use dirs::home_dir; use serde::Deserialize; use std::collections::BTreeMap; @@ -142,6 +143,15 @@ pub struct Config { /// Definition for MCP servers that Codex can reach out to for tool calls. pub mcp_servers: HashMap, + /// Preferred store for MCP OAuth credentials. + /// keyring: Use an OS-specific keyring service. + /// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access. + /// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2 + /// file: CODEX_HOME/.credentials.json + /// This file will be readable to Codex and other applications running as the same user. + /// auto (default): keyring if available, otherwise file. + pub mcp_oauth_credentials_store_mode: OAuthCredentialsStoreMode, + /// Combined provider map (defaults merged with user-defined overrides). pub model_providers: HashMap, @@ -694,6 +704,14 @@ pub struct ConfigToml { #[serde(default)] pub mcp_servers: HashMap, + /// Preferred backend for storing MCP OAuth credentials. + /// keyring: Use an OS-specific keyring service. + /// https://github.com/openai/codex/blob/main/codex-rs/rmcp-client/src/oauth.rs#L2 + /// file: Use a file in the Codex home directory. + /// auto (default): Use the OS-specific keyring service if available, otherwise use a file. + #[serde(default)] + pub mcp_oauth_credentials_store: Option, + /// User-defined provider entries that extend/override the built-in list. #[serde(default)] pub model_providers: HashMap, @@ -1074,6 +1092,9 @@ impl Config { user_instructions, base_instructions, mcp_servers: cfg.mcp_servers, + // The config.toml omits "_mode" because it's a config file. However, "_mode" + // is important in code to differentiate the mode from the store implementation. + mcp_oauth_credentials_store_mode: cfg.mcp_oauth_credentials_store.unwrap_or_default(), model_providers, project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES), project_doc_fallback_filenames: cfg @@ -1364,6 +1385,85 @@ exclude_slash_tmp = true ); } + #[test] + fn config_defaults_to_auto_oauth_store_mode() -> std::io::Result<()> { + let codex_home = TempDir::new()?; + let cfg = ConfigToml::default(); + + let config = Config::load_from_base_config_with_overrides( + cfg, + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + )?; + + assert_eq!( + config.mcp_oauth_credentials_store_mode, + OAuthCredentialsStoreMode::Auto, + ); + + Ok(()) + } + + #[test] + fn config_honors_explicit_file_oauth_store_mode() -> std::io::Result<()> { + let codex_home = TempDir::new()?; + let cfg = ConfigToml { + mcp_oauth_credentials_store: Some(OAuthCredentialsStoreMode::File), + ..Default::default() + }; + + let config = Config::load_from_base_config_with_overrides( + cfg, + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + )?; + + assert_eq!( + config.mcp_oauth_credentials_store_mode, + OAuthCredentialsStoreMode::File, + ); + + Ok(()) + } + + #[tokio::test] + async fn managed_config_overrides_oauth_store_mode() -> anyhow::Result<()> { + let codex_home = TempDir::new()?; + let managed_path = codex_home.path().join("managed_config.toml"); + let config_path = codex_home.path().join(CONFIG_TOML_FILE); + + std::fs::write(&config_path, "mcp_oauth_credentials_store = \"file\"\n")?; + std::fs::write(&managed_path, "mcp_oauth_credentials_store = \"keyring\"\n")?; + + let overrides = crate::config_loader::LoaderOverrides { + managed_config_path: Some(managed_path.clone()), + #[cfg(target_os = "macos")] + managed_preferences_base64: None, + }; + + let root_value = load_resolved_config(codex_home.path(), Vec::new(), overrides).await?; + let cfg: ConfigToml = root_value.try_into().map_err(|e| { + tracing::error!("Failed to deserialize overridden config: {e}"); + std::io::Error::new(std::io::ErrorKind::InvalidData, e) + })?; + assert_eq!( + cfg.mcp_oauth_credentials_store, + Some(OAuthCredentialsStoreMode::Keyring), + ); + + let final_config = Config::load_from_base_config_with_overrides( + cfg, + ConfigOverrides::default(), + codex_home.path().to_path_buf(), + )?; + assert_eq!( + final_config.mcp_oauth_credentials_store_mode, + OAuthCredentialsStoreMode::Keyring, + ); + + Ok(()) + } + #[tokio::test] async fn load_global_mcp_servers_returns_empty_if_missing() -> anyhow::Result<()> { let codex_home = TempDir::new()?; @@ -1896,6 +1996,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), + mcp_oauth_credentials_store_mode: Default::default(), model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), @@ -1958,6 +2059,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), + mcp_oauth_credentials_store_mode: Default::default(), model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), @@ -2035,6 +2137,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), + mcp_oauth_credentials_store_mode: Default::default(), model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), @@ -2098,6 +2201,7 @@ model_verbosity = "high" notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), + mcp_oauth_credentials_store_mode: Default::default(), model_providers: fixture.model_provider_map.clone(), project_doc_max_bytes: PROJECT_DOC_MAX_BYTES, project_doc_fallback_filenames: Vec::new(), diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index 76738a03..37986e4b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -16,6 +16,7 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use codex_mcp_client::McpClient; +use codex_rmcp_client::OAuthCredentialsStoreMode; use codex_rmcp_client::RmcpClient; use mcp_types::ClientCapabilities; use mcp_types::Implementation; @@ -125,9 +126,11 @@ impl McpClientAdapter { bearer_token: Option, params: mcp_types::InitializeRequestParams, startup_timeout: Duration, + store_mode: OAuthCredentialsStoreMode, ) -> Result { let client = Arc::new( - RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token).await?, + RmcpClient::new_streamable_http_client(&server_name, &url, bearer_token, store_mode) + .await?, ); client.initialize(params, Some(startup_timeout)).await?; Ok(McpClientAdapter::Rmcp(client)) @@ -182,6 +185,7 @@ impl McpConnectionManager { pub async fn new( mcp_servers: HashMap, use_rmcp_client: bool, + store_mode: OAuthCredentialsStoreMode, ) -> Result<(Self, ClientStartErrors)> { // Early exit if no servers are configured. if mcp_servers.is_empty() { @@ -249,6 +253,7 @@ impl McpConnectionManager { bearer_token, params, startup_timeout, + store_mode, ) .await } diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index ac69a100..0d15584f 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -5,6 +5,7 @@ mod perform_oauth_login; mod rmcp_client; mod utils; +pub use oauth::OAuthCredentialsStoreMode; pub use oauth::StoredOAuthTokens; pub use oauth::WrappedOAuthTokenResponse; pub use oauth::delete_oauth_tokens; diff --git a/codex-rs/rmcp-client/src/oauth.rs b/codex-rs/rmcp-client/src/oauth.rs index bb13b718..05348476 100644 --- a/codex-rs/rmcp-client/src/oauth.rs +++ b/codex-rs/rmcp-client/src/oauth.rs @@ -58,6 +58,21 @@ pub struct StoredOAuthTokens { pub token_response: WrappedOAuthTokenResponse, } +/// Determine where Codex should store and read MCP credentials. +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum OAuthCredentialsStoreMode { + /// `Keyring` when available; otherwise, `File`. + /// Credentials stored in the keyring will only be readable by Codex unless the user explicitly grants access via OS-level keyring access. + #[default] + Auto, + /// CODEX_HOME/.credentials.json + /// This file will be readable to Codex and other applications running as the same user. + File, + /// Keyring when available, otherwise fail. + Keyring, +} + #[derive(Debug)] struct CredentialStoreError(anyhow::Error); @@ -83,15 +98,15 @@ impl fmt::Display for CredentialStoreError { impl std::error::Error for CredentialStoreError {} -trait CredentialStore { +trait KeyringStore { fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError>; fn save(&self, service: &str, account: &str, value: &str) -> Result<(), CredentialStoreError>; fn delete(&self, service: &str, account: &str) -> Result; } -struct KeyringCredentialStore; +struct DefaultKeyringStore; -impl CredentialStore for KeyringCredentialStore { +impl KeyringStore for DefaultKeyringStore { fn load(&self, service: &str, account: &str) -> Result, CredentialStoreError> { let entry = Entry::new(service, account).map_err(CredentialStoreError::new)?; match entry.get_password() { @@ -129,47 +144,85 @@ impl PartialEq for WrappedOAuthTokenResponse { } } -pub(crate) fn load_oauth_tokens(server_name: &str, url: &str) -> Result> { - let store = KeyringCredentialStore; - load_oauth_tokens_with_store(&store, server_name, url) +pub(crate) fn load_oauth_tokens( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> Result> { + let keyring_store = DefaultKeyringStore; + match store_mode { + OAuthCredentialsStoreMode::Auto => { + load_oauth_tokens_from_keyring_with_fallback_to_file(&keyring_store, server_name, url) + } + OAuthCredentialsStoreMode::File => load_oauth_tokens_from_file(server_name, url), + OAuthCredentialsStoreMode::Keyring => { + load_oauth_tokens_from_keyring(&keyring_store, server_name, url) + .with_context(|| "failed to read OAuth tokens from keyring".to_string()) + } + } } -fn load_oauth_tokens_with_store( - store: &C, +fn load_oauth_tokens_from_keyring_with_fallback_to_file( + keyring_store: &K, + server_name: &str, + url: &str, +) -> Result> { + match load_oauth_tokens_from_keyring(keyring_store, server_name, url) { + Ok(Some(tokens)) => Ok(Some(tokens)), + Ok(None) => load_oauth_tokens_from_file(server_name, url), + Err(error) => { + warn!("failed to read OAuth tokens from keyring: {error}"); + load_oauth_tokens_from_file(server_name, url) + .with_context(|| format!("failed to read OAuth tokens from keyring: {error}")) + } + } +} + +fn load_oauth_tokens_from_keyring( + keyring_store: &K, server_name: &str, url: &str, ) -> Result> { let key = compute_store_key(server_name, url)?; - match store.load(KEYRING_SERVICE, &key) { + match keyring_store.load(KEYRING_SERVICE, &key) { Ok(Some(serialized)) => { let tokens: StoredOAuthTokens = serde_json::from_str(&serialized) .context("failed to deserialize OAuth tokens from keyring")?; Ok(Some(tokens)) } - Ok(None) => load_oauth_tokens_from_file(server_name, url), - Err(error) => { - let message = error.message(); - warn!("failed to read OAuth tokens from keyring: {message}"); - load_oauth_tokens_from_file(server_name, url) - .with_context(|| format!("failed to read OAuth tokens from keyring: {message}")) + Ok(None) => Ok(None), + Err(error) => Err(error.into_error()), + } +} + +pub fn save_oauth_tokens( + server_name: &str, + tokens: &StoredOAuthTokens, + store_mode: OAuthCredentialsStoreMode, +) -> Result<()> { + let keyring_store = DefaultKeyringStore; + match store_mode { + OAuthCredentialsStoreMode::Auto => save_oauth_tokens_with_keyring_with_fallback_to_file( + &keyring_store, + server_name, + tokens, + ), + OAuthCredentialsStoreMode::File => save_oauth_tokens_to_file(tokens), + OAuthCredentialsStoreMode::Keyring => { + save_oauth_tokens_with_keyring(&keyring_store, server_name, tokens) } } } -pub fn save_oauth_tokens(server_name: &str, tokens: &StoredOAuthTokens) -> Result<()> { - let store = KeyringCredentialStore; - save_oauth_tokens_with_store(&store, server_name, tokens) -} - -fn save_oauth_tokens_with_store( - store: &C, +fn save_oauth_tokens_with_keyring( + keyring_store: &K, server_name: &str, tokens: &StoredOAuthTokens, ) -> Result<()> { let serialized = serde_json::to_string(tokens).context("failed to serialize OAuth tokens")?; let key = compute_store_key(server_name, &tokens.url)?; - match store.save(KEYRING_SERVICE, &key, &serialized) { + match keyring_store.save(KEYRING_SERVICE, &key, &serialized) { Ok(()) => { if let Err(error) = delete_oauth_tokens_from_file(&key) { warn!("failed to remove OAuth tokens from fallback storage: {error:?}"); @@ -177,31 +230,61 @@ fn save_oauth_tokens_with_store( Ok(()) } Err(error) => { - let message = error.message(); - warn!("failed to write OAuth tokens to keyring: {message}"); + let message = format!( + "failed to write OAuth tokens to keyring: {}", + error.message() + ); + warn!("{message}"); + Err(error.into_error().context(message)) + } + } +} + +fn save_oauth_tokens_with_keyring_with_fallback_to_file( + keyring_store: &K, + server_name: &str, + tokens: &StoredOAuthTokens, +) -> Result<()> { + match save_oauth_tokens_with_keyring(keyring_store, server_name, tokens) { + Ok(()) => Ok(()), + Err(error) => { + let message = error.to_string(); + warn!("falling back to file storage for OAuth tokens: {message}"); save_oauth_tokens_to_file(tokens) .with_context(|| format!("failed to write OAuth tokens to keyring: {message}")) } } } -pub fn delete_oauth_tokens(server_name: &str, url: &str) -> Result { - let store = KeyringCredentialStore; - delete_oauth_tokens_with_store(&store, server_name, url) +pub fn delete_oauth_tokens( + server_name: &str, + url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> Result { + let keyring_store = DefaultKeyringStore; + delete_oauth_tokens_from_keyring_and_file(&keyring_store, store_mode, server_name, url) } -fn delete_oauth_tokens_with_store( - store: &C, +fn delete_oauth_tokens_from_keyring_and_file( + keyring_store: &K, + store_mode: OAuthCredentialsStoreMode, server_name: &str, url: &str, ) -> Result { let key = compute_store_key(server_name, url)?; - let keyring_removed = match store.delete(KEYRING_SERVICE, &key) { + let keyring_result = keyring_store.delete(KEYRING_SERVICE, &key); + let keyring_removed = match keyring_result { Ok(removed) => removed, Err(error) => { let message = error.message(); warn!("failed to delete OAuth tokens from keyring: {message}"); - return Err(error.into_error()).context("failed to delete OAuth tokens from keyring"); + match store_mode { + OAuthCredentialsStoreMode::Auto | OAuthCredentialsStoreMode::Keyring => { + return Err(error.into_error()) + .context("failed to delete OAuth tokens from keyring"); + } + OAuthCredentialsStoreMode::File => false, + } } }; @@ -218,6 +301,7 @@ struct OAuthPersistorInner { server_name: String, url: String, authorization_manager: Arc>, + store_mode: OAuthCredentialsStoreMode, last_credentials: Mutex>, } @@ -225,14 +309,16 @@ impl OAuthPersistor { pub(crate) fn new( server_name: String, url: String, - manager: Arc>, + authorization_manager: Arc>, + store_mode: OAuthCredentialsStoreMode, initial_credentials: Option, ) -> Self { Self { inner: Arc::new(OAuthPersistorInner { server_name, url, - authorization_manager: manager, + authorization_manager, + store_mode, last_credentials: Mutex::new(initial_credentials), }), } @@ -257,15 +343,18 @@ impl OAuthPersistor { }; let mut last_credentials = self.inner.last_credentials.lock().await; if last_credentials.as_ref() != Some(&stored) { - save_oauth_tokens(&self.inner.server_name, &stored)?; + save_oauth_tokens(&self.inner.server_name, &stored, self.inner.store_mode)?; *last_credentials = Some(stored); } } None => { let mut last_serialized = self.inner.last_credentials.lock().await; if last_serialized.take().is_some() - && let Err(error) = - delete_oauth_tokens(&self.inner.server_name, &self.inner.url) + && let Err(error) = delete_oauth_tokens( + &self.inner.server_name, + &self.inner.url, + self.inner.store_mode, + ) { warn!( "failed to remove OAuth tokens for server {}: {error}", @@ -542,7 +631,7 @@ mod tests { } } - impl CredentialStore for MockCredentialStore { + impl KeyringStore for MockCredentialStore { fn load( &self, _service: &str, @@ -643,7 +732,8 @@ mod tests { let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; store.save(KEYRING_SERVICE, &key, &serialized)?; - let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?; + let loaded = + super::load_oauth_tokens_from_keyring(&store, &tokens.server_name, &tokens.url)?; assert_eq!(loaded, Some(expected)); Ok(()) } @@ -657,8 +747,12 @@ mod tests { super::save_oauth_tokens_to_file(&tokens)?; - let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)? - .expect("tokens should load from fallback"); + let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file( + &store, + &tokens.server_name, + &tokens.url, + )? + .expect("tokens should load from fallback"); assert_tokens_match_without_expiry(&loaded, &expected); Ok(()) } @@ -674,8 +768,12 @@ mod tests { super::save_oauth_tokens_to_file(&tokens)?; - let loaded = super::load_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)? - .expect("tokens should load from fallback"); + let loaded = super::load_oauth_tokens_from_keyring_with_fallback_to_file( + &store, + &tokens.server_name, + &tokens.url, + )? + .expect("tokens should load from fallback"); assert_tokens_match_without_expiry(&loaded, &expected); Ok(()) } @@ -689,7 +787,11 @@ mod tests { super::save_oauth_tokens_to_file(&tokens)?; - super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?; + super::save_oauth_tokens_with_keyring_with_fallback_to_file( + &store, + &tokens.server_name, + &tokens, + )?; let fallback_path = super::fallback_file_path()?; assert!(!fallback_path.exists(), "fallback file should be removed"); @@ -706,7 +808,11 @@ mod tests { let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; store.set_error(&key, KeyringError::Invalid("error".into(), "save".into())); - super::save_oauth_tokens_with_store(&store, &tokens.server_name, &tokens)?; + super::save_oauth_tokens_with_keyring_with_fallback_to_file( + &store, + &tokens.server_name, + &tokens, + )?; let fallback_path = super::fallback_file_path()?; assert!(fallback_path.exists(), "fallback file should be created"); @@ -734,8 +840,34 @@ mod tests { store.save(KEYRING_SERVICE, &key, &serialized)?; super::save_oauth_tokens_to_file(&tokens)?; - let removed = - super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url)?; + let removed = super::delete_oauth_tokens_from_keyring_and_file( + &store, + OAuthCredentialsStoreMode::Auto, + &tokens.server_name, + &tokens.url, + )?; + assert!(removed); + assert!(!store.contains(&key)); + assert!(!super::fallback_file_path()?.exists()); + Ok(()) + } + + #[test] + fn delete_oauth_tokens_file_mode_removes_keyring_only_entry() -> Result<()> { + let _env = TempCodexHome::new(); + let store = MockCredentialStore::default(); + let tokens = sample_tokens(); + let serialized = serde_json::to_string(&tokens)?; + let key = super::compute_store_key(&tokens.server_name, &tokens.url)?; + store.save(KEYRING_SERVICE, &key, &serialized)?; + assert!(store.contains(&key)); + + let removed = super::delete_oauth_tokens_from_keyring_and_file( + &store, + OAuthCredentialsStoreMode::Auto, + &tokens.server_name, + &tokens.url, + )?; assert!(removed); assert!(!store.contains(&key)); assert!(!super::fallback_file_path()?.exists()); @@ -751,8 +883,12 @@ mod tests { store.set_error(&key, KeyringError::Invalid("error".into(), "delete".into())); super::save_oauth_tokens_to_file(&tokens).unwrap(); - let result = - super::delete_oauth_tokens_with_store(&store, &tokens.server_name, &tokens.url); + let result = super::delete_oauth_tokens_from_keyring_and_file( + &store, + OAuthCredentialsStoreMode::Auto, + &tokens.server_name, + &tokens.url, + ); assert!(result.is_err()); assert!(super::fallback_file_path().unwrap().exists()); Ok(()) diff --git a/codex-rs/rmcp-client/src/perform_oauth_login.rs b/codex-rs/rmcp-client/src/perform_oauth_login.rs index b5a89361..c2d39a21 100644 --- a/codex-rs/rmcp-client/src/perform_oauth_login.rs +++ b/codex-rs/rmcp-client/src/perform_oauth_login.rs @@ -12,6 +12,7 @@ use tokio::sync::oneshot; use tokio::time::timeout; use urlencoding::decode; +use crate::OAuthCredentialsStoreMode; use crate::StoredOAuthTokens; use crate::WrappedOAuthTokenResponse; use crate::save_oauth_tokens; @@ -26,7 +27,11 @@ impl Drop for CallbackServerGuard { } } -pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result<()> { +pub async fn perform_oauth_login( + server_name: &str, + server_url: &str, + store_mode: OAuthCredentialsStoreMode, +) -> Result<()> { let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| anyhow!(err))?); let guard = CallbackServerGuard { server: Arc::clone(&server), @@ -81,7 +86,7 @@ pub async fn perform_oauth_login(server_name: &str, server_url: &str) -> Result< client_id, token_response: WrappedOAuthTokenResponse(credentials), }; - save_oauth_tokens(server_name, &stored)?; + save_oauth_tokens(server_name, &stored, store_mode)?; drop(guard); Ok(()) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 676af9a5..3d12e508 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -35,6 +35,7 @@ use tracing::warn; use crate::load_oauth_tokens; use crate::logging_client_handler::LoggingClientHandler; +use crate::oauth::OAuthCredentialsStoreMode; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; use crate::utils::convert_call_tool_result; @@ -119,10 +120,11 @@ impl RmcpClient { server_name: &str, url: &str, bearer_token: Option, + store_mode: OAuthCredentialsStoreMode, ) -> Result { let initial_oauth_tokens = match bearer_token { Some(_) => None, - None => match load_oauth_tokens(server_name, url) { + None => match load_oauth_tokens(server_name, url, store_mode) { Ok(tokens) => tokens, Err(err) => { warn!("failed to read tokens for server `{server_name}`: {err}"); @@ -132,7 +134,8 @@ impl RmcpClient { }; let transport = if let Some(initial_tokens) = initial_oauth_tokens.clone() { let (transport, oauth_persistor) = - create_oauth_transport_and_runtime(server_name, url, initial_tokens).await?; + create_oauth_transport_and_runtime(server_name, url, initial_tokens, store_mode) + .await?; PendingTransport::StreamableHttpWithOAuth { transport, oauth_persistor, @@ -286,6 +289,7 @@ async fn create_oauth_transport_and_runtime( server_name: &str, url: &str, initial_tokens: StoredOAuthTokens, + credentials_store: OAuthCredentialsStoreMode, ) -> Result<( StreamableHttpClientTransport>, OAuthPersistor, @@ -320,6 +324,7 @@ async fn create_oauth_transport_and_runtime( server_name.to_string(), url.to_string(), auth_manager, + credentials_store, Some(initial_tokens), );