From 90ef94d3b382037fe030dc229bd71d942a316497 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Sat, 4 Oct 2025 18:40:06 -0700 Subject: [PATCH] Surface context window error to the client (#4675) In the past, we were treating `input exceeded context window` as a streaming error and retrying on it. Retrying on it has no point because it won't change the behavior. In this PR, we surface the error to the client without retry and also send a token count event to indicate that the context window is full. image --- codex-rs/core/src/client.rs | 83 ++++++++++++++++++++- codex-rs/core/src/codex.rs | 15 ++++ codex-rs/core/src/codex/compact.rs | 12 +++ codex-rs/core/src/error.rs | 5 ++ codex-rs/core/src/state/session.rs | 9 +++ codex-rs/core/tests/common/responses.rs | 10 +++ codex-rs/core/tests/suite/client.rs | 99 +++++++++++++++++++++++++ codex-rs/protocol/src/protocol.rs | 25 +++++++ 8 files changed, 254 insertions(+), 4 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 7cf60f56..bf1919ce 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -63,7 +63,6 @@ struct ErrorResponse { #[derive(Debug, Deserialize)] struct Error { r#type: Option, - #[allow(dead_code)] code: Option, message: Option, @@ -794,9 +793,13 @@ async fn process_sse( if let Some(error) = error { match serde_json::from_value::(error.clone()) { Ok(error) => { - let delay = try_parse_retry_after(&error); - let message = error.message.unwrap_or_default(); - response_error = Some(CodexErr::Stream(message, delay)); + if is_context_window_error(&error) { + response_error = Some(CodexErr::ContextWindowExceeded); + } else { + let delay = try_parse_retry_after(&error); + let message = error.message.clone().unwrap_or_default(); + response_error = Some(CodexErr::Stream(message, delay)); + } } Err(e) => { let error = format!("failed to parse ErrorResponse: {e}"); @@ -922,6 +925,10 @@ fn try_parse_retry_after(err: &Error) -> Option { None } +fn is_context_window_error(error: &Error) -> bool { + error.code.as_deref() == Some("context_length_exceeded") +} + #[cfg(test)] mod tests { use super::*; @@ -1179,6 +1186,74 @@ mod tests { } } + #[tokio::test] + async fn context_window_error_is_fatal() { + let raw_error = r#"{"type":"response.failed","sequence_number":3,"response":{"id":"resp_5c66275b97b9baef1ed95550adb3b7ec13b17aafd1d2f11b","object":"response","created_at":1759510079,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."},"usage":null,"user":null,"metadata":{}}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://test.com".to_string()), + env_key: Some("TEST_API_KEY".to_string()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_openai_auth: false, + }; + + let otel_event_manager = otel_event_manager(); + + let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; + + assert_eq!(events.len(), 1); + + match &events[0] { + Err(err @ CodexErr::ContextWindowExceeded) => { + assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); + } + other => panic!("unexpected context window event: {other:?}"), + } + } + + #[tokio::test] + async fn context_window_error_with_newline_is_fatal() { + let raw_error = r#"{"type":"response.failed","sequence_number":4,"response":{"id":"resp_fatal_newline","object":"response","created_at":1759510080,"status":"failed","background":false,"error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try\nagain."},"usage":null,"user":null,"metadata":{}}}"#; + + let sse1 = format!("event: response.failed\ndata: {raw_error}\n\n"); + let provider = ModelProviderInfo { + name: "test".to_string(), + base_url: Some("https://test.com".to_string()), + env_key: Some("TEST_API_KEY".to_string()), + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(1000), + requires_openai_auth: false, + }; + + let otel_event_manager = otel_event_manager(); + + let events = collect_events(&[sse1.as_bytes()], provider, otel_event_manager).await; + + assert_eq!(events.len(), 1); + + match &events[0] { + Err(err @ CodexErr::ContextWindowExceeded) => { + assert_eq!(err.to_string(), CodexErr::ContextWindowExceeded.to_string()); + } + other => panic!("unexpected context window event: {other:?}"), + } + } + // ──────────────────────────── // Table-driven test from `main` // ──────────────────────────── diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2f13c5ba..84c52d91 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -782,6 +782,17 @@ impl Session { self.send_event(event).await; } + async fn set_total_tokens_full(&self, sub_id: &str, turn_context: &TurnContext) { + let context_window = turn_context.client.get_model_context_window(); + if let Some(context_window) = context_window { + { + let mut state = self.state.lock().await; + state.set_token_usage_full(context_window); + } + self.send_token_count_event(sub_id).await; + } + } + /// Record a user input item to conversation history and also persist a /// corresponding UserMessage EventMsg to rollout. async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) { @@ -1938,6 +1949,10 @@ async fn run_turn( Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), + Err(e @ CodexErr::ContextWindowExceeded) => { + sess.set_total_tokens_full(&sub_id, turn_context).await; + return Err(e); + } Err(CodexErr::UsageLimitReached(e)) => { let rate_limits = e.rate_limits.clone(); if let Some(rate_limits) = rate_limits { diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index 136e68e4..40c9da7b 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -103,6 +103,18 @@ async fn run_compact_task_inner( Err(CodexErr::Interrupted) => { return; } + Err(e @ CodexErr::ContextWindowExceeded) => { + sess.set_total_tokens_full(&sub_id, turn_context.as_ref()) + .await; + let event = Event { + id: sub_id.clone(), + msg: EventMsg::Error(ErrorEvent { + message: e.to_string(), + }), + }; + sess.send_event(event).await; + return; + } Err(e) => { if retries < max_retries { retries += 1; diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index aa093379..6fad448b 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -55,6 +55,11 @@ pub enum CodexErr { #[error("stream disconnected before completion: {0}")] Stream(String, Option), + #[error( + "Codex ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying." + )] + ContextWindowExceeded, + #[error("no conversation with id: {0}")] ConversationNotFound(ConversationId), diff --git a/codex-rs/core/src/state/session.rs b/codex-rs/core/src/state/session.rs index f170a10c..8310d91c 100644 --- a/codex-rs/core/src/state/session.rs +++ b/codex-rs/core/src/state/session.rs @@ -64,5 +64,14 @@ impl SessionState { (self.token_info.clone(), self.latest_rate_limits.clone()) } + pub(crate) fn set_token_usage_full(&mut self, context_window: u64) { + match &mut self.token_info { + Some(info) => info.fill_to_context_window(context_window), + None => { + self.token_info = Some(TokenUsageInfo::full_context_window(context_window)); + } + } + } + // Pending input/approval moved to TurnState. } diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index b13e7599..f3b4d6af 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -135,6 +135,16 @@ pub fn ev_apply_patch_function_call(call_id: &str, patch: &str) -> Value { }) } +pub fn sse_failed(id: &str, code: &str, message: &str) -> String { + sse(vec![serde_json::json!({ + "type": "response.failed", + "response": { + "id": id, + "error": {"code": code, "message": message} + } + })]) +} + pub fn sse_response(body: String) -> ResponseTemplate { ResponseTemplate::new(200) .insert_header("content-type", "text/event-stream") diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 7157a105..c49c38e3 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -14,6 +14,8 @@ use codex_core::ResponseEvent; use codex_core::ResponseItem; use codex_core::WireApi; use codex_core::built_in_model_providers; +use codex_core::error::CodexErr; +use codex_core::model_family::find_family_for_model; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; @@ -26,8 +28,10 @@ use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; use core_test_support::responses; use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use core_test_support::wait_for_event_with_timeout; use futures::StreamExt; use serde_json::json; use std::io::Write; @@ -37,6 +41,7 @@ use uuid::Uuid; use wiremock::Mock; use wiremock::MockServer; use wiremock::ResponseTemplate; +use wiremock::matchers::body_string_contains; use wiremock::matchers::header_regex; use wiremock::matchers::method; use wiremock::matchers::path; @@ -996,6 +1001,100 @@ async fn usage_limit_error_emits_rate_limit_event() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn context_window_error_sets_total_tokens_to_model_window() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + let server = MockServer::start().await; + + responses::mount_sse_once_match( + &server, + body_string_contains("trigger context window"), + responses::sse_failed( + "resp_context_window", + "context_length_exceeded", + "Your input exceeds the context window of this model. Please adjust your input and try again.", + ), + ) + .await; + + responses::mount_sse_once_match( + &server, + body_string_contains("seed turn"), + sse_completed("resp_seed"), + ) + .await; + + let TestCodex { codex, .. } = test_codex() + .with_config(|config| { + config.model = "gpt-5".to_string(); + config.model_family = find_family_for_model("gpt-5").expect("known gpt-5 model family"); + config.model_context_window = Some(272_000); + }) + .build(&server) + .await?; + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "seed turn".into(), + }], + }) + .await?; + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "trigger context window".into(), + }], + }) + .await?; + + use std::time::Duration; + + let token_event = wait_for_event_with_timeout( + &codex, + |event| { + matches!( + event, + EventMsg::TokenCount(payload) + if payload.info.as_ref().is_some_and(|info| { + info.model_context_window == Some(info.total_token_usage.total_tokens) + && info.total_token_usage.total_tokens > 0 + }) + ) + }, + Duration::from_secs(5), + ) + .await; + + let EventMsg::TokenCount(token_payload) = token_event else { + unreachable!("wait_for_event_with_timeout returned unexpected event"); + }; + + let info = token_payload + .info + .expect("token usage info present when context window is exceeded"); + + assert_eq!(info.model_context_window, Some(272_000)); + assert_eq!(info.total_token_usage.total_tokens, 272_000); + + let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await; + let expected_context_window_message = CodexErr::ContextWindowExceeded.to_string(); + assert!( + matches!( + error_event, + EventMsg::Error(ref err) if err.message == expected_context_window_message + ), + "expected context window error; got {error_event:?}" + ); + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn azure_overrides_assign_properties_used_for_responses_url() { skip_if_no_network!(); diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index b6b279e0..bdd1d3e2 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -590,6 +590,31 @@ impl TokenUsageInfo { self.total_token_usage.add_assign(last); self.last_token_usage = last.clone(); } + + pub fn fill_to_context_window(&mut self, context_window: u64) { + let previous_total = self.total_token_usage.total_tokens; + let delta = context_window.saturating_sub(previous_total); + + self.model_context_window = Some(context_window); + self.total_token_usage = TokenUsage { + total_tokens: context_window, + ..TokenUsage::default() + }; + self.last_token_usage = TokenUsage { + total_tokens: delta, + ..TokenUsage::default() + }; + } + + pub fn full_context_window(context_window: u64) -> Self { + let mut info = Self { + total_token_usage: TokenUsage::default(), + last_token_usage: TokenUsage::default(), + model_context_window: Some(context_window), + }; + info.fill_to_context_window(context_window); + info + } } #[derive(Debug, Clone, Deserialize, Serialize, TS)]