chore: refactor attempt_stream_responses() out of stream_responses() (#4194)
I would like to be able to swap in a different way to resolve model sampling requests, so this refactoring consolidates things behind `attempt_stream_responses()` to make that easier. Ideally, we would support an in-memory backend that we can use in our integration tests, for example.
This commit is contained in:
@@ -229,155 +229,169 @@ impl ModelClient {
|
|||||||
if azure_workaround {
|
if azure_workaround {
|
||||||
attach_item_ids(&mut payload_json, &input_with_instructions);
|
attach_item_ids(&mut payload_json, &input_with_instructions);
|
||||||
}
|
}
|
||||||
let payload_body = serde_json::to_string(&payload_json)?;
|
|
||||||
|
|
||||||
let mut attempt = 0;
|
let max_attempts = self.provider.request_max_retries();
|
||||||
let max_retries = self.provider.request_max_retries();
|
for attempt in 0..=max_attempts {
|
||||||
|
match self
|
||||||
loop {
|
.attempt_stream_responses(&payload_json, &auth_manager)
|
||||||
attempt += 1;
|
.await
|
||||||
|
|
||||||
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
|
||||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
|
||||||
|
|
||||||
trace!(
|
|
||||||
"POST to {}: {}",
|
|
||||||
self.provider.get_full_url(&auth),
|
|
||||||
payload_body.as_str()
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut req_builder = self
|
|
||||||
.provider
|
|
||||||
.create_request_builder(&self.client, &auth)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
req_builder = req_builder
|
|
||||||
.header("OpenAI-Beta", "responses=experimental")
|
|
||||||
// Send session_id for compatibility.
|
|
||||||
.header("conversation_id", self.conversation_id.to_string())
|
|
||||||
.header("session_id", self.conversation_id.to_string())
|
|
||||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
|
||||||
.json(&payload_json);
|
|
||||||
|
|
||||||
if let Some(auth) = auth.as_ref()
|
|
||||||
&& auth.mode == AuthMode::ChatGPT
|
|
||||||
&& let Some(account_id) = auth.get_account_id()
|
|
||||||
{
|
{
|
||||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
Ok(stream) => {
|
||||||
}
|
return Ok(stream);
|
||||||
|
|
||||||
let res = req_builder.send().await;
|
|
||||||
if let Ok(resp) = &res {
|
|
||||||
trace!(
|
|
||||||
"Response status: {}, cf-ray: {}",
|
|
||||||
resp.status(),
|
|
||||||
resp.headers()
|
|
||||||
.get("cf-ray")
|
|
||||||
.map(|v| v.to_str().unwrap_or_default())
|
|
||||||
.unwrap_or_default()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
match res {
|
|
||||||
Ok(resp) if resp.status().is_success() => {
|
|
||||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
|
||||||
|
|
||||||
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
|
||||||
&& tx_event
|
|
||||||
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
|
||||||
.await
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
debug!("receiver dropped rate limit snapshot event");
|
|
||||||
}
|
|
||||||
|
|
||||||
// spawn task to process SSE
|
|
||||||
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
|
||||||
tokio::spawn(process_sse(
|
|
||||||
stream,
|
|
||||||
tx_event,
|
|
||||||
self.provider.stream_idle_timeout(),
|
|
||||||
));
|
|
||||||
|
|
||||||
return Ok(ResponseStream { rx_event });
|
|
||||||
}
|
}
|
||||||
Ok(res) => {
|
Err(StreamAttemptError::Fatal(e)) => {
|
||||||
let status = res.status();
|
return Err(e);
|
||||||
|
}
|
||||||
// Pull out Retry‑After header if present.
|
Err(retryable_attempt_error) => {
|
||||||
let retry_after_secs = res
|
if attempt == max_attempts {
|
||||||
.headers()
|
return Err(retryable_attempt_error.into_error());
|
||||||
.get(reqwest::header::RETRY_AFTER)
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|s| s.parse::<u64>().ok());
|
|
||||||
|
|
||||||
if status == StatusCode::UNAUTHORIZED
|
|
||||||
&& let Some(manager) = auth_manager.as_ref()
|
|
||||||
&& manager.auth().is_some()
|
|
||||||
{
|
|
||||||
let _ = manager.refresh_token().await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx
|
tokio::time::sleep(retryable_attempt_error.delay(attempt)).await;
|
||||||
// errors. When we bubble early with only the HTTP status the caller sees an opaque
|
}
|
||||||
// "unexpected status 400 Bad Request" which makes debugging nearly impossible.
|
}
|
||||||
// Instead, read (and include) the response text so higher layers and users see the
|
}
|
||||||
// exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is
|
|
||||||
// small and this branch only runs on error paths so the extra allocation is
|
|
||||||
// negligible.
|
|
||||||
if !(status == StatusCode::TOO_MANY_REQUESTS
|
|
||||||
|| status == StatusCode::UNAUTHORIZED
|
|
||||||
|| status.is_server_error())
|
|
||||||
{
|
|
||||||
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
|
||||||
let body = res.text().await.unwrap_or_default();
|
|
||||||
return Err(CodexErr::UnexpectedStatus(status, body));
|
|
||||||
}
|
|
||||||
|
|
||||||
if status == StatusCode::TOO_MANY_REQUESTS {
|
unreachable!("stream_responses_attempt should always return");
|
||||||
let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers());
|
}
|
||||||
let body = res.json::<ErrorResponse>().await.ok();
|
|
||||||
if let Some(ErrorResponse { error }) = body {
|
/// Single attempt to start a streaming Responses API call.
|
||||||
if error.r#type.as_deref() == Some("usage_limit_reached") {
|
async fn attempt_stream_responses(
|
||||||
// Prefer the plan_type provided in the error message if present
|
&self,
|
||||||
// because it's more up to date than the one encoded in the auth
|
payload_json: &Value,
|
||||||
// token.
|
auth_manager: &Option<Arc<AuthManager>>,
|
||||||
let plan_type = error
|
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||||
.plan_type
|
// Always fetch the latest auth in case a prior attempt refreshed the token.
|
||||||
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||||
let resets_in_seconds = error.resets_in_seconds;
|
|
||||||
return Err(CodexErr::UsageLimitReached(UsageLimitReachedError {
|
trace!(
|
||||||
plan_type,
|
"POST to {}: {:?}",
|
||||||
resets_in_seconds,
|
self.provider.get_full_url(&auth),
|
||||||
rate_limits: rate_limit_snapshot,
|
serde_json::to_string(payload_json)
|
||||||
}));
|
);
|
||||||
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
|
||||||
return Err(CodexErr::UsageNotIncluded);
|
let mut req_builder = self
|
||||||
}
|
.provider
|
||||||
|
.create_request_builder(&self.client, &auth)
|
||||||
|
.await
|
||||||
|
.map_err(StreamAttemptError::Fatal)?;
|
||||||
|
|
||||||
|
req_builder = req_builder
|
||||||
|
.header("OpenAI-Beta", "responses=experimental")
|
||||||
|
// Send session_id for compatibility.
|
||||||
|
.header("conversation_id", self.conversation_id.to_string())
|
||||||
|
.header("session_id", self.conversation_id.to_string())
|
||||||
|
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||||
|
.json(payload_json);
|
||||||
|
|
||||||
|
if let Some(auth) = auth.as_ref()
|
||||||
|
&& auth.mode == AuthMode::ChatGPT
|
||||||
|
&& let Some(account_id) = auth.get_account_id()
|
||||||
|
{
|
||||||
|
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = req_builder.send().await;
|
||||||
|
if let Ok(resp) = &res {
|
||||||
|
trace!(
|
||||||
|
"Response status: {}, cf-ray: {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.headers()
|
||||||
|
.get("cf-ray")
|
||||||
|
.map(|v| v.to_str().unwrap_or_default())
|
||||||
|
.unwrap_or_default()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
match res {
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||||
|
|
||||||
|
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
||||||
|
&& tx_event
|
||||||
|
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
debug!("receiver dropped rate limit snapshot event");
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawn task to process SSE
|
||||||
|
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
|
||||||
|
tokio::spawn(process_sse(
|
||||||
|
stream,
|
||||||
|
tx_event,
|
||||||
|
self.provider.stream_idle_timeout(),
|
||||||
|
));
|
||||||
|
|
||||||
|
Ok(ResponseStream { rx_event })
|
||||||
|
}
|
||||||
|
Ok(res) => {
|
||||||
|
let status = res.status();
|
||||||
|
|
||||||
|
// Pull out Retry‑After header if present.
|
||||||
|
let retry_after_secs = res
|
||||||
|
.headers()
|
||||||
|
.get(reqwest::header::RETRY_AFTER)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|s| s.parse::<u64>().ok());
|
||||||
|
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
|
||||||
|
|
||||||
|
if status == StatusCode::UNAUTHORIZED
|
||||||
|
&& let Some(manager) = auth_manager.as_ref()
|
||||||
|
&& manager.auth().is_some()
|
||||||
|
{
|
||||||
|
let _ = manager.refresh_token().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The OpenAI Responses endpoint returns structured JSON bodies even for 4xx/5xx
|
||||||
|
// errors. When we bubble early with only the HTTP status the caller sees an opaque
|
||||||
|
// "unexpected status 400 Bad Request" which makes debugging nearly impossible.
|
||||||
|
// Instead, read (and include) the response text so higher layers and users see the
|
||||||
|
// exact error message (e.g. "Unknown parameter: 'input[0].metadata'"). The body is
|
||||||
|
// small and this branch only runs on error paths so the extra allocation is
|
||||||
|
// negligible.
|
||||||
|
if !(status == StatusCode::TOO_MANY_REQUESTS
|
||||||
|
|| status == StatusCode::UNAUTHORIZED
|
||||||
|
|| status.is_server_error())
|
||||||
|
{
|
||||||
|
// Surface the error body to callers. Use `unwrap_or_default` per Clippy.
|
||||||
|
let body = res.text().await.unwrap_or_default();
|
||||||
|
return Err(StreamAttemptError::Fatal(CodexErr::UnexpectedStatus(
|
||||||
|
status, body,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if status == StatusCode::TOO_MANY_REQUESTS {
|
||||||
|
let rate_limit_snapshot = parse_rate_limit_snapshot(res.headers());
|
||||||
|
let body = res.json::<ErrorResponse>().await.ok();
|
||||||
|
if let Some(ErrorResponse { error }) = body {
|
||||||
|
if error.r#type.as_deref() == Some("usage_limit_reached") {
|
||||||
|
// Prefer the plan_type provided in the error message if present
|
||||||
|
// because it's more up to date than the one encoded in the auth
|
||||||
|
// token.
|
||||||
|
let plan_type = error
|
||||||
|
.plan_type
|
||||||
|
.or_else(|| auth.as_ref().and_then(CodexAuth::get_plan_type));
|
||||||
|
let resets_in_seconds = error.resets_in_seconds;
|
||||||
|
let codex_err = CodexErr::UsageLimitReached(UsageLimitReachedError {
|
||||||
|
plan_type,
|
||||||
|
resets_in_seconds,
|
||||||
|
rate_limits: rate_limit_snapshot,
|
||||||
|
});
|
||||||
|
return Err(StreamAttemptError::Fatal(codex_err));
|
||||||
|
} else if error.r#type.as_deref() == Some("usage_not_included") {
|
||||||
|
return Err(StreamAttemptError::Fatal(CodexErr::UsageNotIncluded));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if attempt > max_retries {
|
|
||||||
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
|
||||||
return Err(CodexErr::InternalServerError);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Err(CodexErr::RetryLimit(status));
|
|
||||||
}
|
|
||||||
|
|
||||||
let delay = retry_after_secs
|
|
||||||
.map(|s| Duration::from_millis(s * 1_000))
|
|
||||||
.unwrap_or_else(|| backoff(attempt));
|
|
||||||
tokio::time::sleep(delay).await;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
if attempt > max_retries {
|
|
||||||
return Err(e.into());
|
|
||||||
}
|
|
||||||
let delay = backoff(attempt);
|
|
||||||
tokio::time::sleep(delay).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(StreamAttemptError::RetryableHttpError {
|
||||||
|
status,
|
||||||
|
retry_after,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
Err(e) => Err(StreamAttemptError::RetryableTransportError(e.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,6 +424,47 @@ impl ModelClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum StreamAttemptError {
|
||||||
|
RetryableHttpError {
|
||||||
|
status: StatusCode,
|
||||||
|
retry_after: Option<Duration>,
|
||||||
|
},
|
||||||
|
RetryableTransportError(CodexErr),
|
||||||
|
Fatal(CodexErr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamAttemptError {
|
||||||
|
/// attempt is 0-based.
|
||||||
|
fn delay(&self, attempt: u64) -> Duration {
|
||||||
|
// backoff() uses 1-based attempts.
|
||||||
|
let backoff_attempt = attempt + 1;
|
||||||
|
match self {
|
||||||
|
Self::RetryableHttpError { retry_after, .. } => {
|
||||||
|
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
|
||||||
|
}
|
||||||
|
Self::RetryableTransportError { .. } => backoff(backoff_attempt),
|
||||||
|
Self::Fatal(_) => {
|
||||||
|
// Should not be called on Fatal errors.
|
||||||
|
Duration::from_secs(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_error(self) -> CodexErr {
|
||||||
|
match self {
|
||||||
|
Self::RetryableHttpError { status, .. } => {
|
||||||
|
if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||||
|
CodexErr::InternalServerError
|
||||||
|
} else {
|
||||||
|
CodexErr::RetryLimit(status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self::RetryableTransportError(error) => error,
|
||||||
|
Self::Fatal(error) => error,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
struct SseEvent {
|
struct SseEvent {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
|
|||||||
Reference in New Issue
Block a user