diff --git a/AGENTS.md b/AGENTS.md index accfe447..832f1d65 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -73,3 +73,28 @@ If you don’t have the tool: ### Test assertions - Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already. + +### Integration tests (core) + +- Prefer the utilities in `core_test_support::responses` when writing end-to-end Codex tests. + +- All `mount_sse*` helpers return a `ResponseMock`; hold onto it so you can assert against outbound `/responses` POST bodies. +- Use `ResponseMock::single_request()` when a test should only issue one POST, or `ResponseMock::requests()` to inspect every captured `ResponsesRequest`. +- `ResponsesRequest` exposes helpers (`body_json`, `input`, `function_call_output`, `custom_tool_call_output`, `call_output`, `header`, `path`, `query_param`) so assertions can target structured payloads instead of manual JSON digging. +- Build SSE payloads with the provided `ev_*` constructors and the `sse(...)`. + +- Typical pattern: + + ```rust + let mock = responses::mount_sse_once(&server, responses::sse(vec![ + responses::ev_response_created("resp-1"), + responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + responses::ev_completed("resp-1"), + ])).await; + + codex.submit(Op::UserTurn { ... }).await?; + + // Assert request body if needed. + let request = mock.single_request(); + // assert using request.function_call_output(call_id) or request.json_body() or other helpers. + ``` diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index 24ea824a..98b3eca1 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -1,11 +1,105 @@ +use std::sync::Arc; +use std::sync::Mutex; + use serde_json::Value; use wiremock::BodyPrintLimit; +use wiremock::Match; use wiremock::Mock; +use wiremock::MockBuilder; use wiremock::MockServer; use wiremock::Respond; use wiremock::ResponseTemplate; use wiremock::matchers::method; -use wiremock::matchers::path; +use wiremock::matchers::path_regex; + +#[derive(Debug, Clone)] +pub struct ResponseMock { + requests: Arc>>, +} + +impl ResponseMock { + fn new() -> Self { + Self { + requests: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn single_request(&self) -> ResponsesRequest { + let requests = self.requests.lock().unwrap(); + if requests.len() != 1 { + panic!("expected 1 request, got {}", requests.len()); + } + requests.first().unwrap().clone() + } + + pub fn requests(&self) -> Vec { + self.requests.lock().unwrap().clone() + } +} + +#[derive(Debug, Clone)] +pub struct ResponsesRequest(wiremock::Request); + +impl ResponsesRequest { + pub fn body_json(&self) -> Value { + self.0.body_json().unwrap() + } + + pub fn input(&self) -> Vec { + self.0.body_json::().unwrap()["input"] + .as_array() + .expect("input array not found in request") + .clone() + } + + pub fn function_call_output(&self, call_id: &str) -> Value { + self.call_output(call_id, "function_call_output") + } + + pub fn custom_tool_call_output(&self, call_id: &str) -> Value { + self.call_output(call_id, "custom_tool_call_output") + } + + pub fn call_output(&self, call_id: &str, call_type: &str) -> Value { + self.input() + .iter() + .find(|item| { + item.get("type").unwrap() == call_type && item.get("call_id").unwrap() == call_id + }) + .cloned() + .unwrap_or_else(|| panic!("function call output {call_id} item not found in request")) + } + + pub fn header(&self, name: &str) -> Option { + self.0 + .headers + .get(name) + .and_then(|v| v.to_str().ok()) + .map(str::to_string) + } + + pub fn path(&self) -> String { + self.0.url.path().to_string() + } + + pub fn query_param(&self, name: &str) -> Option { + self.0 + .url + .query_pairs() + .find(|(k, _)| k == name) + .map(|(_, v)| v.to_string()) + } +} + +impl Match for ResponseMock { + fn matches(&self, request: &wiremock::Request) -> bool { + self.requests + .lock() + .unwrap() + .push(ResponsesRequest(request.clone())); + true + } +} /// Build an SSE stream body from a list of JSON events. pub fn sse(events: Vec) -> String { @@ -161,34 +255,40 @@ pub fn sse_response(body: String) -> ResponseTemplate { .set_body_raw(body, "text/event-stream") } -pub async fn mount_sse_once_match(server: &MockServer, matcher: M, body: String) +fn base_mock() -> (MockBuilder, ResponseMock) { + let response_mock = ResponseMock::new(); + let mock = Mock::given(method("POST")) + .and(path_regex(".*/responses$")) + .and(response_mock.clone()); + (mock, response_mock) +} + +pub async fn mount_sse_once_match(server: &MockServer, matcher: M, body: String) -> ResponseMock where M: wiremock::Match + Send + Sync + 'static, { - Mock::given(method("POST")) - .and(path("/v1/responses")) - .and(matcher) + let (mock, response_mock) = base_mock(); + mock.and(matcher) .respond_with(sse_response(body)) .up_to_n_times(1) .mount(server) .await; + response_mock } -pub async fn mount_sse_once(server: &MockServer, body: String) { - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(sse_response(body)) - .expect(1) +pub async fn mount_sse_once(server: &MockServer, body: String) -> ResponseMock { + let (mock, response_mock) = base_mock(); + mock.respond_with(sse_response(body)) + .up_to_n_times(1) .mount(server) .await; + response_mock } -pub async fn mount_sse(server: &MockServer, body: String) { - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(sse_response(body)) - .mount(server) - .await; +pub async fn mount_sse(server: &MockServer, body: String) -> ResponseMock { + let (mock, response_mock) = base_mock(); + mock.respond_with(sse_response(body)).mount(server).await; + response_mock } pub async fn start_mock_server() -> MockServer { @@ -201,7 +301,7 @@ pub async fn start_mock_server() -> MockServer { /// Mounts a sequence of SSE response bodies and serves them in order for each /// POST to `/v1/responses`. Panics if more requests are received than bodies /// provided. Also asserts the exact number of expected calls. -pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec) { +pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec) -> ResponseMock { use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; @@ -228,10 +328,11 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec) { responses: bodies, }; - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(responder) + let (mock, response_mock) = base_mock(); + mock.respond_with(responder) .expect(num_calls as u64) .mount(server) .await; + + response_mock } diff --git a/codex-rs/core/tests/suite/cli_stream.rs b/codex-rs/core/tests/suite/cli_stream.rs index 8fc36772..f9408d5a 100644 --- a/codex-rs/core/tests/suite/cli_stream.rs +++ b/codex-rs/core/tests/suite/cli_stream.rs @@ -106,16 +106,12 @@ async fn exec_cli_applies_experimental_instructions_file() { "data: {\"type\":\"response.created\",\"response\":{}}\n\n", "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"r1\"}}\n\n" ); - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with( - ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse, "text/event-stream"), - ) - .expect(1) - .mount(&server) - .await; + let resp_mock = core_test_support::responses::mount_sse_once_match( + &server, + path("/v1/responses"), + sse.to_string(), + ) + .await; // Create a temporary instructions file with a unique marker we can assert // appears in the outbound request payload. @@ -164,8 +160,8 @@ async fn exec_cli_applies_experimental_instructions_file() { // Inspect the captured request and verify our custom base instructions were // included in the `instructions` field. - let request = &server.received_requests().await.unwrap()[0]; - let body = request.body_json::().unwrap(); + let request = resp_mock.single_request(); + let body = request.body_json(); let instructions = body .get("instructions") .and_then(|v| v.as_str()) diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index c49c38e3..eb14dabb 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -223,15 +223,9 @@ async fn resume_includes_initial_messages_and_sends_prior_items() { // Mock server that will receive the resumed request let server = MockServer::start().await; - let first = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse_completed("resp1"), "text/event-stream"); - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(first) - .expect(1) - .mount(&server) - .await; + let resp_mock = + responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1")) + .await; // Configure Codex to resume from our file let model_provider = ModelProviderInfo { @@ -277,8 +271,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() { .unwrap(); wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let request = &server.received_requests().await.unwrap()[0]; - let request_body = request.body_json::().unwrap(); + let request = resp_mock.single_request(); + let request_body = request.body_json(); let expected_input = json!([ { "type": "message", @@ -372,18 +366,9 @@ async fn includes_base_instructions_override_in_request() { skip_if_no_network!(); // Mock server let server = MockServer::start().await; - - // First request – must NOT include `previous_response_id`. - let first = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse_completed("resp1"), "text/event-stream"); - - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(first) - .expect(1) - .mount(&server) - .await; + let resp_mock = + responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1")) + .await; let model_provider = ModelProviderInfo { base_url: Some(format!("{}/v1", server.uri())), @@ -414,8 +399,8 @@ async fn includes_base_instructions_override_in_request() { wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let request = &server.received_requests().await.unwrap()[0]; - let request_body = request.body_json::().unwrap(); + let request = resp_mock.single_request(); + let request_body = request.body_json(); assert!( request_body["instructions"] @@ -570,16 +555,9 @@ async fn includes_user_instructions_message_in_request() { skip_if_no_network!(); let server = MockServer::start().await; - let first = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse_completed("resp1"), "text/event-stream"); - - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(first) - .expect(1) - .mount(&server) - .await; + let resp_mock = + responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1")) + .await; let model_provider = ModelProviderInfo { base_url: Some(format!("{}/v1", server.uri())), @@ -610,8 +588,8 @@ async fn includes_user_instructions_message_in_request() { wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let request = &server.received_requests().await.unwrap()[0]; - let request_body = request.body_json::().unwrap(); + let request = resp_mock.single_request(); + let request_body = request.body_json(); assert!( !request_body["instructions"] diff --git a/codex-rs/core/tests/suite/model_tools.rs b/codex-rs/core/tests/suite/model_tools.rs index 6a7c5762..ee7b44d4 100644 --- a/codex-rs/core/tests/suite/model_tools.rs +++ b/codex-rs/core/tests/suite/model_tools.rs @@ -10,14 +10,11 @@ use codex_core::protocol::InputItem; use codex_core::protocol::Op; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; +use core_test_support::responses; use core_test_support::skip_if_no_network; use core_test_support::wait_for_event; use tempfile::TempDir; -use wiremock::Mock; use wiremock::MockServer; -use wiremock::ResponseTemplate; -use wiremock::matchers::method; -use wiremock::matchers::path; fn sse_completed(id: &str) -> String { load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) @@ -44,16 +41,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec { let server = MockServer::start().await; let sse = sse_completed(model); - let template = ResponseTemplate::new(200) - .insert_header("content-type", "text/event-stream") - .set_body_raw(sse, "text/event-stream"); - - Mock::given(method("POST")) - .and(path("/v1/responses")) - .respond_with(template) - .expect(1) - .mount(&server) - .await; + let resp_mock = responses::mount_sse_once_match(&server, wiremock::matchers::any(), sse).await; let model_provider = ModelProviderInfo { base_url: Some(format!("{}/v1", server.uri())), @@ -93,13 +81,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec { .unwrap(); wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.unwrap(); - assert_eq!( - requests.len(), - 1, - "expected a single request for model {model}" - ); - let body = requests[0].body_json::().unwrap(); + let body = resp_mock.single_request().body_json(); tool_identifiers(&body) } diff --git a/codex-rs/core/tests/suite/read_file.rs b/codex-rs/core/tests/suite/read_file.rs index a6c8a7a1..fc5a94f9 100644 --- a/codex-rs/core/tests/suite/read_file.rs +++ b/codex-rs/core/tests/suite/read_file.rs @@ -58,7 +58,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> { ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -79,36 +79,12 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> { wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.expect("recorded requests"); - let request_bodies = requests - .iter() - .map(|req| req.body_json::().unwrap()) - .collect::>(); - assert!( - !request_bodies.is_empty(), - "expected at least one request body" - ); - - let tool_output_item = request_bodies - .iter() - .find_map(|body| { - body.get("input") - .and_then(Value::as_array) - .and_then(|items| { - items.iter().find(|item| { - item.get("type").and_then(Value::as_str) == Some("function_call_output") - }) - }) - }) - .unwrap_or_else(|| { - panic!("function_call_output item not found in requests: {request_bodies:#?}") - }); - + let req = second_mock.single_request(); + let tool_output_item = req.function_call_output(call_id); assert_eq!( tool_output_item.get("call_id").and_then(Value::as_str), Some(call_id) ); - let output_text = tool_output_item .get("output") .and_then(|value| match value { diff --git a/codex-rs/core/tests/suite/tool_harness.rs b/codex-rs/core/tests/suite/tool_harness.rs index 14e0e1c8..eaefe7d9 100644 --- a/codex-rs/core/tests/suite/tool_harness.rs +++ b/codex-rs/core/tests/suite/tool_harness.rs @@ -27,16 +27,6 @@ use serde_json::Value; use serde_json::json; use wiremock::matchers::any; -fn function_call_output(body: &Value) -> Option<&Value> { - body.get("input") - .and_then(Value::as_array) - .and_then(|items| { - items.iter().find(|item| { - item.get("type").and_then(Value::as_str) == Some("function_call_output") - }) - }) -} - fn extract_output_text(item: &Value) -> Option<&str> { item.get("output").and_then(|value| match value { Value::String(text) => Some(text.as_str()), @@ -45,12 +35,6 @@ fn extract_output_text(item: &Value) -> Option<&str> { }) } -fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> { - requests - .iter() - .find(|body| function_call_output(body).is_some()) -} - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -81,7 +65,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> ev_assistant_message("msg-1", "all done"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -102,18 +86,9 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.expect("recorded requests"); - assert!(!requests.is_empty(), "expected at least one POST request"); - - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); - let output_text = extract_output_text(output_item).expect("output text present"); + let req = second_mock.single_request(); + let output_item = req.function_call_output(call_id); + let output_text = extract_output_text(&output_item).expect("output text present"); let exec_output: Value = serde_json::from_str(output_text)?; assert_eq!(exec_output["metadata"]["exit_code"], 0); let stdout = exec_output["output"].as_str().expect("stdout field"); @@ -159,7 +134,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> { ev_assistant_message("msg-1", "plan acknowledged"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -197,22 +172,13 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> { assert!(saw_plan_update, "expected PlanUpdate event"); - let requests = server.received_requests().await.expect("recorded requests"); - assert!(!requests.is_empty(), "expected at least one POST request"); - - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let req = second_mock.single_request(); + let output_item = req.function_call_output(call_id); assert_eq!( output_item.get("call_id").and_then(Value::as_str), Some(call_id) ); - let output_text = extract_output_text(output_item).expect("output text present"); + let output_text = extract_output_text(&output_item).expect("output text present"); assert_eq!(output_text, "Plan updated"); Ok(()) @@ -251,7 +217,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> { ev_assistant_message("msg-1", "malformed plan payload"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -286,22 +252,13 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> { "did not expect PlanUpdate event for malformed payload" ); - let requests = server.received_requests().await.expect("recorded requests"); - assert!(!requests.is_empty(), "expected at least one POST request"); - - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let req = second_mock.single_request(); + let output_item = req.function_call_output(call_id); assert_eq!( output_item.get("call_id").and_then(Value::as_str), Some(call_id) ); - let output_text = extract_output_text(output_item).expect("output text present"); + let output_text = extract_output_text(&output_item).expect("output text present"); assert!( output_text.contains("failed to parse function arguments"), "expected parse error message in output text, got {output_text:?}" @@ -354,7 +311,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<() ev_assistant_message("msg-1", "patch complete"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -395,22 +352,13 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<() let patch_end_success = patch_end_success.expect("expected PatchApplyEnd event to capture success flag"); - let requests = server.received_requests().await.expect("recorded requests"); - assert!(!requests.is_empty(), "expected at least one POST request"); - - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let req = second_mock.single_request(); + let output_item = req.function_call_output(call_id); assert_eq!( output_item.get("call_id").and_then(Value::as_str), Some(call_id) ); - let output_text = extract_output_text(output_item).expect("output text present"); + let output_text = extract_output_text(&output_item).expect("output text present"); if let Ok(exec_output) = serde_json::from_str::(output_text) { let exit_code = exec_output["metadata"]["exit_code"] @@ -480,7 +428,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> { ev_assistant_message("msg-1", "failed"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -501,22 +449,13 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> { wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.expect("recorded requests"); - assert!(!requests.is_empty(), "expected at least one POST request"); - - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let req = second_mock.single_request(); + let output_item = req.function_call_output(call_id); assert_eq!( output_item.get("call_id").and_then(Value::as_str), Some(call_id) ); - let output_text = extract_output_text(output_item).expect("output text present"); + let output_text = extract_output_text(&output_item).expect("output text present"); assert!( output_text.contains("apply_patch verification failed"), diff --git a/codex-rs/core/tests/suite/tools.rs b/codex-rs/core/tests/suite/tools.rs index 08826a1a..27e709f2 100644 --- a/codex-rs/core/tests/suite/tools.rs +++ b/codex-rs/core/tests/suite/tools.rs @@ -15,6 +15,7 @@ use core_test_support::responses::ev_completed; use core_test_support::responses::ev_custom_tool_call; use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_once; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::responses::start_mock_server; @@ -25,7 +26,6 @@ use core_test_support::wait_for_event; use regex_lite::Regex; use serde_json::Value; use serde_json::json; -use wiremock::Request; async fn submit_turn( test: &TestCodex, @@ -58,27 +58,6 @@ async fn submit_turn( Ok(()) } -fn request_bodies(requests: &[Request]) -> Result> { - requests - .iter() - .map(|req| Ok(serde_json::from_slice::(&req.body)?)) - .collect() -} - -fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> { - let mut out = Vec::new(); - for body in bodies { - if let Some(items) = body.get("input").and_then(Value::as_array) { - for item in items { - if item.get("type").and_then(Value::as_str) == Some(ty) { - out.push(item); - } - } - } - } - out -} - fn tool_names(body: &Value) -> Vec { body.get("tools") .and_then(Value::as_array) @@ -107,18 +86,23 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> { let call_id = "custom-unsupported"; let tool_name = "unsupported_tool"; - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), ev_custom_tool_call(call_id, tool_name, "\"payload\""), ev_completed("resp-1"), ]), + ) + .await; + let mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -128,13 +112,7 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let custom_items = collect_output_items(&bodies, "custom_tool_call_output"); - assert_eq!(custom_items.len(), 1, "expected single custom tool output"); - let item = custom_items[0]; - assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id)); - + let item = mock.single_request().custom_tool_call_output(call_id); let output = item .get("output") .and_then(Value::as_str) @@ -170,7 +148,8 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> { "timeout_ms": 1_000, }); - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), ev_function_call( @@ -180,6 +159,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> { ), ev_completed("resp-1"), ]), + ) + .await; + let second_mock = mount_sse_once( + &server, sse(vec![ ev_response_created("resp-2"), ev_function_call( @@ -189,12 +172,16 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> { ), ev_completed("resp-2"), ]), + ) + .await; + let third_mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-3"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -204,46 +191,23 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let function_outputs = collect_output_items(&bodies, "function_call_output"); - for item in &function_outputs { - let call_id = item - .get("call_id") - .and_then(Value::as_str) - .unwrap_or_default(); - assert!( - call_id == call_id_blocked || call_id == call_id_success, - "unexpected call id {call_id}" - ); - } - let policy = AskForApproval::Never; let expected_message = format!( "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}" ); - let blocked_outputs: Vec<&Value> = function_outputs - .iter() - .filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked)) - .copied() - .collect(); - assert!( - !blocked_outputs.is_empty(), - "expected at least one rejection output for {call_id_blocked}" + let blocked_item = second_mock + .single_request() + .function_call_output(call_id_blocked); + assert_eq!( + blocked_item.get("output").and_then(Value::as_str), + Some(expected_message.as_str()), + "unexpected rejection message" ); - for item in blocked_outputs { - assert_eq!( - item.get("output").and_then(Value::as_str), - Some(expected_message.as_str()), - "unexpected rejection message" - ); - } - let success_item = function_outputs - .iter() - .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success)) - .expect("success output present"); + let success_item = third_mock + .single_request() + .function_call_output(call_id_success); let output_json: Value = serde_json::from_str( success_item .get("output") @@ -282,18 +246,23 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> { } }); - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), local_shell_event, ev_completed("resp-1"), ]), + ) + .await; + let second_mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -303,15 +272,7 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let function_outputs = collect_output_items(&bodies, "function_call_output"); - assert_eq!( - function_outputs.len(), - 1, - "expected a single function output" - ); - let item = function_outputs[0]; + let item = second_mock.single_request().function_call_output(""); assert_eq!(item.get("call_id").and_then(Value::as_str), Some("")); assert_eq!( item.get("output").and_then(Value::as_str), @@ -329,7 +290,7 @@ async fn collect_tools(use_unified_exec: bool) -> Result> { ev_assistant_message("msg-1", "done"), ev_completed("resp-1"), ])]; - mount_sse_sequence(&server, responses).await; + let mock = mount_sse_sequence(&server, responses).await; let mut builder = test_codex().with_config(move |config| { config.use_experimental_unified_exec_tool = use_unified_exec; @@ -344,15 +305,8 @@ async fn collect_tools(use_unified_exec: bool) -> Result> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - assert_eq!( - requests.len(), - 1, - "expected a single request for tools collection" - ); - let bodies = request_bodies(&requests)?; - let first_body = bodies.first().expect("request body present"); - Ok(tool_names(first_body)) + let first_body = mock.single_request().body_json(); + Ok(tool_names(&first_body)) } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -392,18 +346,23 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> { "timeout_ms": timeout_ms, }); - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), + ) + .await; + let second_mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -413,13 +372,7 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let function_outputs = collect_output_items(&bodies, "function_call_output"); - let timeout_item = function_outputs - .iter() - .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id)) - .expect("timeout output present"); + let timeout_item = second_mock.single_request().function_call_output(call_id); let output_str = timeout_item .get("output") @@ -478,18 +431,23 @@ async fn shell_sandbox_denied_truncates_error_output() -> Result<()> { "timeout_ms": 1_000, }); - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), + ) + .await; + let second_mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -499,13 +457,7 @@ async fn shell_sandbox_denied_truncates_error_output() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let function_outputs = collect_output_items(&bodies, "function_call_output"); - let denied_item = function_outputs - .iter() - .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id)) - .expect("denied output present"); + let denied_item = second_mock.single_request().function_call_output(call_id); let output = denied_item .get("output") @@ -558,18 +510,23 @@ async fn shell_spawn_failure_truncates_exec_error() -> Result<()> { "timeout_ms": 1_000, }); - let responses = vec![ + mount_sse_once( + &server, sse(vec![ ev_response_created("resp-1"), ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), ev_completed("resp-1"), ]), + ) + .await; + let second_mock = mount_sse_once( + &server, sse(vec![ ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]), - ]; - mount_sse_sequence(&server, responses).await; + ) + .await; submit_turn( &test, @@ -579,13 +536,7 @@ async fn shell_spawn_failure_truncates_exec_error() -> Result<()> { ) .await?; - let requests = server.received_requests().await.expect("recorded requests"); - let bodies = request_bodies(&requests)?; - let function_outputs = collect_output_items(&bodies, "function_call_output"); - let failure_item = function_outputs - .iter() - .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id)) - .expect("spawn failure output present"); + let failure_item = second_mock.single_request().function_call_output(call_id); let output = failure_item .get("output") diff --git a/codex-rs/core/tests/suite/view_image.rs b/codex-rs/core/tests/suite/view_image.rs index 16913f4f..bdb67ad6 100644 --- a/codex-rs/core/tests/suite/view_image.rs +++ b/codex-rs/core/tests/suite/view_image.rs @@ -22,16 +22,6 @@ use core_test_support::wait_for_event; use serde_json::Value; use wiremock::matchers::any; -fn function_call_output(body: &Value) -> Option<&Value> { - body.get("input") - .and_then(Value::as_array) - .and_then(|items| { - items.iter().find(|item| { - item.get("type").and_then(Value::as_str) == Some("function_call_output") - }) - }) -} - fn find_image_message(body: &Value) -> Option<&Value> { body.get("input") .and_then(Value::as_array) @@ -59,12 +49,6 @@ fn extract_output_text(item: &Value) -> Option<&str> { }) } -fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> { - requests - .iter() - .find(|body| function_call_output(body).is_some()) -} - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> { skip_if_no_network!(Ok(())); @@ -100,7 +84,7 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> { ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -137,25 +121,14 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> { assert_eq!(tool_event.call_id, call_id); assert_eq!(tool_event.path, abs_path); - let requests = server.received_requests().await.expect("recorded requests"); - assert!( - requests.len() >= 2, - "expected at least two POST requests, got {}", - requests.len() - ); - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); + let body = mock.single_request().body_json(); + let output_item = mock.single_request().function_call_output(call_id); - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); - let output_text = extract_output_text(output_item).expect("output text present"); + let output_text = extract_output_text(&output_item).expect("output text present"); assert_eq!(output_text, "attached local image path"); - let image_message = find_image_message(body_with_tool_output) - .expect("pending input image message not included in request"); + let image_message = + find_image_message(&body).expect("pending input image message not included in request"); let image_url = image_message .get("content") .and_then(Value::as_array) @@ -210,7 +183,7 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> { ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -231,26 +204,14 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> { wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.expect("recorded requests"); - assert!( - requests.len() >= 2, - "expected at least two POST requests, got {}", - requests.len() - ); - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); - let output_text = extract_output_text(output_item).expect("output text present"); + let body_with_tool_output = mock.single_request().body_json(); + let output_item = mock.single_request().function_call_output(call_id); + let output_text = extract_output_text(&output_item).expect("output text present"); let expected_message = format!("image path `{}` is not a file", abs_path.display()); assert_eq!(output_text, expected_message); assert!( - find_image_message(body_with_tool_output).is_none(), + find_image_message(&body_with_tool_output).is_none(), "directory path should not produce an input_image message" ); @@ -287,7 +248,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> { ev_assistant_message("msg-1", "done"), ev_completed("resp-2"), ]); - responses::mount_sse_once_match(&server, any(), second_response).await; + let mock = responses::mount_sse_once_match(&server, any(), second_response).await; let session_model = session_configured.model.clone(); @@ -308,21 +269,9 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> { wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; - let requests = server.received_requests().await.expect("recorded requests"); - assert!( - requests.len() >= 2, - "expected at least two POST requests, got {}", - requests.len() - ); - let request_bodies = requests - .iter() - .map(|req| req.body_json::().expect("request json")) - .collect::>(); - - let body_with_tool_output = find_request_with_function_call_output(&request_bodies) - .expect("function_call_output item not found in requests"); - let output_item = function_call_output(body_with_tool_output).expect("tool output item"); - let output_text = extract_output_text(output_item).expect("output text present"); + let body_with_tool_output = mock.single_request().body_json(); + let output_item = mock.single_request().function_call_output(call_id); + let output_text = extract_output_text(&output_item).expect("output text present"); let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display()); assert!( output_text.starts_with(&expected_prefix), @@ -330,7 +279,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> { ); assert!( - find_image_message(body_with_tool_output).is_none(), + find_image_message(&body_with_tool_output).is_none(), "missing file should not produce an input_image message" ); diff --git a/codex-rs/exec/tests/suite/output_schema.rs b/codex-rs/exec/tests/suite/output_schema.rs index b054484f..913270ef 100644 --- a/codex-rs/exec/tests/suite/output_schema.rs +++ b/codex-rs/exec/tests/suite/output_schema.rs @@ -28,7 +28,7 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> { responses::ev_assistant_message("m1", "fixture hello"), responses::ev_completed("resp1"), ]); - responses::mount_sse_once_match(&server, any(), body).await; + let response_mock = responses::mount_sse_once_match(&server, any(), body).await; test.cmd_with_server(&server) .arg("--skip-git-repo-check") @@ -43,12 +43,8 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> { .assert() .success(); - let requests = server - .received_requests() - .await - .expect("failed to capture requests"); - assert_eq!(requests.len(), 1, "expected exactly one request"); - let payload: Value = serde_json::from_slice(&requests[0].body)?; + let request = response_mock.single_request(); + let payload: Value = request.body_json(); let text = payload.get("text").expect("request missing text field"); let format = text .get("format")