diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index cb4cd68d..2fdf9873 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -774,6 +774,7 @@ dependencies = [ "codex-arg0", "codex-common", "codex-core", + "codex-login", "codex-ollama", "codex-protocol", "core_test_support", diff --git a/codex-rs/cli/src/proto.rs b/codex-rs/cli/src/proto.rs index 3bc4d816..9f8c4d3b 100644 --- a/codex-rs/cli/src/proto.rs +++ b/codex-rs/cli/src/proto.rs @@ -9,6 +9,7 @@ use codex_core::config::ConfigOverrides; use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::Submission; +use codex_login::AuthManager; use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tracing::error; @@ -36,7 +37,10 @@ pub async fn run_main(opts: ProtoCli) -> anyhow::Result<()> { let config = Config::load_with_cli_overrides(overrides_vec, ConfigOverrides::default())?; // Use conversation_manager API to start a conversation - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::new(AuthManager::shared( + config.codex_home.clone(), + config.preferred_auth_method, + )); let NewConversation { conversation_id: _, conversation, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 471312d3..174ac58f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -4,8 +4,8 @@ use std::sync::OnceLock; use std::time::Duration; use bytes::Bytes; +use codex_login::AuthManager; use codex_login::AuthMode; -use codex_login::CodexAuth; use eventsource_stream::Eventsource; use futures::prelude::*; use regex_lite::Regex; @@ -61,7 +61,7 @@ struct Error { #[derive(Debug, Clone)] pub struct ModelClient { config: Arc, - auth: Option, + auth_manager: Option>, client: reqwest::Client, provider: ModelProviderInfo, session_id: Uuid, @@ -72,7 +72,7 @@ pub struct ModelClient { impl ModelClient { pub fn new( config: Arc, - auth: Option, + auth_manager: Option>, provider: ModelProviderInfo, effort: ReasoningEffortConfig, summary: ReasoningSummaryConfig, @@ -80,7 +80,7 @@ impl ModelClient { ) -> Self { Self { config, - auth, + auth_manager, client: reqwest::Client::new(), provider, session_id, @@ -141,7 +141,8 @@ impl ModelClient { return stream_from_fixture(path, self.provider.clone()).await; } - let auth = self.auth.clone(); + let auth_manager = self.auth_manager.clone(); + let auth = auth_manager.as_ref().and_then(|m| m.auth()); let auth_mode = auth.as_ref().map(|a| a.mode); @@ -264,9 +265,10 @@ impl ModelClient { .and_then(|s| s.parse::().ok()); if status == StatusCode::UNAUTHORIZED - && let Some(a) = auth.as_ref() + && let Some(manager) = auth_manager.as_ref() + && manager.auth().is_some() { - let _ = a.refresh_token().await; + let _ = manager.refresh_token().await; } // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx @@ -353,8 +355,8 @@ impl ModelClient { self.summary } - pub fn get_auth(&self) -> Option { - self.auth.clone() + pub fn get_auth_manager(&self) -> Option> { + self.auth_manager.clone() } } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 7d616f96..f30ebcd3 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -13,7 +13,7 @@ use async_channel::Sender; use codex_apply_patch::ApplyPatchAction; use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; -use codex_login::CodexAuth; +use codex_login::AuthManager; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use futures::prelude::*; @@ -144,7 +144,10 @@ pub(crate) const INITIAL_SUBMIT_ID: &str = ""; impl Codex { /// Spawn a new [`Codex`] and initialize the session. - pub async fn spawn(config: Config, auth: Option) -> CodexResult { + pub async fn spawn( + config: Config, + auth_manager: Arc, + ) -> CodexResult { let (tx_sub, rx_sub) = async_channel::bounded(64); let (tx_event, rx_event) = async_channel::unbounded(); @@ -169,13 +172,17 @@ impl Codex { }; // Generate a unique ID for the lifetime of this Codex session. - let (session, turn_context) = - Session::new(configure_session, config.clone(), auth, tx_event.clone()) - .await - .map_err(|e| { - error!("Failed to create session: {e:#}"); - CodexErr::InternalAgentDied - })?; + let (session, turn_context) = Session::new( + configure_session, + config.clone(), + auth_manager.clone(), + tx_event.clone(), + ) + .await + .map_err(|e| { + error!("Failed to create session: {e:#}"); + CodexErr::InternalAgentDied + })?; let session_id = session.session_id; // This task will run until Op::Shutdown is received. @@ -323,7 +330,7 @@ impl Session { async fn new( configure_session: ConfigureSession, config: Arc, - auth: Option, + auth_manager: Arc, tx_event: Sender, ) -> anyhow::Result<(Arc, TurnContext)> { let ConfigureSession { @@ -467,7 +474,7 @@ impl Session { // construct the model client. let client = ModelClient::new( config.clone(), - auth.clone(), + Some(auth_manager.clone()), provider.clone(), model_reasoning_effort, model_reasoning_summary, @@ -1034,7 +1041,8 @@ async fn submission_loop( let effective_effort = effort.unwrap_or(prev.client.get_reasoning_effort()); let effective_summary = summary.unwrap_or(prev.client.get_reasoning_summary()); - let auth = prev.client.get_auth(); + let auth_manager = prev.client.get_auth_manager(); + // Build updated config for the client let mut updated_config = (*config).clone(); updated_config.model = effective_model.clone(); @@ -1042,7 +1050,7 @@ async fn submission_loop( let client = ModelClient::new( Arc::new(updated_config), - auth, + auth_manager, provider, effective_effort, effective_summary, diff --git a/codex-rs/core/src/conversation_manager.rs b/codex-rs/core/src/conversation_manager.rs index 2dc69be4..b5538431 100644 --- a/codex-rs/core/src/conversation_manager.rs +++ b/codex-rs/core/src/conversation_manager.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use codex_login::AuthManager; use codex_login::CodexAuth; use tokio::sync::RwLock; use uuid::Uuid; @@ -28,33 +29,37 @@ pub struct NewConversation { /// maintaining them in memory. pub struct ConversationManager { conversations: Arc>>>, -} - -impl Default for ConversationManager { - fn default() -> Self { - Self { - conversations: Arc::new(RwLock::new(HashMap::new())), - } - } + auth_manager: Arc, } impl ConversationManager { - pub async fn new_conversation(&self, config: Config) -> CodexResult { - let auth = CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method)?; - self.new_conversation_with_auth(config, auth).await + pub fn new(auth_manager: Arc) -> Self { + Self { + conversations: Arc::new(RwLock::new(HashMap::new())), + auth_manager, + } } - /// Used for integration tests: should not be used by ordinary business - /// logic. - pub async fn new_conversation_with_auth( + /// Construct with a dummy AuthManager containing the provided CodexAuth. + /// Used for integration tests: should not be used by ordinary business logic. + pub fn with_auth(auth: CodexAuth) -> Self { + Self::new(codex_login::AuthManager::from_auth_for_testing(auth)) + } + + pub async fn new_conversation(&self, config: Config) -> CodexResult { + self.spawn_conversation(config, self.auth_manager.clone()) + .await + } + + async fn spawn_conversation( &self, config: Config, - auth: Option, + auth_manager: Arc, ) -> CodexResult { let CodexSpawnOk { codex, session_id: conversation_id, - } = Codex::spawn(config, auth).await?; + } = Codex::spawn(config, auth_manager).await?; // The first event must be `SessionInitialized`. Validate and forward it // to the caller so that they can display it in the conversation diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index 30ba62ee..629567a1 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -142,13 +142,14 @@ async fn includes_session_id_and_model_headers_in_request() { let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let NewConversation { conversation: codex, conversation_id, session_configured: _, } = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation"); @@ -207,9 +208,10 @@ async fn includes_base_instructions_override_in_request() { config.base_instructions = Some("test instructions".to_string()); config.model_provider = model_provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -262,9 +264,10 @@ async fn originator_config_override_is_used() { config.model_provider = model_provider; config.responses_originator_header = "my_override".to_owned(); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -318,13 +321,13 @@ async fn chatgpt_auth_sends_correct_request() { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::with_auth(create_dummy_codex_auth()); let NewConversation { conversation: codex, conversation_id, session_configured: _, } = conversation_manager - .new_conversation_with_auth(config, Some(create_dummy_codex_auth())) + .new_conversation(config) .await .expect("create new conversation"); @@ -411,7 +414,13 @@ async fn prefers_chatgpt_token_when_config_prefers_chatgpt() { config.model_provider = model_provider; config.preferred_auth_method = AuthMode::ChatGPT; - let conversation_manager = ConversationManager::default(); + let auth_manager = + match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) { + Ok(Some(auth)) => codex_login::AuthManager::from_auth_for_testing(auth), + Ok(None) => panic!("No CodexAuth found in codex_home"), + Err(e) => panic!("Failed to load CodexAuth: {}", e), + }; + let conversation_manager = ConversationManager::new(auth_manager); let NewConversation { conversation: codex, .. @@ -486,7 +495,13 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() { config.model_provider = model_provider; config.preferred_auth_method = AuthMode::ApiKey; - let conversation_manager = ConversationManager::default(); + let auth_manager = + match CodexAuth::from_codex_home(codex_home.path(), config.preferred_auth_method) { + Ok(Some(auth)) => codex_login::AuthManager::from_auth_for_testing(auth), + Ok(None) => panic!("No CodexAuth found in codex_home"), + Err(e) => panic!("Failed to load CodexAuth: {}", e), + }; + let conversation_manager = ConversationManager::new(auth_manager); let NewConversation { conversation: codex, .. @@ -540,9 +555,10 @@ async fn includes_user_instructions_message_in_request() { config.model_provider = model_provider; config.user_instructions = Some("be nice".to_string()); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -632,9 +648,9 @@ async fn azure_overrides_assign_properties_used_for_responses_url() { let mut config = load_default_config_for_test(&codex_home); config.model_provider = provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::with_auth(create_dummy_codex_auth()); let codex = conversation_manager - .new_conversation_with_auth(config, None) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -708,9 +724,9 @@ async fn env_var_overrides_loaded_auth() { let mut config = load_default_config_for_test(&codex_home); config.model_provider = provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::with_auth(create_dummy_codex_auth()); let codex = conversation_manager - .new_conversation_with_auth(config, Some(create_dummy_codex_auth())) + .new_conversation(config) .await .expect("create new conversation") .conversation; diff --git a/codex-rs/core/tests/compact.rs b/codex-rs/core/tests/compact.rs index 28b1ca8d..404a88e8 100644 --- a/codex-rs/core/tests/compact.rs +++ b/codex-rs/core/tests/compact.rs @@ -141,9 +141,9 @@ async fn summarize_context_three_requests_and_instructions() { let home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&home); config.model_provider = model_provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("dummy"))) + .new_conversation(config) .await .unwrap() .conversation; diff --git a/codex-rs/core/tests/prompt_caching.rs b/codex-rs/core/tests/prompt_caching.rs index ac1dbcd9..958aff7b 100644 --- a/codex-rs/core/tests/prompt_caching.rs +++ b/codex-rs/core/tests/prompt_caching.rs @@ -56,9 +56,10 @@ async fn default_system_instructions_contain_apply_patch() { config.model_provider = model_provider; config.user_instructions = Some("be consistent and helpful".to_string()); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -137,9 +138,10 @@ async fn prompt_tools_are_consistent_across_requests() { config.include_apply_patch_tool = true; config.include_plan_tool = true; - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -229,9 +231,10 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests config.model_provider = model_provider; config.user_instructions = Some("be consistent and helpful".to_string()); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -350,9 +353,10 @@ async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() { config.model_provider = model_provider; config.user_instructions = Some("be consistent and helpful".to_string()); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; @@ -472,9 +476,10 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() { config.model_provider = model_provider; config.user_instructions = Some("be consistent and helpful".to_string()); - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .expect("create new conversation") .conversation; diff --git a/codex-rs/core/tests/stream_error_allows_next_turn.rs b/codex-rs/core/tests/stream_error_allows_next_turn.rs index 415e75a4..8d4b2c99 100644 --- a/codex-rs/core/tests/stream_error_allows_next_turn.rs +++ b/codex-rs/core/tests/stream_error_allows_next_turn.rs @@ -88,9 +88,10 @@ async fn continue_after_stream_error() { config.base_instructions = Some("You are a helpful assistant".to_string()); config.model_provider = provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .unwrap() .conversation; diff --git a/codex-rs/core/tests/stream_no_completed.rs b/codex-rs/core/tests/stream_no_completed.rs index 3fb3f642..a425cfa7 100644 --- a/codex-rs/core/tests/stream_no_completed.rs +++ b/codex-rs/core/tests/stream_no_completed.rs @@ -93,9 +93,10 @@ async fn retries_on_early_close() { let codex_home = TempDir::new().unwrap(); let mut config = load_default_config_for_test(&codex_home); config.model_provider = model_provider; - let conversation_manager = ConversationManager::default(); + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); let codex = conversation_manager - .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .new_conversation(config) .await .unwrap() .conversation; diff --git a/codex-rs/exec/Cargo.toml b/codex-rs/exec/Cargo.toml index 89dc3951..a270b587 100644 --- a/codex-rs/exec/Cargo.toml +++ b/codex-rs/exec/Cargo.toml @@ -25,6 +25,7 @@ codex-common = { path = "../common", features = [ "sandbox_summary", ] } codex-core = { path = "../core" } +codex-login = { path = "../login" } codex-ollama = { path = "../ollama" } codex-protocol = { path = "../protocol" } owo-colors = "4.2.0" diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index e18314fb..d403cb79 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -20,6 +20,7 @@ use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::protocol::TaskCompleteEvent; use codex_core::util::is_inside_git_repo; +use codex_login::AuthManager; use codex_ollama::DEFAULT_OSS_MODEL; use codex_protocol::config_types::SandboxMode; use event_processor_with_human_output::EventProcessorWithHumanOutput; @@ -185,7 +186,10 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any std::process::exit(1); } - let conversation_manager = ConversationManager::default(); + let conversation_manager = ConversationManager::new(AuthManager::shared( + config.codex_home.clone(), + config.preferred_auth_method, + )); let NewConversation { conversation_id: _, conversation, diff --git a/codex-rs/login/src/auth_manager.rs b/codex-rs/login/src/auth_manager.rs new file mode 100644 index 00000000..5e892b28 --- /dev/null +++ b/codex-rs/login/src/auth_manager.rs @@ -0,0 +1,129 @@ +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::RwLock; + +use crate::AuthMode; +use crate::CodexAuth; + +/// Internal cached auth state. +#[derive(Clone, Debug)] +struct CachedAuth { + preferred_auth_mode: AuthMode, + auth: Option, +} + +/// Central manager providing a single source of truth for auth.json derived +/// authentication data. It loads once (or on preference change) and then +/// hands out cloned `CodexAuth` values so the rest of the program has a +/// consistent snapshot. +/// +/// External modifications to `auth.json` will NOT be observed until +/// `reload()` is called explicitly. This matches the design goal of avoiding +/// different parts of the program seeing inconsistent auth data mid‑run. +#[derive(Debug)] +pub struct AuthManager { + codex_home: PathBuf, + inner: RwLock, +} + +impl AuthManager { + /// Create a new manager loading the initial auth using the provided + /// preferred auth method. Errors loading auth are swallowed; `auth()` will + /// simply return `None` in that case so callers can treat it as an + /// unauthenticated state. + pub fn new(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Self { + let auth = crate::CodexAuth::from_codex_home(&codex_home, preferred_auth_mode) + .ok() + .flatten(); + Self { + codex_home, + inner: RwLock::new(CachedAuth { + preferred_auth_mode, + auth, + }), + } + } + + /// Create an AuthManager with a specific CodexAuth, for testing only. + pub fn from_auth_for_testing(auth: CodexAuth) -> Arc { + let preferred_auth_mode = auth.mode; + let cached = CachedAuth { + preferred_auth_mode, + auth: Some(auth), + }; + Arc::new(Self { + codex_home: PathBuf::new(), + inner: RwLock::new(cached), + }) + } + + /// Current cached auth (clone). May be `None` if not logged in or load failed. + pub fn auth(&self) -> Option { + self.inner.read().ok().and_then(|c| c.auth.clone()) + } + + /// Preferred auth method used when (re)loading. + pub fn preferred_auth_method(&self) -> AuthMode { + self.inner + .read() + .map(|c| c.preferred_auth_mode) + .unwrap_or(AuthMode::ApiKey) + } + + /// Force a reload using the existing preferred auth method. Returns + /// whether the auth value changed. + pub fn reload(&self) -> bool { + let preferred = self.preferred_auth_method(); + let new_auth = crate::CodexAuth::from_codex_home(&self.codex_home, preferred) + .ok() + .flatten(); + if let Ok(mut guard) = self.inner.write() { + let changed = !AuthManager::auths_equal(&guard.auth, &new_auth); + guard.auth = new_auth; + changed + } else { + false + } + } + + fn auths_equal(a: &Option, b: &Option) -> bool { + match (a, b) { + (None, None) => true, + (Some(a), Some(b)) => a == b, + _ => false, + } + } + + /// Convenience constructor returning an `Arc` wrapper. + pub fn shared(codex_home: PathBuf, preferred_auth_mode: AuthMode) -> Arc { + Arc::new(Self::new(codex_home, preferred_auth_mode)) + } + + /// Attempt to refresh the current auth token (if any). On success, reload + /// the auth state from disk so other components observe refreshed token. + pub async fn refresh_token(&self) -> std::io::Result> { + let auth = match self.auth() { + Some(a) => a, + None => return Ok(None), + }; + match auth.refresh_token().await { + Ok(token) => { + // Reload to pick up persisted changes. + self.reload(); + Ok(Some(token)) + } + Err(e) => Err(e), + } + } + + /// Log out by deleting the on‑disk auth.json (if present). Returns Ok(true) + /// if a file was removed, Ok(false) if no auth file existed. On success, + /// reloads the in‑memory auth cache so callers immediately observe the + /// unauthenticated state. + pub fn logout(&self) -> std::io::Result { + let removed = crate::logout(&self.codex_home)?; + // Always reload to clear any cached auth (even if file absent). + self.reload(); + Ok(removed) + } +} diff --git a/codex-rs/login/src/lib.rs b/codex-rs/login/src/lib.rs index 1f118823..6d5297ea 100644 --- a/codex-rs/login/src/lib.rs +++ b/codex-rs/login/src/lib.rs @@ -23,12 +23,14 @@ pub use crate::server::run_login_server; pub use crate::token_data::TokenData; use crate::token_data::parse_id_token; +mod auth_manager; mod pkce; mod server; mod token_data; pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; pub const OPENAI_API_KEY_ENV_VAR: &str = "OPENAI_API_KEY"; +pub use auth_manager::AuthManager; pub use codex_protocol::mcp_protocol::AuthMode; #[derive(Debug, Clone)] diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index 657cda25..0bbf6ff8 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -14,6 +14,7 @@ use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ReviewDecision; +use codex_login::AuthManager; use codex_protocol::mcp_protocol::AuthMode; use codex_protocol::mcp_protocol::GitDiffToRemoteResponse; use mcp_types::JSONRPCErrorError; @@ -31,10 +32,8 @@ use crate::outgoing_message::OutgoingNotification; use codex_core::protocol::InputItem as CoreInputItem; use codex_core::protocol::Op; use codex_login::CLIENT_ID; -use codex_login::CodexAuth; use codex_login::ServerOptions as LoginServerOptions; use codex_login::ShutdownHandle; -use codex_login::logout; use codex_login::run_login_server; use codex_protocol::mcp_protocol::APPLY_PATCH_APPROVAL_METHOD; use codex_protocol::mcp_protocol::AddConversationListenerParams; @@ -78,6 +77,7 @@ impl ActiveLogin { /// Handles JSON-RPC messages for Codex conversations. pub(crate) struct CodexMessageProcessor { + auth_manager: Arc, conversation_manager: Arc, outgoing: Arc, codex_linux_sandbox_exe: Option, @@ -90,12 +90,14 @@ pub(crate) struct CodexMessageProcessor { impl CodexMessageProcessor { pub fn new( + auth_manager: Arc, conversation_manager: Arc, outgoing: Arc, codex_linux_sandbox_exe: Option, config: Arc, ) -> Self { Self { + auth_manager, conversation_manager, outgoing, codex_linux_sandbox_exe, @@ -129,6 +131,9 @@ impl CodexMessageProcessor { ClientRequest::RemoveConversationListener { request_id, params } => { self.remove_conversation_listener(request_id, params).await; } + ClientRequest::GitDiffToRemote { request_id, params } => { + self.git_diff_to_origin(request_id, params.cwd).await; + } ClientRequest::LoginChatGpt { request_id } => { self.login_chatgpt(request_id).await; } @@ -138,11 +143,8 @@ impl CodexMessageProcessor { ClientRequest::LogoutChatGpt { request_id } => { self.logout_chatgpt(request_id).await; } - ClientRequest::GetAuthStatus { request_id } => { - self.get_auth_status(request_id).await; - } - ClientRequest::GitDiffToRemote { request_id, params } => { - self.git_diff_to_origin(request_id, params.cwd).await; + ClientRequest::GetAuthStatus { request_id, params } => { + self.get_auth_status(request_id, params).await; } } } @@ -185,6 +187,7 @@ impl CodexMessageProcessor { // Spawn background task to monitor completion. let outgoing_clone = self.outgoing.clone(); let active_login = self.active_login.clone(); + let auth_manager = self.auth_manager.clone(); tokio::spawn(async move { let (success, error_msg) = match tokio::time::timeout( LOGIN_CHATGPT_TIMEOUT, @@ -211,8 +214,13 @@ impl CodexMessageProcessor { // Send an auth status change notification. if success { + // Update in-memory auth cache now that login completed. + auth_manager.reload(); + + // Notify clients with the actual current auth mode. + let current_auth_method = auth_manager.auth().map(|a| a.mode); let payload = AuthStatusChangeNotification { - auth_method: Some(AuthMode::ChatGPT), + auth_method: current_auth_method, }; outgoing_clone .send_server_notification(ServerNotification::AuthStatusChange(payload)) @@ -276,10 +284,7 @@ impl CodexMessageProcessor { } } - // Load config to locate codex_home for persistent logout. - let config = self.config.as_ref(); - - if let Err(err) = logout(&config.codex_home) { + if let Err(err) = self.auth_manager.logout() { let error = JSONRPCErrorError { code: INTERNAL_ERROR_CODE, message: format!("logout failed: {err}"), @@ -296,45 +301,55 @@ impl CodexMessageProcessor { ) .await; - // Send auth status change notification. - let payload = AuthStatusChangeNotification { auth_method: None }; + // Send auth status change notification reflecting the current auth mode + // after logout (which may fall back to API key via env var). + let current_auth_method = self.auth_manager.auth().map(|auth| auth.mode); + let payload = AuthStatusChangeNotification { + auth_method: current_auth_method, + }; self.outgoing .send_server_notification(ServerNotification::AuthStatusChange(payload)) .await; } - async fn get_auth_status(&self, request_id: RequestId) { - // Load config to determine codex_home and preferred auth method. - let config = self.config.as_ref(); + async fn get_auth_status( + &self, + request_id: RequestId, + params: codex_protocol::mcp_protocol::GetAuthStatusParams, + ) { + let preferred_auth_method: AuthMode = self.auth_manager.preferred_auth_method(); + let include_token = params.include_token.unwrap_or(false); + let do_refresh = params.refresh_token.unwrap_or(false); - let preferred_auth_method: AuthMode = config.preferred_auth_method; - let response = - match CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method) { - Ok(Some(auth)) => { - // Verify that the current auth mode has a valid, non-empty token. - // If token acquisition fails or is empty, treat as unauthenticated. - let reported_auth_method = match auth.get_token().await { - Ok(token) if !token.is_empty() => Some(auth.mode), - Ok(_) => None, // Empty token - Err(err) => { - tracing::warn!("failed to get token for auth status: {err}"); - None - } - }; - codex_protocol::mcp_protocol::GetAuthStatusResponse { - auth_method: reported_auth_method, - preferred_auth_method, + if do_refresh && let Err(err) = self.auth_manager.refresh_token().await { + tracing::warn!("failed to refresh token while getting auth status: {err}"); + } + + let response = match self.auth_manager.auth() { + Some(auth) => { + let (reported_auth_method, token_opt) = match auth.get_token().await { + Ok(token) if !token.is_empty() => { + let tok = if include_token { Some(token) } else { None }; + (Some(auth.mode), tok) } + Ok(_) => (None, None), + Err(err) => { + tracing::warn!("failed to get token for auth status: {err}"); + (None, None) + } + }; + codex_protocol::mcp_protocol::GetAuthStatusResponse { + auth_method: reported_auth_method, + preferred_auth_method, + auth_token: token_opt, } - Ok(None) => codex_protocol::mcp_protocol::GetAuthStatusResponse { - auth_method: None, - preferred_auth_method, - }, - Err(_) => codex_protocol::mcp_protocol::GetAuthStatusResponse { - auth_method: None, - preferred_auth_method, - }, - }; + } + None => codex_protocol::mcp_protocol::GetAuthStatusResponse { + auth_method: None, + preferred_auth_method, + auth_token: None, + }, + }; self.outgoing.send_response(request_id, response).await; } diff --git a/codex-rs/mcp-server/src/message_processor.rs b/codex-rs/mcp-server/src/message_processor.rs index a22f9c5b..6be60151 100644 --- a/codex-rs/mcp-server/src/message_processor.rs +++ b/codex-rs/mcp-server/src/message_processor.rs @@ -13,6 +13,7 @@ use codex_protocol::mcp_protocol::ClientRequest; use codex_core::ConversationManager; use codex_core::config::Config; use codex_core::protocol::Submission; +use codex_login::AuthManager; use mcp_types::CallToolRequestParams; use mcp_types::CallToolResult; use mcp_types::ClientRequest as McpClientRequest; @@ -52,8 +53,11 @@ impl MessageProcessor { config: Arc, ) -> Self { let outgoing = Arc::new(outgoing); - let conversation_manager = Arc::new(ConversationManager::default()); + let auth_manager = + AuthManager::shared(config.codex_home.clone(), config.preferred_auth_method); + let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone())); let codex_message_processor = CodexMessageProcessor::new( + auth_manager, conversation_manager.clone(), outgoing.clone(), codex_linux_sandbox_exe.clone(), diff --git a/codex-rs/mcp-server/tests/auth.rs b/codex-rs/mcp-server/tests/auth.rs new file mode 100644 index 00000000..533cb903 --- /dev/null +++ b/codex-rs/mcp-server/tests/auth.rs @@ -0,0 +1,142 @@ +use std::path::Path; + +use codex_login::login_with_api_key; +use codex_protocol::mcp_protocol::AuthMode; +use codex_protocol::mcp_protocol::GetAuthStatusParams; +use codex_protocol::mcp_protocol::GetAuthStatusResponse; +use mcp_test_support::McpProcess; +use mcp_test_support::to_response; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +// Helper to create a config.toml; mirrors create_conversation.rs +fn create_config_toml(codex_home: &Path) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "danger-full-access" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "http://127.0.0.1:0/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"#, + ) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn get_auth_status_no_auth() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + create_config_toml(codex_home.path()).expect("write config.toml"); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + let request_id = mcp + .send_get_auth_status_request(GetAuthStatusParams { + include_token: Some(true), + refresh_token: Some(false), + }) + .await + .expect("send getAuthStatus"); + + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await + .expect("getAuthStatus timeout") + .expect("getAuthStatus response"); + let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status"); + assert_eq!(status.auth_method, None, "expected no auth method"); + assert_eq!(status.auth_token, None, "expected no token"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn get_auth_status_with_api_key() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + create_config_toml(codex_home.path()).expect("write config.toml"); + login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key"); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + let request_id = mcp + .send_get_auth_status_request(GetAuthStatusParams { + include_token: Some(true), + refresh_token: Some(false), + }) + .await + .expect("send getAuthStatus"); + + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await + .expect("getAuthStatus timeout") + .expect("getAuthStatus response"); + let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status"); + assert_eq!(status.auth_method, Some(AuthMode::ApiKey)); + assert_eq!(status.auth_token, Some("sk-test-key".to_string())); + assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn get_auth_status_with_api_key_no_include_token() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + create_config_toml(codex_home.path()).expect("write config.toml"); + login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key"); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + // Build params via struct so None field is omitted in wire JSON. + let params = GetAuthStatusParams { + include_token: None, + refresh_token: Some(false), + }; + let request_id = mcp + .send_get_auth_status_request(params) + .await + .expect("send getAuthStatus"); + + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await + .expect("getAuthStatus timeout") + .expect("getAuthStatus response"); + let status: GetAuthStatusResponse = to_response(resp).expect("deserialize status"); + assert_eq!(status.auth_method, Some(AuthMode::ApiKey)); + assert!(status.auth_token.is_none(), "token must be omitted"); + assert_eq!(status.preferred_auth_method, AuthMode::ChatGPT); +} diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index 6793dcaf..bcc37843 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -13,6 +13,8 @@ use anyhow::Context; use assert_cmd::prelude::*; use codex_mcp_server::CodexToolCallParam; use codex_protocol::mcp_protocol::AddConversationListenerParams; +use codex_protocol::mcp_protocol::CancelLoginChatGptParams; +use codex_protocol::mcp_protocol::GetAuthStatusParams; use codex_protocol::mcp_protocol::InterruptConversationParams; use codex_protocol::mcp_protocol::NewConversationParams; use codex_protocol::mcp_protocol::RemoveConversationListenerParams; @@ -217,6 +219,34 @@ impl McpProcess { self.send_request("interruptConversation", params).await } + /// Send a `getAuthStatus` JSON-RPC request. + pub async fn send_get_auth_status_request( + &mut self, + params: GetAuthStatusParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("getAuthStatus", params).await + } + + /// Send a `loginChatGpt` JSON-RPC request. + pub async fn send_login_chat_gpt_request(&mut self) -> anyhow::Result { + self.send_request("loginChatGpt", None).await + } + + /// Send a `cancelLoginChatGpt` JSON-RPC request. + pub async fn send_cancel_login_chat_gpt_request( + &mut self, + params: CancelLoginChatGptParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("cancelLoginChatGpt", params).await + } + + /// Send a `logoutChatGpt` JSON-RPC request. + pub async fn send_logout_chat_gpt_request(&mut self) -> anyhow::Result { + self.send_request("logoutChatGpt", None).await + } + async fn send_request( &mut self, method: &str, diff --git a/codex-rs/mcp-server/tests/login.rs b/codex-rs/mcp-server/tests/login.rs new file mode 100644 index 00000000..7a796c01 --- /dev/null +++ b/codex-rs/mcp-server/tests/login.rs @@ -0,0 +1,146 @@ +use std::path::Path; +use std::time::Duration; + +use codex_login::login_with_api_key; +use codex_protocol::mcp_protocol::CancelLoginChatGptParams; +use codex_protocol::mcp_protocol::CancelLoginChatGptResponse; +use codex_protocol::mcp_protocol::GetAuthStatusParams; +use codex_protocol::mcp_protocol::GetAuthStatusResponse; +use codex_protocol::mcp_protocol::LoginChatGptResponse; +use codex_protocol::mcp_protocol::LogoutChatGptResponse; +use mcp_test_support::McpProcess; +use mcp_test_support::to_response; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +// Helper to create a config.toml; mirrors create_conversation.rs +fn create_config_toml(codex_home: &Path) -> std::io::Result<()> { + let config_toml = codex_home.join("config.toml"); + std::fs::write( + config_toml, + r#" +model = "mock-model" +approval_policy = "never" +sandbox_mode = "danger-full-access" + +model_provider = "mock_provider" + +[model_providers.mock_provider] +name = "Mock provider for test" +base_url = "http://127.0.0.1:0/v1" +wire_api = "chat" +request_max_retries = 0 +stream_max_retries = 0 +"#, + ) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn logout_chatgpt_removes_auth() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + create_config_toml(codex_home.path()).expect("write config.toml"); + login_with_api_key(codex_home.path(), "sk-test-key").expect("seed api key"); + assert!(codex_home.path().join("auth.json").exists()); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + let id = mcp + .send_logout_chat_gpt_request() + .await + .expect("send logoutChatGpt"); + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(id)), + ) + .await + .expect("logoutChatGpt timeout") + .expect("logoutChatGpt response"); + let _ok: LogoutChatGptResponse = to_response(resp).expect("deserialize logout response"); + + assert!( + !codex_home.path().join("auth.json").exists(), + "auth.json should be deleted" + ); + + // Verify status reflects signed-out state. + let status_id = mcp + .send_get_auth_status_request(GetAuthStatusParams { + include_token: Some(true), + refresh_token: Some(false), + }) + .await + .expect("send getAuthStatus"); + let status_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(status_id)), + ) + .await + .expect("getAuthStatus timeout") + .expect("getAuthStatus response"); + let status: GetAuthStatusResponse = to_response(status_resp).expect("deserialize status"); + assert_eq!(status.auth_method, None); + assert_eq!(status.auth_token, None); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn login_and_cancel_chatgpt() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + create_config_toml(codex_home.path()).expect("write config.toml"); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + let login_id = mcp + .send_login_chat_gpt_request() + .await + .expect("send loginChatGpt"); + let login_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(login_id)), + ) + .await + .expect("loginChatGpt timeout") + .expect("loginChatGpt response"); + let login: LoginChatGptResponse = to_response(login_resp).expect("deserialize login resp"); + + let cancel_id = mcp + .send_cancel_login_chat_gpt_request(CancelLoginChatGptParams { + login_id: login.login_id, + }) + .await + .expect("send cancelLoginChatGpt"); + let cancel_resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(cancel_id)), + ) + .await + .expect("cancelLoginChatGpt timeout") + .expect("cancelLoginChatGpt response"); + let _ok: CancelLoginChatGptResponse = + to_response(cancel_resp).expect("deserialize cancel response"); + + // Optionally observe the completion notification; do not fail if it races. + let maybe_note = timeout( + Duration::from_secs(2), + mcp.read_stream_until_notification_message("codex/event/login_chat_gpt_complete"), + ) + .await; + if maybe_note.is_err() { + eprintln!("warning: did not observe login_chat_gpt_complete notification after cancel"); + } +} diff --git a/codex-rs/protocol-ts/src/lib.rs b/codex-rs/protocol-ts/src/lib.rs index 2366ae86..1fbcc7bd 100644 --- a/codex-rs/protocol-ts/src/lib.rs +++ b/codex-rs/protocol-ts/src/lib.rs @@ -32,11 +32,16 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { codex_protocol::mcp_protocol::SendUserTurnResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::InterruptConversationParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::InterruptConversationResponse::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::GitDiffToRemoteParams::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::GitDiffToRemoteResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::LoginChatGptResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::LoginChatGptCompleteNotification::export_all_to(out_dir)?; codex_protocol::mcp_protocol::CancelLoginChatGptParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::CancelLoginChatGptResponse::export_all_to(out_dir)?; - codex_protocol::mcp_protocol::GitDiffToRemoteParams::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::LogoutChatGptParams::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::LogoutChatGptResponse::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::GetAuthStatusParams::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::GetAuthStatusResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ApplyPatchApprovalParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ExecCommandApprovalParams::export_all_to(out_dir)?; diff --git a/codex-rs/protocol/src/mcp_protocol.rs b/codex-rs/protocol/src/mcp_protocol.rs index 7cb38e15..7fb087cf 100644 --- a/codex-rs/protocol/src/mcp_protocol.rs +++ b/codex-rs/protocol/src/mcp_protocol.rs @@ -78,6 +78,11 @@ pub enum ClientRequest { request_id: RequestId, params: RemoveConversationListenerParams, }, + GitDiffToRemote { + #[serde(rename = "id")] + request_id: RequestId, + params: GitDiffToRemoteParams, + }, LoginChatGpt { #[serde(rename = "id")] request_id: RequestId, @@ -94,11 +99,7 @@ pub enum ClientRequest { GetAuthStatus { #[serde(rename = "id")] request_id: RequestId, - }, - GitDiffToRemote { - #[serde(rename = "id")] - request_id: RequestId, - params: GitDiffToRemoteParams, + params: GetAuthStatusParams, }, } @@ -195,9 +196,7 @@ pub struct CancelLoginChatGptResponse {} #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] -pub struct LogoutChatGptParams { - pub login_id: Uuid, -} +pub struct LogoutChatGptParams {} #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] @@ -206,7 +205,12 @@ pub struct LogoutChatGptResponse {} #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] pub struct GetAuthStatusParams { - pub login_id: Uuid, + /// If true, include the current auth token (if available) in the response. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_token: Option, + /// If true, attempt to refresh the token before returning status. + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] @@ -215,6 +219,8 @@ pub struct GetAuthStatusResponse { #[serde(skip_serializing_if = "Option::is_none")] pub auth_method: Option, pub preferred_auth_method: AuthMode, + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_token: Option, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] diff --git a/codex-rs/tui/src/app.rs b/codex-rs/tui/src/app.rs index a532ba71..341eaf4d 100644 --- a/codex-rs/tui/src/app.rs +++ b/codex-rs/tui/src/app.rs @@ -9,6 +9,7 @@ use codex_ansi_escape::ansi_escape_line; use codex_core::ConversationManager; use codex_core::config::Config; use codex_core::protocol::TokenUsage; +use codex_login::AuthManager; use color_eyre::eyre::Result; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; @@ -50,6 +51,7 @@ pub(crate) struct App { impl App { pub async fn run( tui: &mut tui::Tui, + auth_manager: Arc, config: Config, initial_prompt: Option, initial_images: Vec, @@ -58,7 +60,7 @@ impl App { let (app_event_tx, mut app_event_rx) = unbounded_channel(); let app_event_tx = AppEventSender::new(app_event_tx); - let conversation_manager = Arc::new(ConversationManager::default()); + let conversation_manager = Arc::new(ConversationManager::new(auth_manager.clone())); let enhanced_keys_supported = supports_keyboard_enhancement().unwrap_or(false); diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index d981f06d..e43dfc78 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -21,6 +21,7 @@ use codex_core::protocol::PatchApplyBeginEvent; use codex_core::protocol::PatchApplyEndEvent; use codex_core::protocol::StreamErrorEvent; use codex_core::protocol::TaskCompleteEvent; +use codex_login::CodexAuth; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use crossterm::event::KeyModifiers; @@ -104,7 +105,9 @@ async fn helpers_are_available_and_do_not_panic() { let (tx_raw, _rx) = unbounded_channel::(); let tx = AppEventSender::new(tx_raw); let cfg = test_config(); - let conversation_manager = Arc::new(ConversationManager::default()); + let conversation_manager = Arc::new(ConversationManager::with_auth(CodexAuth::from_api_key( + "test", + ))); let mut w = ChatWidget::new( cfg, conversation_manager, diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index bce0d899..d586c202 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -12,6 +12,7 @@ use codex_core::config::find_codex_home; use codex_core::config::load_config_as_toml_with_cli_overrides; use codex_core::protocol::AskForApproval; use codex_core::protocol::SandboxPolicy; +use codex_login::AuthManager; use codex_login::AuthMode; use codex_login::CodexAuth; use codex_ollama::DEFAULT_OSS_MODEL; @@ -300,6 +301,7 @@ async fn run_ratatui_app( let Cli { prompt, images, .. } = cli; + let auth_manager = AuthManager::shared(config.codex_home.clone(), config.preferred_auth_method); let login_status = get_login_status(&config); let should_show_onboarding = should_show_onboarding(login_status, &config, should_show_trust_screen); @@ -312,6 +314,7 @@ async fn run_ratatui_app( show_trust_screen: should_show_trust_screen, login_status, preferred_auth_method: config.preferred_auth_method, + auth_manager: auth_manager.clone(), }, &mut tui, ) @@ -322,7 +325,7 @@ async fn run_ratatui_app( } } - let app_result = App::run(&mut tui, config, prompt, images).await; + let app_result = App::run(&mut tui, auth_manager, config, prompt, images).await; restore(); // Mark the end of the recorded session. diff --git a/codex-rs/tui/src/onboarding/auth.rs b/codex-rs/tui/src/onboarding/auth.rs index 6f653cd0..347289d2 100644 --- a/codex-rs/tui/src/onboarding/auth.rs +++ b/codex-rs/tui/src/onboarding/auth.rs @@ -1,5 +1,6 @@ #![allow(clippy::unwrap_used)] +use codex_login::AuthManager; use codex_login::CLIENT_ID; use codex_login::ServerOptions; use codex_login::ShutdownHandle; @@ -112,6 +113,7 @@ pub(crate) struct AuthModeWidget { pub codex_home: PathBuf, pub login_status: LoginStatus, pub preferred_auth_method: AuthMode, + pub auth_manager: Arc, } impl AuthModeWidget { @@ -338,6 +340,7 @@ impl AuthModeWidget { Ok(child) => { let sign_in_state = self.sign_in_state.clone(); let request_frame = self.request_frame.clone(); + let auth_manager = self.auth_manager.clone(); tokio::spawn(async move { let auth_url = child.auth_url.clone(); { @@ -351,6 +354,9 @@ impl AuthModeWidget { let r = child.block_until_done().await; match r { Ok(()) => { + // Force the auth manager to reload the new auth information. + auth_manager.reload(); + *sign_in_state.write().unwrap() = SignInState::ChatGptSuccessMessage; request_frame.schedule_frame(); } diff --git a/codex-rs/tui/src/onboarding/onboarding_screen.rs b/codex-rs/tui/src/onboarding/onboarding_screen.rs index 5721430c..f0009bb9 100644 --- a/codex-rs/tui/src/onboarding/onboarding_screen.rs +++ b/codex-rs/tui/src/onboarding/onboarding_screen.rs @@ -1,4 +1,5 @@ use codex_core::util::is_inside_git_repo; +use codex_login::AuthManager; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use crossterm::event::KeyEventKind; @@ -58,6 +59,7 @@ pub(crate) struct OnboardingScreenArgs { pub show_login_screen: bool, pub login_status: LoginStatus, pub preferred_auth_method: AuthMode, + pub auth_manager: Arc, } impl OnboardingScreen { @@ -69,6 +71,7 @@ impl OnboardingScreen { show_login_screen, login_status, preferred_auth_method, + auth_manager, } = args; let mut steps: Vec = vec![Step::Welcome(WelcomeWidget { is_logged_in: !matches!(login_status, LoginStatus::NotAuthenticated), @@ -82,6 +85,7 @@ impl OnboardingScreen { codex_home: codex_home.clone(), login_status, preferred_auth_method, + auth_manager, })) } let is_git_repo = is_inside_git_repo(&cwd);