Refactor env settings into config (#1601)

## Summary
- add OpenAI retry and timeout fields to Config
- inject these settings in tests instead of mutating env vars
- plumb Config values through client and chat completions logic
- document new configuration options

## Testing
- `cargo test -p codex-core --no-run`

------
https://chatgpt.com/codex/tasks/task_i_68792c5b04cc832195c03050c8b6ea94

---------

Co-authored-by: Michael Bolin <mbolin@openai.com>
This commit is contained in:
aibrahim-oai
2025-07-18 12:12:39 -07:00
committed by GitHub
parent d5a2148deb
commit 9846adeabf
12 changed files with 228 additions and 94 deletions

View File

@@ -30,8 +30,6 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
use crate::error::CodexErr;
use crate::error::Result;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::models::ResponseItem;
@@ -113,7 +111,7 @@ impl ModelClient {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
// short circuit for tests
warn!(path, "Streaming from fixture");
return stream_from_fixture(path).await;
return stream_from_fixture(path, self.provider.clone()).await;
}
let full_instructions = prompt.get_full_instructions(&self.config.model);
@@ -140,6 +138,7 @@ impl ModelClient {
);
let mut attempt = 0;
let max_retries = self.provider.request_max_retries();
loop {
attempt += 1;
@@ -158,7 +157,11 @@ impl ModelClient {
// spawn task to process SSE
let stream = resp.bytes_stream().map_err(CodexErr::Reqwest);
tokio::spawn(process_sse(stream, tx_event));
tokio::spawn(process_sse(
stream,
tx_event,
self.provider.stream_idle_timeout(),
));
return Ok(ResponseStream { rx_event });
}
@@ -177,7 +180,7 @@ impl ModelClient {
return Err(CodexErr::UnexpectedStatus(status, body));
}
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(CodexErr::RetryLimit(status));
}
@@ -194,7 +197,7 @@ impl ModelClient {
tokio::time::sleep(delay).await;
}
Err(e) => {
if attempt > *OPENAI_REQUEST_MAX_RETRIES {
if attempt > max_retries {
return Err(e.into());
}
let delay = backoff(attempt);
@@ -203,6 +206,10 @@ impl ModelClient {
}
}
}
pub fn get_provider(&self) -> ModelProviderInfo {
self.provider.clone()
}
}
#[derive(Debug, Deserialize, Serialize)]
@@ -254,14 +261,16 @@ struct ResponseCompletedOutputTokensDetails {
reasoning_tokens: u64,
}
async fn process_sse<S>(stream: S, tx_event: mpsc::Sender<Result<ResponseEvent>>)
where
async fn process_sse<S>(
stream: S,
tx_event: mpsc::Sender<Result<ResponseEvent>>,
idle_timeout: Duration,
) where
S: Stream<Item = Result<Bytes>> + Unpin,
{
let mut stream = stream.eventsource();
// If the stream stays completely silent for an extended period treat it as disconnected.
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// The response id returned from the "complete" message.
let mut response_completed: Option<ResponseCompleted> = None;
@@ -322,7 +331,7 @@ where
// duplicated `output` array embedded in the `response.completed`
// payload. That produced two concrete issues:
// 1. No realtime streaming the user only saw output after the
// entire turn had finished, which broke the typing UX and
// entire turn had finished, which broke the "typing" UX and
// made longrunning turns look stalled.
// 2. Duplicate `function_call_output` items both the
// individual *and* the completed array were forwarded, which
@@ -395,7 +404,10 @@ where
}
/// used in tests to stream from a text SSE file
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
async fn stream_from_fixture(
path: impl AsRef<Path>,
provider: ModelProviderInfo,
) -> Result<ResponseStream> {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
let f = std::fs::File::open(path.as_ref())?;
let lines = std::io::BufReader::new(f).lines();
@@ -409,7 +421,11 @@ async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
let rdr = std::io::Cursor::new(content);
let stream = ReaderStream::new(rdr).map_err(CodexErr::Io);
tokio::spawn(process_sse(stream, tx_event));
tokio::spawn(process_sse(
stream,
tx_event,
provider.stream_idle_timeout(),
));
Ok(ResponseStream { rx_event })
}
@@ -429,7 +445,10 @@ mod tests {
/// Runs the SSE parser on pre-chunked byte slices and returns every event
/// (including any final `Err` from a stream-closure check).
async fn collect_events(chunks: &[&[u8]]) -> Vec<Result<ResponseEvent>> {
async fn collect_events(
chunks: &[&[u8]],
provider: ModelProviderInfo,
) -> Vec<Result<ResponseEvent>> {
let mut builder = IoBuilder::new();
for chunk in chunks {
builder.read(chunk);
@@ -438,7 +457,7 @@ mod tests {
let reader = builder.build();
let stream = ReaderStream::new(reader).map_err(CodexErr::Io);
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(16);
tokio::spawn(process_sse(stream, tx));
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
@@ -449,7 +468,10 @@ mod tests {
/// Builds an in-memory SSE stream from JSON fixtures and returns only the
/// successfully parsed events (panics on internal channel errors).
async fn run_sse(events: Vec<serde_json::Value>) -> Vec<ResponseEvent> {
async fn run_sse(
events: Vec<serde_json::Value>,
provider: ModelProviderInfo,
) -> Vec<ResponseEvent> {
let mut body = String::new();
for e in events {
let kind = e
@@ -465,7 +487,7 @@ mod tests {
let (tx, mut rx) = mpsc::channel::<Result<ResponseEvent>>(8);
let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io);
tokio::spawn(process_sse(stream, tx));
tokio::spawn(process_sse(stream, tx, provider.stream_idle_timeout()));
let mut out = Vec::new();
while let Some(ev) = rx.recv().await {
@@ -510,7 +532,25 @@ mod tests {
let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n");
let sse3 = format!("event: response.completed\ndata: {completed}\n\n");
let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await;
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let events = collect_events(
&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()],
provider,
)
.await;
assert_eq!(events.len(), 3);
@@ -551,8 +591,21 @@ mod tests {
.to_string();
let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n");
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let events = collect_events(&[sse1.as_bytes()]).await;
let events = collect_events(&[sse1.as_bytes()], provider).await;
assert_eq!(events.len(), 2);
@@ -640,7 +693,21 @@ mod tests {
let mut evs = vec![case.event];
evs.push(completed.clone());
let out = run_sse(evs).await;
let provider = ModelProviderInfo {
name: "test".to_string(),
base_url: "https://test.com".to_string(),
env_key: Some("TEST_API_KEY".to_string()),
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(1000),
};
let out = run_sse(evs, provider).await;
assert_eq!(out.len(), case.expected_len, "case {}", case.name);
assert!(
(case.expect_first)(&out[0]),