chore: refactor attempt_stream_responses() out of stream_responses() (#4194)
I would like to be able to swap in a different way to resolve model sampling requests, so this refactoring consolidates things behind `attempt_stream_responses()` to make that easier. Ideally, we would support an in-memory backend that we can use in our integration tests, for example.
This commit is contained in:
@@ -229,27 +229,52 @@ impl ModelClient {
|
||||
if azure_workaround {
|
||||
attach_item_ids(&mut payload_json, &input_with_instructions);
|
||||
}
|
||||
let payload_body = serde_json::to_string(&payload_json)?;
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = self.provider.request_max_retries();
|
||||
let max_attempts = self.provider.request_max_retries();
|
||||
for attempt in 0..=max_attempts {
|
||||
match self
|
||||
.attempt_stream_responses(&payload_json, &auth_manager)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => {
|
||||
return Ok(stream);
|
||||
}
|
||||
Err(StreamAttemptError::Fatal(e)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Err(retryable_attempt_error) => {
|
||||
if attempt == max_attempts {
|
||||
return Err(retryable_attempt_error.into_error());
|
||||
}
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("stream_responses_attempt should always return");
|
||||
}
|
||||
|
||||
/// Single attempt to start a streaming Responses API call.
|
||||
async fn attempt_stream_responses(
|
||||
&self,
|
||||
payload_json: &Value,
|
||||
auth_manager: &Option<Arc<AuthManager>>,
|
||||
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
|
||||
trace!(
|
||||
"POST to {}: {}",
|
||||
"POST to {}: {:?}",
|
||||
self.provider.get_full_url(&auth),
|
||||
payload_body.as_str()
|
||||
serde_json::to_string(payload_json)
|
||||
);
|
||||
|
||||
let mut req_builder = self
|
||||
.provider
|
||||
.create_request_builder(&self.client, &auth)
|
||||
.await?;
|
||||
.await
|
||||
.map_err(StreamAttemptError::Fatal)?;
|
||||
|
||||
req_builder = req_builder
|
||||
.header("OpenAI-Beta", "responses=experimental")
|
||||
@@ -257,7 +282,7 @@ impl ModelClient {
|
||||
.header("conversation_id", self.conversation_id.to_string())
|
||||
.header("session_id", self.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload_json);
|
||||
.json(payload_json);
|
||||
|
||||
if let Some(auth) = auth.as_ref()
|
||||
&& auth.mode == AuthMode::ChatGPT
|
||||
@@ -299,7 +324,7 @@ impl ModelClient {
|
||||
self.provider.stream_idle_timeout(),
|
||||
));
|
||||
|
||||
return Ok(ResponseStream { rx_event });
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
@@ -310,6 +335,7 @@ impl ModelClient {
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
|
||||
|
||||
if status == StatusCode::UNAUTHORIZED
|
||||
&& let Some(manager) = auth_manager.as_ref()
|
||||
@@ -331,7 +357,9 @@ impl ModelClient {
|
||||
{
|
||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||
let body = res.text().await.unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
||||
return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
|
||||
status, body,
|
||||
)));
|
||||
}
|
||||
|
||||
if status == StatusCode::TOO_MANY_REQUESTS {
|
||||
@@ -346,38 +374,24 @@ impl ModelClient {
|
||||
.plan_type
|
||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||
let resets_in_seconds = error.resets_in_seconds;
|
||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||
plan_type,
|
||||
resets_in_seconds,
|
||||
rate_limits: rate_limit_snapshot,
|
||||
}));
|
||||
});
|
||||
return Err(StreamAttemptError::Fatal(codex_err));
|
||||
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
||||
return Err(CodexErr::UsageNotIncluded);
|
||||
return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||
return Err(CodexErr::InternalServerError);
|
||||
}
|
||||
|
||||
return Err(CodexErr::RetryLimit(status));
|
||||
}
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(e.into());
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(StreamAttemptError::RetryableHttpError {
|
||||
status,
|
||||
retry_after,
|
||||
})
|
||||
}
|
||||
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -410,6 +424,47 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamAttemptError {
|
||||
RetryableHttpError {
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
RetryableTransportError(CodexErr),
|
||||
Fatal(CodexErr),
|
||||
}
|
||||
|
||||
impl StreamAttemptError {
|
||||
/// attempt is 0-based.
|
||||
fn delay(&self, attempt: u64) -> Duration {
|
||||
// backoff() uses 1-based attempts.
|
||||
let backoff_attempt = attempt + 1;
|
||||
match self {
|
||||
Self::RetryableHttpError { retry_after, .. } => {
|
||||
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
|
||||
}
|
||||
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
|
||||
Self::Fatal(_) => {
|
||||
// Should not be called on Fatal errors.
|
||||
Duration::from_secs(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_error(self) -> CodexErr {
|
||||
match self {
|
||||
Self::RetryableHttpError { status, .. } => {
|
||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||
CodexErr::InternalServerError
|
||||
} else {
|
||||
CodexErr::RetryLimit(status)
|
||||
}
|
||||
}
|
||||
Self::RetryableTransportError(error) => error,
|
||||
Self::Fatal(error) => error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
|
||||
Reference in New Issue
Block a user