Send limits when getting rate limited (#4102)

Users need visibility on rate limits when they are rate limited.
This commit is contained in:
Ahmed Ibrahim
2025-09-23 15:56:34 -07:00
committed by GitHub
parent fdb8dadcae
commit 8227a5ba1b
8 changed files with 186 additions and 46 deletions

View File

@@ -100,7 +100,7 @@ use crate::protocol::ListCustomPromptsResponseEvent;
use crate::protocol::Op;
use crate::protocol::PatchApplyBeginEvent;
use crate::protocol::PatchApplyEndEvent;
use crate::protocol::RateLimitSnapshotEvent;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::ReviewDecision;
use crate::protocol::ReviewOutputEvent;
use crate::protocol::SandboxPolicy;
@@ -261,7 +261,7 @@ struct State {
pending_input: Vec<ResponseInputItem>,
history: ConversationHistory,
token_info: Option<TokenUsageInfo>,
latest_rate_limits: Option<RateLimitSnapshotEvent>,
latest_rate_limits: Option<RateLimitSnapshot>,
}
/// Context for an initialized model agent
@@ -739,31 +739,42 @@ impl Session {
async fn update_token_usage_info(
&self,
sub_id: &str,
turn_context: &TurnContext,
token_usage: Option<&TokenUsage>,
) {
let mut state = self.state.lock().await;
if let Some(token_usage) = token_usage {
let info = TokenUsageInfo::new_or_append(
&state.token_info,
&Some(token_usage.clone()),
turn_context.client.get_model_context_window(),
);
state.token_info = info;
{
let mut state = self.state.lock().await;
if let Some(token_usage) = token_usage {
let info = TokenUsageInfo::new_or_append(
&state.token_info,
&Some(token_usage.clone()),
turn_context.client.get_model_context_window(),
);
state.token_info = info;
}
}
self.send_token_count_event(sub_id).await;
}
async fn update_rate_limits(&self, new_rate_limits: RateLimitSnapshotEvent) {
let mut state = self.state.lock().await;
state.latest_rate_limits = Some(new_rate_limits);
async fn update_rate_limits(&self, sub_id: &str, new_rate_limits: RateLimitSnapshot) {
{
let mut state = self.state.lock().await;
state.latest_rate_limits = Some(new_rate_limits);
}
self.send_token_count_event(sub_id).await;
}
async fn get_token_count_event(&self) -> TokenCountEvent {
let state = self.state.lock().await;
TokenCountEvent {
info: state.token_info.clone(),
rate_limits: state.latest_rate_limits.clone(),
}
async fn send_token_count_event(&self, sub_id: &str) {
let (info, rate_limits) = {
let state = self.state.lock().await;
(state.token_info.clone(), state.latest_rate_limits.clone())
};
let event = Event {
id: sub_id.to_string(),
msg: EventMsg::TokenCount(TokenCountEvent { info, rate_limits }),
};
self.send_event(event).await;
}
/// Record a user input item to conversation history and also persist a
@@ -1957,9 +1968,14 @@ async fn run_turn(
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
Err(e @ (CodexErr::UsageLimitReached(_) | CodexErr::UsageNotIncluded)) => {
return Err(e);
Err(CodexErr::UsageLimitReached(e)) => {
let rate_limits = e.rate_limits.clone();
if let Some(rate_limits) = rate_limits {
sess.update_rate_limits(&sub_id, rate_limits).await;
}
return Err(CodexErr::UsageLimitReached(e));
}
Err(CodexErr::UsageNotIncluded) => return Err(CodexErr::UsageNotIncluded),
Err(e) => {
// Use the configured provider-specific stream retry budget.
let max_retries = turn_context.client.get_provider().stream_max_retries();
@@ -2132,20 +2148,13 @@ async fn try_run_turn(
ResponseEvent::RateLimits(snapshot) => {
// Update internal state with latest rate limits, but defer sending until
// token usage is available to avoid duplicate TokenCount events.
sess.update_rate_limits(snapshot).await;
sess.update_rate_limits(sub_id, snapshot).await;
}
ResponseEvent::Completed {
response_id: _,
token_usage,
} => {
sess.update_token_usage_info(turn_context, token_usage.as_ref())
.await;
let token_event = sess.get_token_count_event().await;
let _ = sess
.send_event(Event {
id: sub_id.to_string(),
msg: EventMsg::TokenCount(token_event),
})
sess.update_token_usage_info(sub_id, turn_context, token_usage.as_ref())
.await;
let unified_diff = turn_diff_tracker.get_unified_diff();