From e3c6903199e47c9b6263223db9f5e83de483db69 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Fri, 12 Sep 2025 13:52:15 -0700 Subject: [PATCH] Add Azure Responses API workaround (#3528) Azure Responses API doesn't work well with store:false and response items. If store = false and id is sent an error is thrown that ID is not found If store = false and id is not sent an error is thrown that ID is required Add detection for Azure urls and add a workaround to preserve reasoning item IDs and send store:true --- codex-rs/core/src/client.rs | 42 ++++++++- codex-rs/core/src/model_provider_info.rs | 92 +++++++++++++++++++ codex-rs/core/tests/suite/client.rs | 108 +++++++++++++++++++++++ 3 files changed, 239 insertions(+), 3 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index dab1bf68..6bea3820 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -193,6 +193,15 @@ impl ModelClient { None }; + // In general, we want to explicitly send `store: false` when using the Responses API, + // but in practice, the Azure Responses API rejects `store: false`: + // + // - If store = false and id is sent an error is thrown that ID is not found + // - If store = false and id is not sent an error is thrown that ID is required + // + // For Azure, we send `store: true` and preserve reasoning item IDs. + let azure_workaround = self.provider.is_azure_responses_endpoint(); + let payload = ResponsesApiRequest { model: &self.config.model, instructions: &full_instructions, @@ -201,13 +210,19 @@ impl ModelClient { tool_choice: "auto", parallel_tool_calls: false, reasoning, - store: false, + store: azure_workaround, stream: true, include, prompt_cache_key: Some(self.conversation_id.to_string()), text, }; + let mut payload_json = serde_json::to_value(&payload)?; + if azure_workaround { + attach_item_ids(&mut payload_json, &input_with_instructions); + } + let payload_body = serde_json::to_string(&payload_json)?; + let mut attempt = 0; let max_retries = self.provider.request_max_retries(); @@ -220,7 +235,7 @@ impl ModelClient { trace!( "POST to {}: {}", self.provider.get_full_url(&auth), - serde_json::to_string(&payload)? + payload_body.as_str() ); let mut req_builder = self @@ -234,7 +249,7 @@ impl ModelClient { .header("conversation_id", self.conversation_id.to_string()) .header("session_id", self.conversation_id.to_string()) .header(reqwest::header::ACCEPT, "text/event-stream") - .json(&payload); + .json(&payload_json); if let Some(auth) = auth.as_ref() && auth.mode == AuthMode::ChatGPT @@ -431,6 +446,27 @@ struct ResponseCompletedOutputTokensDetails { reasoning_tokens: u64, } +fn attach_item_ids(payload_json: &mut Value, original_items: &[ResponseItem]) { + let Some(input_value) = payload_json.get_mut("input") else { + return; + }; + let serde_json::Value::Array(items) = input_value else { + return; + }; + + for (value, item) in items.iter_mut().zip(original_items.iter()) { + if let ResponseItem::Reasoning { id, .. } = item { + if id.is_empty() { + continue; + } + + if let Some(obj) = value.as_object_mut() { + obj.insert("id".to_string(), Value::String(id.clone())); + } + } + } +} + async fn process_sse( stream: S, tx_event: mpsc::Sender>, diff --git a/codex-rs/core/src/model_provider_info.rs b/codex-rs/core/src/model_provider_info.rs index 7fca131c..2850996d 100644 --- a/codex-rs/core/src/model_provider_info.rs +++ b/codex-rs/core/src/model_provider_info.rs @@ -162,6 +162,21 @@ impl ModelProviderInfo { } } + pub(crate) fn is_azure_responses_endpoint(&self) -> bool { + if self.wire_api != WireApi::Responses { + return false; + } + + if self.name.eq_ignore_ascii_case("azure") { + return true; + } + + self.base_url + .as_ref() + .map(|base| matches_azure_responses_base_url(base)) + .unwrap_or(false) + } + /// Apply provider-specific HTTP headers (both static and environment-based) /// onto an existing `reqwest::RequestBuilder` and return the updated /// builder. @@ -329,6 +344,18 @@ pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo { } } +fn matches_azure_responses_base_url(base_url: &str) -> bool { + let base = base_url.to_ascii_lowercase(); + const AZURE_MARKERS: [&str; 5] = [ + "openai.azure.", + "cognitiveservices.azure.", + "aoai.azure.", + "azure-api.", + "azurefd.", + ]; + AZURE_MARKERS.iter().any(|marker| base.contains(marker)) +} + #[cfg(test)] mod tests { use super::*; @@ -419,4 +446,69 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" } let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap(); assert_eq!(expected_provider, provider); } + + #[test] + fn detects_azure_responses_base_urls() { + fn provider_for(base_url: &str) -> ModelProviderInfo { + ModelProviderInfo { + name: "test".into(), + base_url: Some(base_url.into()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + } + } + + let positive_cases = [ + "https://foo.openai.azure.com/openai", + "https://foo.openai.azure.us/openai/deployments/bar", + "https://foo.cognitiveservices.azure.cn/openai", + "https://foo.aoai.azure.com/openai", + "https://foo.openai.azure-api.net/openai", + "https://foo.z01.azurefd.net/", + ]; + for base_url in positive_cases { + let provider = provider_for(base_url); + assert!( + provider.is_azure_responses_endpoint(), + "expected {base_url} to be detected as Azure" + ); + } + + let named_provider = ModelProviderInfo { + name: "Azure".into(), + base_url: Some("https://example.com".into()), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: None, + stream_max_retries: None, + stream_idle_timeout_ms: None, + requires_openai_auth: false, + }; + assert!(named_provider.is_azure_responses_endpoint()); + + let negative_cases = [ + "https://api.openai.com/v1", + "https://example.com/openai", + "https://myproxy.azurewebsites.net/openai", + ]; + for base_url in negative_cases { + let provider = provider_for(base_url); + assert!( + !provider.is_azure_responses_endpoint(), + "expected {base_url} not to be detected as Azure" + ); + } + } } diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index 579f37cd..23278807 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1,18 +1,27 @@ use codex_core::CodexAuth; use codex_core::ConversationManager; +use codex_core::ModelClient; use codex_core::ModelProviderInfo; use codex_core::NewConversation; +use codex_core::Prompt; +use codex_core::ReasoningItemContent; +use codex_core::ResponseEvent; +use codex_core::ResponseItem; use codex_core::WireApi; use codex_core::built_in_model_providers; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; +use codex_protocol::mcp_protocol::ConversationId; +use codex_protocol::models::ReasoningItemReasoningSummary; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; use core_test_support::wait_for_event; +use futures::StreamExt; use serde_json::json; use std::io::Write; +use std::sync::Arc; use tempfile::TempDir; use uuid::Uuid; use wiremock::Mock; @@ -629,6 +638,105 @@ async fn includes_user_instructions_message_in_request() { assert_message_ends_with(&request_body["input"][1], ""); } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn azure_responses_request_includes_store_and_reasoning_ids() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let server = MockServer::start().await; + + let sse_body = concat!( + "data: {\"type\":\"response.created\",\"response\":{}}\n\n", + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\"}}\n\n", + ); + + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_body, "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/openai/responses")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let provider = ModelProviderInfo { + name: "azure".into(), + base_url: Some(format!("{}/openai", server.uri())), + env_key: None, + env_key_instructions: None, + wire_api: WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: Some(5_000), + requires_openai_auth: false, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.model_provider_id = provider.name.clone(); + config.model_provider = provider.clone(); + let effort = config.model_reasoning_effort; + let summary = config.model_reasoning_summary; + let config = Arc::new(config); + + let client = ModelClient::new( + Arc::clone(&config), + None, + provider, + effort, + summary, + ConversationId::new(), + ); + + let mut prompt = Prompt::default(); + prompt.input.push(ResponseItem::Reasoning { + id: "reasoning-id".into(), + summary: vec![ReasoningItemReasoningSummary::SummaryText { + text: "summary".into(), + }], + content: Some(vec![ReasoningItemContent::ReasoningText { + text: "content".into(), + }]), + encrypted_content: None, + }); + + let mut stream = client + .stream(&prompt) + .await + .expect("responses stream to start"); + + while let Some(event) = stream.next().await { + if let Ok(ResponseEvent::Completed { .. }) = event { + break; + } + } + + let requests = server + .received_requests() + .await + .expect("mock server collected requests"); + assert_eq!(requests.len(), 1, "expected a single request"); + let body: serde_json::Value = requests[0] + .body_json() + .expect("request body to be valid JSON"); + + assert_eq!(body["store"], serde_json::Value::Bool(true)); + assert_eq!(body["stream"], serde_json::Value::Bool(true)); + assert_eq!( + body["input"][0]["id"], + serde_json::Value::String("reasoning-id".into()) + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn azure_overrides_assign_properties_used_for_responses_url() { let existing_env_var_with_random_value = if cfg!(windows) { "USERNAME" } else { "USER" };