chore: refactor attempt_stream_responses() out of stream_responses() (#4194)

I would like to be able to swap in a different way to resolve model
sampling requests, so this refactoring consolidates things behind
`attempt_stream_responses()` to make that easier. Ideally, we would
support an in-memory backend that we can use in our integration tests,
for example.
This commit is contained in:
Michael Bolin
2025-09-25 10:34:07 -07:00
committed by GitHub
parent 103adcdf2d
commit a0c37f5d07

View File

@@ -229,27 +229,52 @@ impl ModelClient {
if azure_workaround { if azure_workaround {
attach_item_ids(&mut payload_json, &input_with_instructions); attach_item_ids(&mut payload_json, &input_with_instructions);
} }
let payload_body = serde_json::to_string(&payload_json)?;
let mut attempt = 0; let max_attempts = self.provider.request_max_retries();
let max_retries = self.provider.request_max_retries(); for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(&payload_json, &auth_manager)
.await
{
Ok(stream) => {
return Ok(stream);
}
Err(StreamAttemptError::Fatal(e)) => {
return Err(e);
}
Err(retryable_attempt_error) => {
if attempt == max_attempts {
return Err(retryable_attempt_error.into_error());
}
loop { tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
attempt += 1; }
}
}
unreachable!("stream_responses_attempt should always return");
}
/// Single attempt to start a streaming Responses API call.
async fn attempt_stream_responses(
&self,
payload_json: &Value,
auth_manager: &Option<Arc<AuthManager>>,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
// Always fetch the latest auth in case a prior attempt refreshed the token. // Always fetch the latest auth in case a prior attempt refreshed the token.
let auth = auth_manager.as_ref().and_then(|m| m.auth()); let auth = auth_manager.as_ref().and_then(|m| m.auth());
trace!( trace!(
"POST to {}: {}", "POST to {}: {:?}",
self.provider.get_full_url(&auth), self.provider.get_full_url(&auth),
payload_body.as_str() serde_json::to_string(payload_json)
); );
let mut req_builder = self let mut req_builder = self
.provider .provider
.create_request_builder(&self.client, &auth) .create_request_builder(&self.client, &auth)
.await?; .await
.map_err(StreamAttemptError::Fatal)?;
req_builder = req_builder req_builder = req_builder
.header("OpenAI-Beta", "responses=experimental") .header("OpenAI-Beta", "responses=experimental")
@@ -257,7 +282,7 @@ impl ModelClient {
.header("conversation_id", self.conversation_id.to_string()) .header("conversation_id", self.conversation_id.to_string())
.header("session_id", self.conversation_id.to_string()) .header("session_id", self.conversation_id.to_string())
.header(reqwest::header::ACCEPT, "text/event-stream") .header(reqwest::header::ACCEPT, "text/event-stream")
.json(&payload_json); .json(payload_json);
if let Some(auth) = auth.as_ref() if let Some(auth) = auth.as_ref()
&& auth.mode == AuthMode::ChatGPT && auth.mode == AuthMode::ChatGPT
@@ -299,7 +324,7 @@ impl ModelClient {
self.provider.stream_idle_timeout(), self.provider.stream_idle_timeout(),
)); ));
return Ok(ResponseStream { rx_event }); Ok(ResponseStream { rx_event })
} }
Ok(res) => { Ok(res) => {
let status = res.status(); let status = res.status();
@@ -310,6 +335,7 @@ impl ModelClient {
.get(reqwest::header::RETRY_AFTER) .get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok()); .and_then(|s| s.parse::<u64>().ok());
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
if status == StatusCode::UNAUTHORIZED if status == StatusCode::UNAUTHORIZED
&& let Some(manager) = auth_manager.as_ref() && let Some(manager) = auth_manager.as_ref()
@@ -331,7 +357,9 @@ impl ModelClient {
{ {
// Surface the error body to callers. Use `unwrap_or_default` per Clippy. // Surface the error body to callers. Use `unwrap_or_default` per Clippy.
let body = res.text().await.unwrap_or_default(); let body = res.text().await.unwrap_or_default();
return Err(CodexErr::UnexpectedStatus(status, body)); return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
status, body,
)));
} }
if status == StatusCode::TOO_MANY_REQUESTS { if status == StatusCode::TOO_MANY_REQUESTS {
@@ -346,38 +374,24 @@ impl ModelClient {
.plan_type .plan_type
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type)); .or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
let resets_in_seconds = error.resets_in_seconds; let resets_in_seconds = error.resets_in_seconds;
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError { let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
plan_type, plan_type,
resets_in_seconds, resets_in_seconds,
rate_limits: rate_limit_snapshot, rate_limits: rate_limit_snapshot,
})); });
return Err(StreamAttemptError::Fatal(codex_err));
} else if error.r#type.as_deref() == Some("usage_not_included") { } else if error.r#type.as_deref() == Some("usage_not_included") {
return Err(CodexErr::UsageNotIncluded); return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
} }
} }
} }
if attempt > max_retries { Err(StreamAttemptError::RetryableHttpError {
if status == StatusCode::INTERNAL_SERVER_ERROR { status,
return Err(CodexErr::InternalServerError); retry_after,
} })
return Err(CodexErr::RetryLimit(status));
}
let delay = retry_after_secs
.map(|s| Duration::from_millis(s * 1_000))
.unwrap_or_else(|| backoff(attempt));
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
tokio::time::sleep(delay).await;
}
} }
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
} }
} }
@@ -410,6 +424,47 @@ impl ModelClient {
} }
} }
enum StreamAttemptError {
RetryableHttpError {
status: StatusCode,
retry_after: Option<Duration>,
},
RetryableTransportError(CodexErr),
Fatal(CodexErr),
}
impl StreamAttemptError {
/// attempt is 0-based.
fn delay(&self, attempt: u64) -> Duration {
// backoff() uses 1-based attempts.
let backoff_attempt = attempt + 1;
match self {
Self::RetryableHttpError { retry_after, .. } => {
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
}
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
Self::Fatal(_) => {
// Should not be called on Fatal errors.
Duration::from_secs(0)
}
}
}
fn into_error(self) -> CodexErr {
match self {
Self::RetryableHttpError { status, .. } => {
if status == StatusCode::INTERNAL_SERVER_ERROR {
CodexErr::InternalServerError
} else {
CodexErr::RetryLimit(status)
}
}
Self::RetryableTransportError(error) => error,
Self::Fatal(error) => error,
}
}
}
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
struct SseEvent { struct SseEvent {
#[serde(rename = "type")] #[serde(rename = "type")]