chore: refactor attempt_stream_responses() out of stream_responses() (#4194)
I would like to be able to swap in a different way to resolve model sampling requests, so this refactoring consolidates things behind `attempt_stream_responses()` to make that easier. Ideally, we would support an in-memory backend that we can use in our integration tests, for example.
This commit is contained in:
@@ -229,27 +229,52 @@ impl ModelClient {
|
|||||||
if azure_workaround {
|
if azure_workaround {
|
||||||
attach_item_ids(&mut payload_json, &input_with_instructions);
|
attach_item_ids(&mut payload_json, &input_with_instructions);
|
||||||
}
|
}
|
||||||
let payload_body = serde_json::to_string(&payload_json)?;
|
|
||||||
|
|
||||||
let mut attempt = 0;
|
let max_attempts = self.provider.request_max_retries();
|
||||||
let max_retries = self.provider.request_max_retries();
|
for attempt in 0..=max_attempts {
|
||||||
|
match self
|
||||||
|
.attempt_stream_responses(&payload_json, &auth_manager)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(stream) => {
|
||||||
|
return Ok(stream);
|
||||||
|
}
|
||||||
|
Err(StreamAttemptError::Fatal(e)) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
Err(retryable_attempt_error) => {
|
||||||
|
if attempt == max_attempts {
|
||||||
|
return Err(retryable_attempt_error.into_error());
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
|
||||||
attempt += 1;
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unreachable!("stream_responses_attempt should always return");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single attempt to start a streaming Responses API call.
|
||||||
|
async fn attempt_stream_responses(
|
||||||
|
&self,
|
||||||
|
payload_json: &Value,
|
||||||
|
auth_manager: &Option<Arc<AuthManager>>,
|
||||||
|
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
"POST to {}: {}",
|
"POST to {}: {:?}",
|
||||||
self.provider.get_full_url(&auth),
|
self.provider.get_full_url(&auth),
|
||||||
payload_body.as_str()
|
serde_json::to_string(payload_json)
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut req_builder = self
|
let mut req_builder = self
|
||||||
.provider
|
.provider
|
||||||
.create_request_builder(&self.client, &auth)
|
.create_request_builder(&self.client, &auth)
|
||||||
.await?;
|
.await
|
||||||
|
.map_err(StreamAttemptError::Fatal)?;
|
||||||
|
|
||||||
req_builder = req_builder
|
req_builder = req_builder
|
||||||
.header("OpenAI-Beta", "responses=experimental")
|
.header("OpenAI-Beta", "responses=experimental")
|
||||||
@@ -257,7 +282,7 @@ impl ModelClient {
|
|||||||
.header("conversation_id", self.conversation_id.to_string())
|
.header("conversation_id", self.conversation_id.to_string())
|
||||||
.header("session_id", self.conversation_id.to_string())
|
.header("session_id", self.conversation_id.to_string())
|
||||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||||
.json(&payload_json);
|
.json(payload_json);
|
||||||
|
|
||||||
if let Some(auth) = auth.as_ref()
|
if let Some(auth) = auth.as_ref()
|
||||||
&& auth.mode == AuthMode::ChatGPT
|
&& auth.mode == AuthMode::ChatGPT
|
||||||
@@ -299,7 +324,7 @@ impl ModelClient {
|
|||||||
self.provider.stream_idle_timeout(),
|
self.provider.stream_idle_timeout(),
|
||||||
));
|
));
|
||||||
|
|
||||||
return Ok(ResponseStream { rx_event });
|
Ok(ResponseStream { rx_event })
|
||||||
}
|
}
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
let status = res.status();
|
let status = res.status();
|
||||||
@@ -310,6 +335,7 @@ impl ModelClient {
|
|||||||
.get(reqwest::header::RETRY_AFTER)
|
.get(reqwest::header::RETRY_AFTER)
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.and_then(|s| s.parse::<u64>().ok());
|
.and_then(|s| s.parse::<u64>().ok());
|
||||||
|
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
|
||||||
|
|
||||||
if status == StatusCode::UNAUTHORIZED
|
if status == StatusCode::UNAUTHORIZED
|
||||||
&& let Some(manager) = auth_manager.as_ref()
|
&& let Some(manager) = auth_manager.as_ref()
|
||||||
@@ -331,7 +357,9 @@ impl ModelClient {
|
|||||||
{
|
{
|
||||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||||
let body = res.text().await.unwrap_or_default();
|
let body = res.text().await.unwrap_or_default();
|
||||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
|
||||||
|
status, body,
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if status == StatusCode::TOO_MANY_REQUESTS {
|
if status == StatusCode::TOO_MANY_REQUESTS {
|
||||||
@@ -346,38 +374,24 @@ impl ModelClient {
|
|||||||
.plan_type
|
.plan_type
|
||||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||||
let resets_in_seconds = error.resets_in_seconds;
|
let resets_in_seconds = error.resets_in_seconds;
|
||||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||||
plan_type,
|
plan_type,
|
||||||
resets_in_seconds,
|
resets_in_seconds,
|
||||||
rate_limits: rate_limit_snapshot,
|
rate_limits: rate_limit_snapshot,
|
||||||
}));
|
});
|
||||||
|
return Err(StreamAttemptError::Fatal(codex_err));
|
||||||
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
||||||
return Err(CodexErr::UsageNotIncluded);
|
return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if attempt > max_retries {
|
Err(StreamAttemptError::RetryableHttpError {
|
||||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
status,
|
||||||
return Err(CodexErr::InternalServerError);
|
retry_after,
|
||||||
}
|
})
|
||||||
|
|
||||||
return Err(CodexErr::RetryLimit(status));
|
|
||||||
}
|
|
||||||
|
|
||||||
let delay = retry_after_secs
|
|
||||||
.map(|s| Duration::from_millis(s * 1_000))
|
|
||||||
.unwrap_or_else(|| backoff(attempt));
|
|
||||||
tokio::time::sleep(delay).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if attempt > max_retries {
|
|
||||||
return Err(e.into());
|
|
||||||
}
|
|
||||||
let delay = backoff(attempt);
|
|
||||||
tokio::time::sleep(delay).await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,6 +424,47 @@ impl ModelClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum StreamAttemptError {
|
||||||
|
RetryableHttpError {
|
||||||
|
status: StatusCode,
|
||||||
|
retry_after: Option<Duration>,
|
||||||
|
},
|
||||||
|
RetryableTransportError(CodexErr),
|
||||||
|
Fatal(CodexErr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamAttemptError {
|
||||||
|
/// attempt is 0-based.
|
||||||
|
fn delay(&self, attempt: u64) -> Duration {
|
||||||
|
// backoff() uses 1-based attempts.
|
||||||
|
let backoff_attempt = attempt + 1;
|
||||||
|
match self {
|
||||||
|
Self::RetryableHttpError { retry_after, .. } => {
|
||||||
|
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
|
||||||
|
}
|
||||||
|
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
|
||||||
|
Self::Fatal(_) => {
|
||||||
|
// Should not be called on Fatal errors.
|
||||||
|
Duration::from_secs(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_error(self) -> CodexErr {
|
||||||
|
match self {
|
||||||
|
Self::RetryableHttpError { status, .. } => {
|
||||||
|
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||||
|
CodexErr::InternalServerError
|
||||||
|
} else {
|
||||||
|
CodexErr::RetryLimit(status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::RetryableTransportError(error) => error,
|
||||||
|
Self::Fatal(error) => error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct SseEvent {
|
struct SseEvent {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
|
|||||||
Reference in New Issue
Block a user