chore: refactor attempt_stream_responses() out of stream_responses() (#4194)

I would like to be able to swap in a different way to resolve model
sampling requests, so this refactoring consolidates things behind
`attempt_stream_responses()` to make that easier. Ideally, we would
support an in-memory backend that we can use in our integration tests,
for example.
This commit is contained in:
Michael Bolin
2025-09-25 10:34:07 -07:00
committed by GitHub
parent 103adcdf2d
commit a0c37f5d07

View File

@@ -229,27 +229,52 @@ impl ModelClient {
if azure_workaround {
attach_item_ids(&mut payload_json, &input_with_instructions);
}
let payload_body = serde_json::to_string(&payload_json)?;
let mut attempt = 0;
let max_retries = self.provider.request_max_retries();
let max_attempts = self.provider.request_max_retries();
for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(&payload_json, &auth_manager)
.await
{
Ok(stream) => {
return Ok(stream);
}
Err(StreamAttemptError::Fatal(e)) => {
return Err(e);
}
Err(retryable_attempt_error) => {
if attempt == max_attempts {
return Err(retryable_attempt_error.into_error());
}
loop {
attempt += 1;
tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
}
}
}
unreachable!("stream_responses_attempt should always return");
}
/// Single attempt to start a streaming Responses API call.
async fn attempt_stream_responses(
&self,
payload_json: &Value,
auth_manager: &Option<Arc<AuthManager>>,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
// Always fetch the latest auth in case a prior attempt refreshed the token.
let auth = auth_manager.as_ref().and_then(|m| m.auth());
trace!(
"POST to {}: {}",
"POST to {}: {:?}",
self.provider.get_full_url(&auth),
payload_body.as_str()
serde_json::to_string(payload_json)
);
let mut req_builder = self
.provider
.create_request_builder(&self.client, &auth)
.await?;
.await
.map_err(StreamAttemptError::Fatal)?;
req_builder = req_builder
.header("OpenAI-Beta", "responses=experimental")
@@ -257,7 +282,7 @@ impl ModelClient {
.header("conversation_id", self.conversation_id.to_string())
.header("session_id", self.conversation_id.to_string())
.header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload_json);
.json(payload_json);
if let Some(auth) = auth.as_ref()
&& auth.mode == AuthMode::ChatGPT
@@ -299,7 +324,7 @@ impl ModelClient {
self.provider.stream_idle_timeout(),
));
return Ok(ResponseStream { rx_event });
Ok(ResponseStream { rx_event })
}
Ok(res) => {
let status = res.status();
@@ -310,6 +335,7 @@ impl ModelClient {
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok());
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
if status == StatusCode::UNAUTHORIZED
&& let Some(manager) = auth_manager.as_ref()
@@ -331,7 +357,9 @@ impl ModelClient {
{
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
let body = res.text().await.unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body));
return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
status, body,
)));
}
if status == StatusCode::TOO_MANY_REQUESTS {
@@ -346,38 +374,24 @@ impl ModelClient {
.plan_type
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
let resets_in_seconds = error.resets_in_seconds;
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
plan_type,
resets_in_seconds,
rate_limits: rate_limit_snapshot,
}));
});
return Err(StreamAttemptError::Fatal(codex_err));
} else if error.r#type.as_deref() == Some("usage_not_included") {
return Err(CodexErr::UsageNotIncluded);
return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
}
}
}
if attempt > max_retries {
if status == StatusCode::INTERNAL_SERVER_ERROR {
return Err(CodexErr::InternalServerError);
}
return Err(CodexErr::RetryLimit(status));
}
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
Err(StreamAttemptError::RetryableHttpError {
status,
retry_after,
})
}
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
}
}
@@ -410,6 +424,47 @@ impl ModelClient {
}
}
enum StreamAttemptError {
RetryableHttpError {
status: StatusCode,
retry_after: Option<Duration>,
},
RetryableTransportError(CodexErr),
Fatal(CodexErr),
}
impl StreamAttemptError {
/// attempt is 0-based.
fn delay(&self, attempt: u64) -> Duration {
// backoff() uses 1-based attempts.
let backoff_attempt = attempt + 1;
match self {
Self::RetryableHttpError { retry_after, .. } => {
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
}
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
Self::Fatal(_) => {
// Should not be called on Fatal errors.
Duration::from_secs(0)
}
}
}
fn into_error(self) -> CodexErr {
match self {
Self::RetryableHttpError { status, .. } => {
if status == StatusCode::INTERNAL_SERVER_ERROR {
CodexErr::InternalServerError
} else {
CodexErr::RetryLimit(status)
}
}
Self::RetryableTransportError(error) => error,
Self::Fatal(error) => error,
}
}
}
#[derive(Debug, Deserialize, Serialize)]
struct SseEvent {
#[serde(rename = "type")]