diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 39a6aff8..1d1082c2 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -229,155 +229,169 @@ impl ModelClient { if azure_workaround { attach_item_ids(&mut payload_json, &input_with_instructions); } - let payload_body = serde_json::to_string(&payload_json)?; - let mut attempt = 0; - let max_retries = self.provider.request_max_retries(); - - loop { - attempt += 1; - - // Always fetch the latest auth in case a prior attempt refreshed the token. - let auth = auth_manager.as_ref().and_then(|m| m.auth()); - - trace!( - "POST to {}: {}", - self.provider.get_full_url(&auth), - payload_body.as_str() - ); - - let mut req_builder = self - .provider - .create_request_builder(&self.client, &auth) - .await?; - - req_builder = req_builder - .header("OpenAI-Beta", "responses=experimental") - // Send session_id for compatibility. - .header("conversation_id", self.conversation_id.to_string()) - .header("session_id", self.conversation_id.to_string()) - .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload_json); - - if let Some(auth) = auth.as_ref() - && auth.mode == AuthMode::ChatGPT - && let Some(account_id) = auth.get_account_id() + let max_attempts = self.provider.request_max_retries(); + for attempt in 0..=max_attempts { + match self + .attempt_stream_responses(&payload_json, &auth_manager) + .await { - req_builder = req_builder.header("chatgpt-account-id", account_id); - } - - let res = req_builder.send().await; - if let Ok(resp) = &res { - trace!( - "Response status: {}, cf-ray: {}", - resp.status(), - resp.headers() - .get("cf-ray") - .map(|v| v.to_str().unwrap_or_default()) - .unwrap_or_default() - ); - } - - match res { - Ok(resp) if resp.status().is_success() => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - - if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) - && tx_event - .send(Ok(ResponseEvent::RateLimits(snapshot))) - .await - .is_err() - { - debug!("receiver dropped rate limit snapshot event"); - } - - // spawn task to process SSE - let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); - tokio::spawn(process_sse( - stream, - tx_event, - self.provider.stream_idle_timeout(), - )); - - return Ok(ResponseStream { rx_event }); + Ok(stream) => { + return Ok(stream); } - Ok(res) => { - let status = res.status(); - - // Pull out Retry‑After header if present. - let retry_after_secs = res - .headers() - .get(reqwest::header::RETRY_AFTER) - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.parse::().ok()); - - if status == StatusCode::UNAUTHORIZED - && let Some(manager) = auth_manager.as_ref() - && manager.auth().is_some() - { - let _ = manager.refresh_token().await; + Err(StreamAttemptError::Fatal(e)) => { + return Err(e); + } + Err(retryable_attempt_error) => { + if attempt == max_attempts { + return Err(retryable_attempt_error.into_error()); } - // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx - // errors. When we bubble early with only the HTTP status the caller sees an opaque - // "unexpected status 400 Bad Request" which makes debugging nearly impossible. - // Instead, read (and include) the response text so higher layers and users see the - // exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is - // small and this branch only runs on error paths so the extra allocation is - // negligible. - if !(status == StatusCode::TOO_MANY_REQUESTS - || status == StatusCode::UNAUTHORIZED - || status.is_server_error()) - { - // Surface the error body to callers. Use `unwrap_or_default` per Clippy. - let body = res.text().await.unwrap_or_default(); - return Err(CodexErr::UnexpectedStatus(status, body)); - } + tokio::time::sleep(retryable_attempt_error.delay(attempt)).await; + } + } + } - if status == StatusCode::TOO_MANY_REQUESTS { - let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers()); - let body = res.json::().await.ok(); - if let Some(ErrorResponse { error }) = body { - if error.r#type.as_deref() == Some("usage_limit_reached") { - // Prefer the plan_type provided in the error message if present - // because it's more up to date than the one encoded in the auth - // token. - let plan_type = error - .plan_type - .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); - let resets_in_seconds = error.resets_in_seconds; - return Err(CodexErr::UsageLimitReached(UsageLimitReachedError { - plan_type, - resets_in_seconds, - rate_limits: rate_limit_snapshot, - })); - } else if error.r#type.as_deref() == Some("usage_not_included") { - return Err(CodexErr::UsageNotIncluded); - } + unreachable!("stream_responses_attempt should always return"); + } + + /// Single attempt to start a streaming Responses API call. + async fn attempt_stream_responses( + &self, + payload_json: &Value, + auth_manager: &Option>, + ) -> std::result::Result { + // Always fetch the latest auth in case a prior attempt refreshed the token. + let auth = auth_manager.as_ref().and_then(|m| m.auth()); + + trace!( + "POST to {}: {:?}", + self.provider.get_full_url(&auth), + serde_json::to_string(payload_json) + ); + + let mut req_builder = self + .provider + .create_request_builder(&self.client, &auth) + .await + .map_err(StreamAttemptError::Fatal)?; + + req_builder = req_builder + .header("OpenAI-Beta", "responses=experimental") + // Send session_id for compatibility. + .header("conversation_id", self.conversation_id.to_string()) + .header("session_id", self.conversation_id.to_string()) + .header(reqwest::header::ACCEPT, "text/event-stream") + .json(payload_json); + + if let Some(auth) = auth.as_ref() + && auth.mode == AuthMode::ChatGPT + && let Some(account_id) = auth.get_account_id() + { + req_builder = req_builder.header("chatgpt-account-id", account_id); + } + + let res = req_builder.send().await; + if let Ok(resp) = &res { + trace!( + "Response status: {}, cf-ray: {}", + resp.status(), + resp.headers() + .get("cf-ray") + .map(|v| v.to_str().unwrap_or_default()) + .unwrap_or_default() + ); + } + + match res { + Ok(resp) if resp.status().is_success() => { + let (tx_event, rx_event) = mpsc::channel::>(1600); + + if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers()) + && tx_event + .send(Ok(ResponseEvent::RateLimits(snapshot))) + .await + .is_err() + { + debug!("receiver dropped rate limit snapshot event"); + } + + // spawn task to process SSE + let stream = resp.bytes_stream().map_err(CodexErr::Reqwest); + tokio::spawn(process_sse( + stream, + tx_event, + self.provider.stream_idle_timeout(), + )); + + Ok(ResponseStream { rx_event }) + } + Ok(res) => { + let status = res.status(); + + // Pull out Retry‑After header if present. + let retry_after_secs = res + .headers() + .get(reqwest::header::RETRY_AFTER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000)); + + if status == StatusCode::UNAUTHORIZED + && let Some(manager) = auth_manager.as_ref() + && manager.auth().is_some() + { + let _ = manager.refresh_token().await; + } + + // The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx + // errors. When we bubble early with only the HTTP status the caller sees an opaque + // "unexpected status 400 Bad Request" which makes debugging nearly impossible. + // Instead, read (and include) the response text so higher layers and users see the + // exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is + // small and this branch only runs on error paths so the extra allocation is + // negligible. + if !(status == StatusCode::TOO_MANY_REQUESTS + || status == StatusCode::UNAUTHORIZED + || status.is_server_error()) + { + // Surface the error body to callers. Use `unwrap_or_default` per Clippy. + let body = res.text().await.unwrap_or_default(); + return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus( + status, body, + ))); + } + + if status == StatusCode::TOO_MANY_REQUESTS { + let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers()); + let body = res.json::().await.ok(); + if let Some(ErrorResponse { error }) = body { + if error.r#type.as_deref() == Some("usage_limit_reached") { + // Prefer the plan_type provided in the error message if present + // because it's more up to date than the one encoded in the auth + // token. + let plan_type = error + .plan_type + .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); + let resets_in_seconds = error.resets_in_seconds; + let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError { + plan_type, + resets_in_seconds, + rate_limits: rate_limit_snapshot, + }); + return Err(StreamAttemptError::Fatal(codex_err)); + } else if error.r#type.as_deref() == Some("usage_not_included") { + return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded)); } } - - if attempt > max_retries { - if status == StatusCode::INTERNAL_SERVER_ERROR { - return Err(CodexErr::InternalServerError); - } - - return Err(CodexErr::RetryLimit(status)); - } - - let delay = retry_after_secs - .map(|s| Duration::from_millis(s * 1_000)) - .unwrap_or_else(|| backoff(attempt)); - tokio::time::sleep(delay).await; - } - Err(e) => { - if attempt > max_retries { - return Err(e.into()); - } - let delay = backoff(attempt); - tokio::time::sleep(delay).await; } + + Err(StreamAttemptError::RetryableHttpError { + status, + retry_after, + }) } + Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())), } } @@ -410,6 +424,47 @@ impl ModelClient { } } +enum StreamAttemptError { + RetryableHttpError { + status: StatusCode, + retry_after: Option, + }, + RetryableTransportError(CodexErr), + Fatal(CodexErr), +} + +impl StreamAttemptError { + /// attempt is 0-based. + fn delay(&self, attempt: u64) -> Duration { + // backoff() uses 1-based attempts. + let backoff_attempt = attempt + 1; + match self { + Self::RetryableHttpError { retry_after, .. } => { + retry_after.unwrap_or_else(|| backoff(backoff_attempt)) + } + Self::RetryableTransportError { .. } => backoff(backoff_attempt), + Self::Fatal(_) => { + // Should not be called on Fatal errors. + Duration::from_secs(0) + } + } + } + + fn into_error(self) -> CodexErr { + match self { + Self::RetryableHttpError { status, .. } => { + if status == StatusCode::INTERNAL_SERVER_ERROR { + CodexErr::InternalServerError + } else { + CodexErr::RetryLimit(status) + } + } + Self::RetryableTransportError(error) => error, + Self::Fatal(error) => error, + } + } +} + #[derive(Debug, Deserialize, Serialize)] struct SseEvent { #[serde(rename = "type")]