diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs index 48b82e3e..a8f754c9 100644 --- a/codex-rs/core/src/unified_exec/mod.rs +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -110,11 +110,22 @@ impl ManagedUnifiedExecSession { let buffer_clone = Arc::clone(&output_buffer); let notify_clone = Arc::clone(&output_notify); let output_task = tokio::spawn(async move { - while let Ok(chunk) = receiver.recv().await { - let mut guard = buffer_clone.lock().await; - guard.push_chunk(chunk); - drop(guard); - notify_clone.notify_waiters(); + loop { + match receiver.recv().await { + Ok(chunk) => { + let mut guard = buffer_clone.lock().await; + guard.push_chunk(chunk); + drop(guard); + notify_clone.notify_waiters(); + } + // If we lag behind the broadcast buffer, skip missed + // messages but keep the task alive to continue streaming. + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { + continue; + } + // When the sender closes, exit the task. + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } } }); diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 5afcc9f5..b8c0fadb 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -167,6 +167,131 @@ async fn unified_exec_reuses_session_via_stdin() -> Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_streams_after_lagged_output() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let script = r#"python3 - <<'PY' +import sys +import time + +chunk = b'x' * (1 << 20) +for _ in range(4): + sys.stdout.buffer.write(chunk) + sys.stdout.flush() + +time.sleep(0.2) +for _ in range(5): + sys.stdout.write("TAIL-MARKER\n") + sys.stdout.flush() + time.sleep(0.05) + +time.sleep(0.2) +PY +"#; + + let first_call_id = "uexec-lag-start"; + let first_args = serde_json::json!({ + "input": ["/bin/sh", "-c", script], + "timeout_ms": 25, + }); + + let second_call_id = "uexec-lag-poll"; + let second_args = serde_json::json!({ + "input": Vec::::new(), + "session_id": "0", + "timeout_ms": 800, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call( + first_call_id, + "unified_exec", + &serde_json::to_string(&first_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_function_call( + second_call_id, + "unified_exec", + &serde_json::to_string(&second_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "lag handled"), + ev_completed("resp-3"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "exercise lag handling".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + + let start_output = outputs + .get(first_call_id) + .expect("missing initial unified_exec output"); + let session_id = start_output["session_id"].as_str().unwrap_or_default(); + assert!( + !session_id.is_empty(), + "expected session id from initial unified_exec response" + ); + + let poll_output = outputs + .get(second_call_id) + .expect("missing poll unified_exec output"); + let poll_text = poll_output["output"].as_str().unwrap_or_default(); + assert!( + poll_text.contains("TAIL-MARKER"), + "expected poll output to contain tail marker, got {poll_text:?}" + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn unified_exec_timeout_and_followup_poll() -> Result<()> { skip_if_no_network!(Ok(()));