diff --git a/codex-rs/core/src/auth.rs b/codex-rs/core/src/auth.rs index b18cae5f..655d8230 100644 --- a/codex-rs/core/src/auth.rs +++ b/codex-rs/core/src/auth.rs @@ -1,12 +1,14 @@ mod storage; use chrono::Utc; +use reqwest::StatusCode; use serde::Deserialize; use serde::Serialize; #[cfg(test)] use serial_test::serial; use std::env; use std::fmt::Debug; +use std::io::ErrorKind; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; @@ -22,10 +24,14 @@ use crate::auth::storage::AuthStorageBackend; use crate::auth::storage::create_auth_storage; use crate::config::Config; use crate::default_client::CodexHttpClient; +use crate::error::RefreshTokenFailedError; +use crate::error::RefreshTokenFailedReason; use crate::token_data::PlanType; use crate::token_data::TokenData; use crate::token_data::parse_id_token; use crate::util::try_parse_error_message; +use serde_json::Value; +use thiserror::Error; #[derive(Debug, Clone)] pub struct CodexAuth { @@ -46,18 +52,54 @@ impl PartialEq for CodexAuth { // TODO(pakrym): use token exp field to check for expiration instead const TOKEN_REFRESH_INTERVAL: i64 = 8; +const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again."; +const REFRESH_TOKEN_REUSED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was already used. Please log out and sign in again."; +const REFRESH_TOKEN_INVALIDATED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was revoked. Please log out and sign in again."; +const REFRESH_TOKEN_UNKNOWN_MESSAGE: &str = + "Your access token could not be refreshed. Please log out and sign in again."; +const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +pub const REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR: &str = "CODEX_REFRESH_TOKEN_URL_OVERRIDE"; + +#[derive(Debug, Error)] +pub enum RefreshTokenError { + #[error("{0}")] + Permanent(#[from] RefreshTokenFailedError), + #[error(transparent)] + Transient(#[from] std::io::Error), +} + +impl RefreshTokenError { + pub fn failed_reason(&self) -> Option { + match self { + Self::Permanent(error) => Some(error.reason), + Self::Transient(_) => None, + } + } + + fn other_with_message(message: impl Into) -> Self { + Self::Transient(std::io::Error::other(message.into())) + } +} + +impl From for std::io::Error { + fn from(err: RefreshTokenError) -> Self { + match err { + RefreshTokenError::Permanent(failed) => std::io::Error::other(failed), + RefreshTokenError::Transient(inner) => inner, + } + } +} + impl CodexAuth { - pub async fn refresh_token(&self) -> Result { + pub async fn refresh_token(&self) -> Result { tracing::info!("Refreshing token"); - let token_data = self - .get_current_token_data() - .ok_or(std::io::Error::other("Token data is not available."))?; + let token_data = self.get_current_token_data().ok_or_else(|| { + RefreshTokenError::Transient(std::io::Error::other("Token data is not available.")) + })?; let token = token_data.refresh_token; - let refresh_response = try_refresh_token(token, &self.client) - .await - .map_err(std::io::Error::other)?; + let refresh_response = try_refresh_token(token, &self.client).await?; let updated = update_tokens( &self.storage, @@ -65,7 +107,8 @@ impl CodexAuth { refresh_response.access_token, refresh_response.refresh_token, ) - .await?; + .await + .map_err(RefreshTokenError::from)?; if let Ok(mut auth_lock) = self.auth_dot_json.lock() { *auth_lock = Some(updated.clone()); @@ -74,7 +117,7 @@ impl CodexAuth { let access = match updated.tokens { Some(t) => t.access_token, None => { - return Err(std::io::Error::other( + return Err(RefreshTokenError::other_with_message( "Token data is not available after refresh.", )); } @@ -99,15 +142,21 @@ impl CodexAuth { .. }) => { if last_refresh < Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) { - let refresh_response = tokio::time::timeout( + let refresh_result = tokio::time::timeout( Duration::from_secs(60), try_refresh_token(tokens.refresh_token.clone(), &self.client), ) - .await - .map_err(|_| { - std::io::Error::other("timed out while refreshing OpenAI API key") - })? - .map_err(std::io::Error::other)?; + .await; + let refresh_response = match refresh_result { + Ok(Ok(response)) => response, + Ok(Err(err)) => return Err(err.into()), + Err(_) => { + return Err(std::io::Error::new( + ErrorKind::TimedOut, + "timed out while refreshing OpenAI API key", + )); + } + }; let updated_auth_dot_json = update_tokens( &self.storage, @@ -425,7 +474,7 @@ async fn update_tokens( async fn try_refresh_token( refresh_token: String, client: &CodexHttpClient, -) -> std::io::Result { +) -> Result { let refresh_request = RefreshRequest { client_id: CLIENT_ID, grant_type: "refresh_token", @@ -433,30 +482,93 @@ async fn try_refresh_token( scope: "openid profile email", }; + let endpoint = refresh_token_endpoint(); + // Use shared client factory to include standard headers let response = client - .post("https://auth.openai.com/oauth/token") + .post(endpoint.as_str()) .header("Content-Type", "application/json") .json(&refresh_request) .send() .await - .map_err(std::io::Error::other)?; + .map_err(|err| RefreshTokenError::Transient(std::io::Error::other(err)))?; - if response.status().is_success() { + let status = response.status(); + if status.is_success() { let refresh_response = response .json::() .await - .map_err(std::io::Error::other)?; + .map_err(|err| RefreshTokenError::Transient(std::io::Error::other(err)))?; Ok(refresh_response) } else { - Err(std::io::Error::other(format!( - "Failed to refresh token: {}: {}", - response.status(), - try_parse_error_message(&response.text().await.unwrap_or_default()), - ))) + let body = response.text().await.unwrap_or_default(); + if status == StatusCode::UNAUTHORIZED { + let failed = classify_refresh_token_failure(&body); + Err(RefreshTokenError::Permanent(failed)) + } else { + let message = try_parse_error_message(&body); + Err(RefreshTokenError::Transient(std::io::Error::other( + format!("Failed to refresh token: {status}: {message}"), + ))) + } } } +fn classify_refresh_token_failure(body: &str) -> RefreshTokenFailedError { + let code = extract_refresh_token_error_code(body); + + let normalized_code = code.as_deref().map(str::to_ascii_lowercase); + let reason = match normalized_code.as_deref() { + Some("refresh_token_expired") => RefreshTokenFailedReason::Expired, + Some("refresh_token_reused") => RefreshTokenFailedReason::Exhausted, + Some("refresh_token_invalidated") => RefreshTokenFailedReason::Revoked, + _ => RefreshTokenFailedReason::Other, + }; + + if reason == RefreshTokenFailedReason::Other { + tracing::warn!( + backend_code = normalized_code.as_deref(), + backend_body = body, + "Encountered unknown 401 response while refreshing token" + ); + } + + let message = match reason { + RefreshTokenFailedReason::Expired => REFRESH_TOKEN_EXPIRED_MESSAGE.to_string(), + RefreshTokenFailedReason::Exhausted => REFRESH_TOKEN_REUSED_MESSAGE.to_string(), + RefreshTokenFailedReason::Revoked => REFRESH_TOKEN_INVALIDATED_MESSAGE.to_string(), + RefreshTokenFailedReason::Other => REFRESH_TOKEN_UNKNOWN_MESSAGE.to_string(), + }; + + RefreshTokenFailedError::new(reason, message) +} + +fn extract_refresh_token_error_code(body: &str) -> Option { + if body.trim().is_empty() { + return None; + } + + let Value::Object(map) = serde_json::from_str::(body).ok()? else { + return None; + }; + + if let Some(error_value) = map.get("error") { + match error_value { + Value::Object(obj) => { + if let Some(code) = obj.get("code").and_then(Value::as_str) { + return Some(code.to_string()); + } + } + Value::String(code) => { + return Some(code.to_string()); + } + _ => {} + } + } + + map.get("code").and_then(Value::as_str).map(str::to_string) +} + #[derive(Serialize)] struct RefreshRequest { client_id: &'static str, @@ -475,6 +587,11 @@ struct RefreshResponse { // Shared constant for token refresh (client id used for oauth token refresh flow) pub const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; +fn refresh_token_endpoint() -> String { + std::env::var(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR) + .unwrap_or_else(|_| REFRESH_TOKEN_URL.to_string()) +} + use std::sync::RwLock; /// Internal cached auth state. @@ -965,7 +1082,9 @@ impl AuthManager { /// Attempt to refresh the current auth token (if any). On success, reload /// the auth state from disk so other components observe refreshed token. - pub async fn refresh_token(&self) -> std::io::Result> { + /// If the token refresh fails in a permanent (non‑transient) way, logs out + /// to clear invalid auth state. + pub async fn refresh_token(&self) -> Result, RefreshTokenError> { let auth = match self.auth() { Some(a) => a, None => return Ok(None), diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 9dfa3a13..8cd1bca5 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -31,6 +31,7 @@ use tracing::warn; use crate::AuthManager; use crate::auth::CodexAuth; +use crate::auth::RefreshTokenError; use crate::chat_completions::AggregateStreamExt; use crate::chat_completions::stream_chat_completions; use crate::client_common::Prompt; @@ -389,12 +390,17 @@ impl ModelClient { && let Some(manager) = auth_manager.as_ref() && let Some(auth) = auth.as_ref() && auth.mode == AuthMode::ChatGPT + && let Err(err) = manager.refresh_token().await { - manager.refresh_token().await.map_err(|err| { - StreamAttemptError::Fatal(CodexErr::Fatal(format!( - "Failed to refresh ChatGPT credentials: {err}" - ))) - })?; + let stream_error = match err { + RefreshTokenError::Permanent(failed) => { + StreamAttemptError::Fatal(CodexErr::RefreshTokenFailed(failed)) + } + RefreshTokenError::Transient(other) => { + StreamAttemptError::RetryableTransportError(CodexErr::Io(other)) + } + }; + return Err(stream_error); } // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 5bf826fc..64cd0023 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1928,6 +1928,7 @@ async fn run_turn( return Err(CodexErr::UsageLimitReached(e)); } Err(CodexErr::UsageNotIncluded) => return Err(CodexErr::UsageNotIncluded), + Err(e @ CodexErr::RefreshTokenFailed(_)) => return Err(e), Err(e) => { // Use the configured provider-specific stream retry budget. let max_retries = turn_context.client.get_provider().stream_max_retries(); @@ -1946,7 +1947,7 @@ async fn run_turn( // at a seemingly frozen screen. sess.notify_stream_error( &turn_context, - format!("Re-connecting... {retries}/{max_retries}"), + format!("Reconnecting... {retries}/{max_retries}"), ) .await; diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index 64683275..6ca8970e 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -135,6 +135,9 @@ pub enum CodexErr { #[error("unsupported operation: {0}")] UnsupportedOperation(String), + #[error("{0}")] + RefreshTokenFailed(RefreshTokenFailedError), + #[error("Fatal error: {0}")] Fatal(String), @@ -201,6 +204,30 @@ impl std::fmt::Display for ResponseStreamFailed { } } +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[error("{message}")] +pub struct RefreshTokenFailedError { + pub reason: RefreshTokenFailedReason, + pub message: String, +} + +impl RefreshTokenFailedError { + pub fn new(reason: RefreshTokenFailedReason, message: impl Into) -> Self { + Self { + reason, + message: message.into(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RefreshTokenFailedReason { + Expired, + Exhausted, + Revoked, + Other, +} + #[derive(Debug)] pub struct UnexpectedResponseError { pub status: StatusCode, diff --git a/codex-rs/core/tests/suite/auth_refresh.rs b/codex-rs/core/tests/suite/auth_refresh.rs new file mode 100644 index 00000000..6daaf70b --- /dev/null +++ b/codex-rs/core/tests/suite/auth_refresh.rs @@ -0,0 +1,272 @@ +use anyhow::Context; +use anyhow::Result; +use base64::Engine; +use chrono::Duration; +use chrono::Utc; +use codex_core::CodexAuth; +use codex_core::auth::AuthCredentialsStoreMode; +use codex_core::auth::AuthDotJson; +use codex_core::auth::REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR; +use codex_core::auth::RefreshTokenError; +use codex_core::auth::load_auth_dot_json; +use codex_core::auth::save_auth; +use codex_core::error::RefreshTokenFailedReason; +use codex_core::token_data::IdTokenInfo; +use codex_core::token_data::TokenData; +use core_test_support::skip_if_no_network; +use pretty_assertions::assert_eq; +use serde::Serialize; +use serde_json::json; +use std::ffi::OsString; +use tempfile::TempDir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +const INITIAL_ACCESS_TOKEN: &str = "initial-access-token"; +const INITIAL_REFRESH_TOKEN: &str = "initial-refresh-token"; + +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_succeeds_updates_storage() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token" + }))) + .expect(1) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server)?; + let auth = ctx.auth.clone(); + + let access = auth + .refresh_token() + .await + .context("refresh should succeed")?; + assert_eq!(access, "new-access-token"); + + let stored = ctx.load_auth()?; + let tokens = stored.tokens.as_ref().context("tokens should exist")?; + assert_eq!(tokens.access_token, "new-access-token"); + assert_eq!(tokens.refresh_token, "new-refresh-token"); + let refreshed_at = stored + .last_refresh + .as_ref() + .context("last_refresh should be recorded")?; + assert!( + *refreshed_at >= ctx.initial_last_refresh, + "last_refresh should advance" + ); + + let cached = auth + .get_token_data() + .await + .context("token data should be cached")?; + assert_eq!(cached.access_token, "new-access-token"); + + server.verify().await; + Ok(()) +} + +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_returns_permanent_error_for_expired_refresh_token() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(401).set_body_json(json!({ + "error": { + "code": "refresh_token_expired" + } + }))) + .expect(1) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server)?; + let auth = ctx.auth.clone(); + + let err = auth + .refresh_token() + .await + .err() + .context("refresh should fail")?; + assert_eq!(err.failed_reason(), Some(RefreshTokenFailedReason::Expired)); + + let stored = ctx.load_auth()?; + let tokens = stored.tokens.as_ref().context("tokens should remain")?; + assert_eq!(tokens.access_token, INITIAL_ACCESS_TOKEN); + assert_eq!(tokens.refresh_token, INITIAL_REFRESH_TOKEN); + assert_eq!( + *stored + .last_refresh + .as_ref() + .context("last_refresh should remain unchanged")?, + ctx.initial_last_refresh, + ); + + server.verify().await; + Ok(()) +} + +#[serial_test::serial(auth_refresh)] +#[tokio::test] +async fn refresh_token_returns_transient_error_on_server_failure() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/oauth/token")) + .respond_with(ResponseTemplate::new(500).set_body_json(json!({ + "error": "temporary-failure" + }))) + .expect(1) + .mount(&server) + .await; + + let ctx = RefreshTokenTestContext::new(&server)?; + let auth = ctx.auth.clone(); + + let err = auth + .refresh_token() + .await + .err() + .context("refresh should fail")?; + assert!(matches!(err, RefreshTokenError::Transient(_))); + assert_eq!(err.failed_reason(), None); + + let stored = ctx.load_auth()?; + let tokens = stored.tokens.as_ref().context("tokens should remain")?; + assert_eq!(tokens.access_token, INITIAL_ACCESS_TOKEN); + assert_eq!(tokens.refresh_token, INITIAL_REFRESH_TOKEN); + assert_eq!( + *stored + .last_refresh + .as_ref() + .context("last_refresh should remain unchanged")?, + ctx.initial_last_refresh, + ); + + server.verify().await; + Ok(()) +} + +struct RefreshTokenTestContext { + codex_home: TempDir, + auth: CodexAuth, + initial_last_refresh: chrono::DateTime, + _env_guard: EnvGuard, +} + +impl RefreshTokenTestContext { + fn new(server: &MockServer) -> Result { + let codex_home = TempDir::new()?; + let initial_last_refresh = Utc::now() - Duration::days(1); + let mut id_token = IdTokenInfo::default(); + id_token.raw_jwt = minimal_jwt(); + let tokens = TokenData { + id_token, + access_token: INITIAL_ACCESS_TOKEN.to_string(), + refresh_token: INITIAL_REFRESH_TOKEN.to_string(), + account_id: Some("account-id".to_string()), + }; + let auth_dot_json = AuthDotJson { + openai_api_key: None, + tokens: Some(tokens), + last_refresh: Some(initial_last_refresh), + }; + save_auth( + codex_home.path(), + &auth_dot_json, + AuthCredentialsStoreMode::File, + )?; + + let endpoint = format!("{}/oauth/token", server.uri()); + let env_guard = EnvGuard::set(REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR, endpoint); + + let auth = CodexAuth::from_auth_storage(codex_home.path(), AuthCredentialsStoreMode::File)? + .context("auth should load from storage")?; + + Ok(Self { + codex_home, + auth, + initial_last_refresh, + _env_guard: env_guard, + }) + } + + fn load_auth(&self) -> Result { + load_auth_dot_json(self.codex_home.path(), AuthCredentialsStoreMode::File) + .context("load auth.json")? + .context("auth.json should exist") + } +} + +struct EnvGuard { + key: &'static str, + original: Option, +} + +impl EnvGuard { + fn set(key: &'static str, value: String) -> Self { + let original = std::env::var_os(key); + // SAFETY: these tests execute serially, so updating the process environment is safe. + unsafe { + std::env::set_var(key, &value); + } + Self { key, original } + } +} + +impl Drop for EnvGuard { + fn drop(&mut self) { + // SAFETY: the guard restores the original environment value before other tests run. + unsafe { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } + } +} + +fn minimal_jwt() -> String { + #[derive(Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = json!({ "sub": "user-123" }); + + fn b64(data: &[u8]) -> String { + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data) + } + + let header_bytes = match serde_json::to_vec(&header) { + Ok(bytes) => bytes, + Err(err) => panic!("serialize header: {err}"), + }; + let payload_bytes = match serde_json::to_vec(&payload) { + Ok(bytes) => bytes, + Err(err) => panic!("serialize payload: {err}"), + }; + let header_b64 = b64(&header_bytes); + let payload_b64 = b64(&payload_bytes); + let signature_b64 = b64(b"sig"); + format!("{header_b64}.{payload_b64}.{signature_b64}") +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index e5978511..bec4f942 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -8,6 +8,7 @@ mod apply_patch_cli; mod apply_patch_freeform; #[cfg(not(target_os = "windows"))] mod approvals; +mod auth_refresh; mod cli_stream; mod client; mod codex_delegate;