diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 3de3e781..e59dbfa2 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -250,6 +250,28 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -654,6 +676,7 @@ dependencies = [ "thiserror 2.0.12", "time", "tokio", + "tokio-test", "tokio-util", "toml 0.9.1", "tracing", @@ -4516,6 +4539,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.15" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 22636102..c55d7d39 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -64,4 +64,5 @@ maplit = "1.0.2" predicates = "3" pretty_assertions = "1.4.1" tempfile = "3" +tokio-test = "0.4" wiremock = "0.6" diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 1b8e4c95..2fa182cf 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -395,9 +395,39 @@ async fn stream_from_fixture(path: impl AsRef) -> Result { #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] + use super::*; use serde_json::json; + use tokio::sync::mpsc; + use tokio_test::io::Builder as IoBuilder; + use tokio_util::io::ReaderStream; + // ──────────────────────────── + // Helpers + // ──────────────────────────── + + /// Runs the SSE parser on pre-chunked byte slices and returns every event + /// (including any final `Err` from a stream-closure check). + async fn collect_events(chunks: &[&[u8]]) -> Vec> { + let mut builder = IoBuilder::new(); + for chunk in chunks { + builder.read(chunk); + } + + let reader = builder.build(); + let stream = ReaderStream::new(reader).map_err(CodexErr::Io); + let (tx, mut rx) = mpsc::channel::>(16); + tokio::spawn(process_sse(stream, tx)); + + let mut events = Vec::new(); + while let Some(ev) = rx.recv().await { + events.push(ev); + } + events + } + + /// Builds an in-memory SSE stream from JSON fixtures and returns only the + /// successfully parsed events (panics on internal channel errors). async fn run_sse(events: Vec) -> Vec { let mut body = String::new(); for e in events { @@ -411,9 +441,11 @@ mod tests { body.push_str(&format!("event: {kind}\ndata: {e}\n\n")); } } + let (tx, mut rx) = mpsc::channel::>(8); let stream = ReaderStream::new(std::io::Cursor::new(body)).map_err(CodexErr::Io); tokio::spawn(process_sse(stream, tx)); + let mut out = Vec::new(); while let Some(ev) = rx.recv().await { out.push(ev.expect("channel closed")); @@ -421,14 +453,104 @@ mod tests { out } - /// Verifies that the SSE adapter emits the expected [`ResponseEvent`] for - /// a variety of `type` values from the Responses API. The test is written - /// table-driven style to keep additions for new event kinds trivial. - /// - /// Each `Case` supplies an input event, a predicate that must match the - /// *first* `ResponseEvent` produced by the adapter, and the total number - /// of events expected after appending a synthetic `response.completed` - /// marker that terminates the stream. + // ──────────────────────────── + // Tests from `implement-test-for-responses-api-sse-parser` + // ──────────────────────────── + + #[tokio::test] + async fn parses_items_and_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let item2 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "World"}] + } + }) + .to_string(); + + let completed = json!({ + "type": "response.completed", + "response": { "id": "resp1" } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + let sse2 = format!("event: response.output_item.done\ndata: {item2}\n\n"); + let sse3 = format!("event: response.completed\ndata: {completed}\n\n"); + + let events = collect_events(&[sse1.as_bytes(), sse2.as_bytes(), sse3.as_bytes()]).await; + + assert_eq!(events.len(), 3); + + matches!( + &events[0], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + matches!( + &events[1], + Ok(ResponseEvent::OutputItemDone(ResponseItem::Message { role, .. })) + if role == "assistant" + ); + + match &events[2] { + Ok(ResponseEvent::Completed { + response_id, + token_usage, + }) => { + assert_eq!(response_id, "resp1"); + assert!(token_usage.is_none()); + } + other => panic!("unexpected third event: {other:?}"), + } + } + + #[tokio::test] + async fn error_when_missing_completed() { + let item1 = json!({ + "type": "response.output_item.done", + "item": { + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello"}] + } + }) + .to_string(); + + let sse1 = format!("event: response.output_item.done\ndata: {item1}\n\n"); + + let events = collect_events(&[sse1.as_bytes()]).await; + + assert_eq!(events.len(), 2); + + matches!(events[0], Ok(ResponseEvent::OutputItemDone(_))); + + match &events[1] { + Err(CodexErr::Stream(msg)) => { + assert_eq!(msg, "stream closed before response.completed") + } + other => panic!("unexpected second event: {other:?}"), + } + } + + // ──────────────────────────── + // Table-driven test from `main` + // ──────────────────────────── + + /// Verifies that the adapter produces the right `ResponseEvent` for a + /// variety of incoming `type` values. #[tokio::test] async fn table_driven_event_kinds() { struct TestCase { @@ -441,11 +563,9 @@ mod tests { fn is_created(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::Created) } - fn is_output(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::OutputItemDone(_)) } - fn is_completed(ev: &ResponseEvent) -> bool { matches!(ev, ResponseEvent::Completed { .. }) } @@ -498,9 +618,14 @@ mod tests { for case in cases { let mut evs = vec![case.event]; evs.push(completed.clone()); + let out = run_sse(evs).await; assert_eq!(out.len(), case.expected_len, "case {}", case.name); - assert!((case.expect_first)(&out[0]), "case {}", case.name); + assert!( + (case.expect_first)(&out[0]), + "first event mismatch in case {}", + case.name + ); } } }