Simplify request body assertions (#4845)

We'll have a lot more test like these
This commit is contained in:
pakrym-oai
2025-10-07 01:56:39 -07:00
committed by GitHub
parent b09f62a1c3
commit 35a770e871
10 changed files with 285 additions and 392 deletions

View File

@@ -73,3 +73,28 @@ If you dont have the tool:
### Test assertions
- Tests should use pretty_assertions::assert_eq for clearer diffs. Import this at the top of the test module if it isn't already.
### Integration tests (core)
- Prefer the utilities in `core_test_support::responses` when writing end-to-end Codex tests.
- All `mount_sse*` helpers return a `ResponseMock`; hold onto it so you can assert against outbound `/responses` POST bodies.
- Use `ResponseMock::single_request()` when a test should only issue one POST, or `ResponseMock::requests()` to inspect every captured `ResponsesRequest`.
- `ResponsesRequest` exposes helpers (`body_json`, `input`, `function_call_output`, `custom_tool_call_output`, `call_output`, `header`, `path`, `query_param`) so assertions can target structured payloads instead of manual JSON digging.
- Build SSE payloads with the provided `ev_*` constructors and the `sse(...)`.
- Typical pattern:
```rust
let mock = responses::mount_sse_once(&server, responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
responses::ev_completed("resp-1"),
])).await;
codex.submit(Op::UserTurn { ... }).await?;
// Assert request body if needed.
let request = mock.single_request();
// assert using request.function_call_output(call_id) or request.json_body() or other helpers.
```

View File

@@ -1,11 +1,105 @@
use std::sync::Arc;
use std::sync::Mutex;
use serde_json::Value;
use wiremock::BodyPrintLimit;
use wiremock::Match;
use wiremock::Mock;
use wiremock::MockBuilder;
use wiremock::MockServer;
use wiremock::Respond;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
use wiremock::matchers::path_regex;
#[derive(Debug, Clone)]
pub struct ResponseMock {
requests: Arc<Mutex<Vec<ResponsesRequest>>>,
}
impl ResponseMock {
fn new() -> Self {
Self {
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn single_request(&self) -> ResponsesRequest {
let requests = self.requests.lock().unwrap();
if requests.len() != 1 {
panic!("expected 1 request, got {}", requests.len());
}
requests.first().unwrap().clone()
}
pub fn requests(&self) -> Vec<ResponsesRequest> {
self.requests.lock().unwrap().clone()
}
}
#[derive(Debug, Clone)]
pub struct ResponsesRequest(wiremock::Request);
impl ResponsesRequest {
pub fn body_json(&self) -> Value {
self.0.body_json().unwrap()
}
pub fn input(&self) -> Vec<Value> {
self.0.body_json::<Value>().unwrap()["input"]
.as_array()
.expect("input array not found in request")
.clone()
}
pub fn function_call_output(&self, call_id: &str) -> Value {
self.call_output(call_id, "function_call_output")
}
pub fn custom_tool_call_output(&self, call_id: &str) -> Value {
self.call_output(call_id, "custom_tool_call_output")
}
pub fn call_output(&self, call_id: &str, call_type: &str) -> Value {
self.input()
.iter()
.find(|item| {
item.get("type").unwrap() == call_type && item.get("call_id").unwrap() == call_id
})
.cloned()
.unwrap_or_else(|| panic!("function call output {call_id} item not found in request"))
}
pub fn header(&self, name: &str) -> Option<String> {
self.0
.headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::to_string)
}
pub fn path(&self) -> String {
self.0.url.path().to_string()
}
pub fn query_param(&self, name: &str) -> Option<String> {
self.0
.url
.query_pairs()
.find(|(k, _)| k == name)
.map(|(_, v)| v.to_string())
}
}
impl Match for ResponseMock {
fn matches(&self, request: &wiremock::Request) -> bool {
self.requests
.lock()
.unwrap()
.push(ResponsesRequest(request.clone()));
true
}
}
/// Build an SSE stream body from a list of JSON events.
pub fn sse(events: Vec<Value>) -> String {
@@ -161,34 +255,40 @@ pub fn sse_response(body: String) -> ResponseTemplate {
.set_body_raw(body, "text/event-stream")
}
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String)
fn base_mock() -> (MockBuilder, ResponseMock) {
let response_mock = ResponseMock::new();
let mock = Mock::given(method("POST"))
.and(path_regex(".*/responses$"))
.and(response_mock.clone());
(mock, response_mock)
}
pub async fn mount_sse_once_match<M>(server: &MockServer, matcher: M, body: String) -> ResponseMock
where
M: wiremock::Match + Send + Sync + 'static,
{
Mock::given(method("POST"))
.and(path("/v1/responses"))
.and(matcher)
let (mock, response_mock) = base_mock();
mock.and(matcher)
.respond_with(sse_response(body))
.up_to_n_times(1)
.mount(server)
.await;
response_mock
}
pub async fn mount_sse_once(server: &MockServer, body: String) {
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(sse_response(body))
.expect(1)
pub async fn mount_sse_once(server: &MockServer, body: String) -> ResponseMock {
let (mock, response_mock) = base_mock();
mock.respond_with(sse_response(body))
.up_to_n_times(1)
.mount(server)
.await;
response_mock
}
pub async fn mount_sse(server: &MockServer, body: String) {
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(sse_response(body))
.mount(server)
.await;
pub async fn mount_sse(server: &MockServer, body: String) -> ResponseMock {
let (mock, response_mock) = base_mock();
mock.respond_with(sse_response(body)).mount(server).await;
response_mock
}
pub async fn start_mock_server() -> MockServer {
@@ -201,7 +301,7 @@ pub async fn start_mock_server() -> MockServer {
/// Mounts a sequence of SSE response bodies and serves them in order for each
/// POST to `/v1/responses`. Panics if more requests are received than bodies
/// provided. Also asserts the exact number of expected calls.
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) -> ResponseMock {
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
@@ -228,10 +328,11 @@ pub async fn mount_sse_sequence(server: &MockServer, bodies: Vec<String>) {
responses: bodies,
};
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(responder)
let (mock, response_mock) = base_mock();
mock.respond_with(responder)
.expect(num_calls as u64)
.mount(server)
.await;
response_mock
}

View File

@@ -106,16 +106,12 @@ async fn exec_cli_applies_experimental_instructions_file() {
"data: {\"type\":\"response.created\",\"response\":{}}\n\n",
"data: {\"type\":\"response.completed\",\"response\":{\"id\":\"r1\"}}\n\n"
);
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse, "text/event-stream"),
)
.expect(1)
.mount(&server)
.await;
let resp_mock = core_test_support::responses::mount_sse_once_match(
&server,
path("/v1/responses"),
sse.to_string(),
)
.await;
// Create a temporary instructions file with a unique marker we can assert
// appears in the outbound request payload.
@@ -164,8 +160,8 @@ async fn exec_cli_applies_experimental_instructions_file() {
// Inspect the captured request and verify our custom base instructions were
// included in the `instructions` field.
let request = &server.received_requests().await.unwrap()[0];
let body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
let body = request.body_json();
let instructions = body
.get("instructions")
.and_then(|v| v.as_str())

View File

@@ -223,15 +223,9 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
// Mock server that will receive the resumed request
let server = MockServer::start().await;
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock =
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
.await;
// Configure Codex to resume from our file
let model_provider = ModelProviderInfo {
@@ -277,8 +271,8 @@ async fn resume_includes_initial_messages_and_sends_prior_items() {
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
let request_body = request.body_json();
let expected_input = json!([
{
"type": "message",
@@ -372,18 +366,9 @@ async fn includes_base_instructions_override_in_request() {
skip_if_no_network!();
// Mock server
let server = MockServer::start().await;
// First request must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock =
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
.await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
@@ -414,8 +399,8 @@ async fn includes_base_instructions_override_in_request() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
let request_body = request.body_json();
assert!(
request_body["instructions"]
@@ -570,16 +555,9 @@ async fn includes_user_instructions_message_in_request() {
skip_if_no_network!();
let server = MockServer::start().await;
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let resp_mock =
responses::mount_sse_once_match(&server, path("/v1/responses"), sse_completed("resp1"))
.await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
@@ -610,8 +588,8 @@ async fn includes_user_instructions_message_in_request() {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.body_json::<serde_json::Value>().unwrap();
let request = resp_mock.single_request();
let request_body = request.body_json();
assert!(
!request_body["instructions"]

View File

@@ -10,14 +10,11 @@ use codex_core::protocol::InputItem;
use codex_core::protocol::Op;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::responses;
use core_test_support::skip_if_no_network;
use core_test_support::wait_for_event;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
fn sse_completed(id: &str) -> String {
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
@@ -44,16 +41,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
let server = MockServer::start().await;
let sse = sse_completed(model);
let template = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse, "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(template)
.expect(1)
.mount(&server)
.await;
let resp_mock = responses::mount_sse_once_match(&server, wiremock::matchers::any(), sse).await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
@@ -93,13 +81,7 @@ async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.unwrap();
assert_eq!(
requests.len(),
1,
"expected a single request for model {model}"
);
let body = requests[0].body_json::<serde_json::Value>().unwrap();
let body = resp_mock.single_request().body_json();
tool_identifiers(&body)
}

View File

@@ -58,7 +58,7 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -79,36 +79,12 @@ async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().unwrap())
.collect::<Vec<_>>();
assert!(
!request_bodies.is_empty(),
"expected at least one request body"
);
let tool_output_item = request_bodies
.iter()
.find_map(|body| {
body.get("input")
.and_then(Value::as_array)
.and_then(|items| {
items.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
})
})
})
.unwrap_or_else(|| {
panic!("function_call_output item not found in requests: {request_bodies:#?}")
});
let req = second_mock.single_request();
let tool_output_item = req.function_call_output(call_id);
assert_eq!(
tool_output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let output_text = tool_output_item
.get("output")
.and_then(|value| match value {

View File

@@ -27,16 +27,6 @@ use serde_json::Value;
use serde_json::json;
use wiremock::matchers::any;
fn function_call_output(body: &Value) -> Option<&Value> {
body.get("input")
.and_then(Value::as_array)
.and_then(|items| {
items.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
})
})
}
fn extract_output_text(item: &Value) -> Option<&str> {
item.get("output").and_then(|value| match value {
Value::String(text) => Some(text.as_str()),
@@ -45,12 +35,6 @@ fn extract_output_text(item: &Value) -> Option<&str> {
})
}
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
requests
.iter()
.find(|body| function_call_output(body).is_some())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
@@ -81,7 +65,7 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
ev_assistant_message("msg-1", "all done"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -102,18 +86,9 @@ async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()>
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let output_text = extract_output_text(output_item).expect("output text present");
let req = second_mock.single_request();
let output_item = req.function_call_output(call_id);
let output_text = extract_output_text(&output_item).expect("output text present");
let exec_output: Value = serde_json::from_str(output_text)?;
assert_eq!(exec_output["metadata"]["exit_code"], 0);
let stdout = exec_output["output"].as_str().expect("stdout field");
@@ -159,7 +134,7 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "plan acknowledged"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -197,22 +172,13 @@ async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
assert!(saw_plan_update, "expected PlanUpdate event");
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let req = second_mock.single_request();
let output_item = req.function_call_output(call_id);
assert_eq!(
output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let output_text = extract_output_text(output_item).expect("output text present");
let output_text = extract_output_text(&output_item).expect("output text present");
assert_eq!(output_text, "Plan updated");
Ok(())
@@ -251,7 +217,7 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "malformed plan payload"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -286,22 +252,13 @@ async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
"did not expect PlanUpdate event for malformed payload"
);
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let req = second_mock.single_request();
let output_item = req.function_call_output(call_id);
assert_eq!(
output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let output_text = extract_output_text(output_item).expect("output text present");
let output_text = extract_output_text(&output_item).expect("output text present");
assert!(
output_text.contains("failed to parse function arguments"),
"expected parse error message in output text, got {output_text:?}"
@@ -354,7 +311,7 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
ev_assistant_message("msg-1", "patch complete"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -395,22 +352,13 @@ async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()
let patch_end_success =
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let req = second_mock.single_request();
let output_item = req.function_call_output(call_id);
assert_eq!(
output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let output_text = extract_output_text(output_item).expect("output text present");
let output_text = extract_output_text(&output_item).expect("output text present");
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
let exit_code = exec_output["metadata"]["exit_code"]
@@ -480,7 +428,7 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "failed"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let second_mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -501,22 +449,13 @@ async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(!requests.is_empty(), "expected at least one POST request");
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let req = second_mock.single_request();
let output_item = req.function_call_output(call_id);
assert_eq!(
output_item.get("call_id").and_then(Value::as_str),
Some(call_id)
);
let output_text = extract_output_text(output_item).expect("output text present");
let output_text = extract_output_text(&output_item).expect("output text present");
assert!(
output_text.contains("apply_patch verification failed"),

View File

@@ -15,6 +15,7 @@ use core_test_support::responses::ev_completed;
use core_test_support::responses::ev_custom_tool_call;
use core_test_support::responses::ev_function_call;
use core_test_support::responses::ev_response_created;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
@@ -25,7 +26,6 @@ use core_test_support::wait_for_event;
use regex_lite::Regex;
use serde_json::Value;
use serde_json::json;
use wiremock::Request;
async fn submit_turn(
test: &TestCodex,
@@ -58,27 +58,6 @@ async fn submit_turn(
Ok(())
}
fn request_bodies(requests: &[Request]) -> Result<Vec<Value>> {
requests
.iter()
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
.collect()
}
fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> {
let mut out = Vec::new();
for body in bodies {
if let Some(items) = body.get("input").and_then(Value::as_array) {
for item in items {
if item.get("type").and_then(Value::as_str) == Some(ty) {
out.push(item);
}
}
}
}
out
}
fn tool_names(body: &Value) -> Vec<String> {
body.get("tools")
.and_then(Value::as_array)
@@ -107,18 +86,23 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
let call_id = "custom-unsupported";
let tool_name = "unsupported_tool";
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_custom_tool_call(call_id, tool_name, "\"payload\""),
ev_completed("resp-1"),
]),
)
.await;
let mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -128,13 +112,7 @@ async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let custom_items = collect_output_items(&bodies, "custom_tool_call_output");
assert_eq!(custom_items.len(), 1, "expected single custom tool output");
let item = custom_items[0];
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id));
let item = mock.single_request().custom_tool_call_output(call_id);
let output = item
.get("output")
.and_then(Value::as_str)
@@ -170,7 +148,8 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
"timeout_ms": 1_000,
});
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(
@@ -180,6 +159,10 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
),
ev_completed("resp-1"),
]),
)
.await;
let second_mock = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-2"),
ev_function_call(
@@ -189,12 +172,16 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
),
ev_completed("resp-2"),
]),
)
.await;
let third_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-3"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -204,46 +191,23 @@ async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let function_outputs = collect_output_items(&bodies, "function_call_output");
for item in &function_outputs {
let call_id = item
.get("call_id")
.and_then(Value::as_str)
.unwrap_or_default();
assert!(
call_id == call_id_blocked || call_id == call_id_success,
"unexpected call id {call_id}"
);
}
let policy = AskForApproval::Never;
let expected_message = format!(
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}"
);
let blocked_outputs: Vec<&Value> = function_outputs
.iter()
.filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked))
.copied()
.collect();
assert!(
!blocked_outputs.is_empty(),
"expected at least one rejection output for {call_id_blocked}"
let blocked_item = second_mock
.single_request()
.function_call_output(call_id_blocked);
assert_eq!(
blocked_item.get("output").and_then(Value::as_str),
Some(expected_message.as_str()),
"unexpected rejection message"
);
for item in blocked_outputs {
assert_eq!(
item.get("output").and_then(Value::as_str),
Some(expected_message.as_str()),
"unexpected rejection message"
);
}
let success_item = function_outputs
.iter()
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success))
.expect("success output present");
let success_item = third_mock
.single_request()
.function_call_output(call_id_success);
let output_json: Value = serde_json::from_str(
success_item
.get("output")
@@ -282,18 +246,23 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
}
});
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
local_shell_event,
ev_completed("resp-1"),
]),
)
.await;
let second_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -303,15 +272,7 @@ async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let function_outputs = collect_output_items(&bodies, "function_call_output");
assert_eq!(
function_outputs.len(),
1,
"expected a single function output"
);
let item = function_outputs[0];
let item = second_mock.single_request().function_call_output("");
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
assert_eq!(
item.get("output").and_then(Value::as_str),
@@ -329,7 +290,7 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
ev_assistant_message("msg-1", "done"),
ev_completed("resp-1"),
])];
mount_sse_sequence(&server, responses).await;
let mock = mount_sse_sequence(&server, responses).await;
let mut builder = test_codex().with_config(move |config| {
config.use_experimental_unified_exec_tool = use_unified_exec;
@@ -344,15 +305,8 @@ async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
assert_eq!(
requests.len(),
1,
"expected a single request for tools collection"
);
let bodies = request_bodies(&requests)?;
let first_body = bodies.first().expect("request body present");
Ok(tool_names(first_body))
let first_body = mock.single_request().body_json();
Ok(tool_names(&first_body))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
@@ -392,18 +346,23 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
"timeout_ms": timeout_ms,
});
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
ev_completed("resp-1"),
]),
)
.await;
let second_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -413,13 +372,7 @@ async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let function_outputs = collect_output_items(&bodies, "function_call_output");
let timeout_item = function_outputs
.iter()
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
.expect("timeout output present");
let timeout_item = second_mock.single_request().function_call_output(call_id);
let output_str = timeout_item
.get("output")
@@ -478,18 +431,23 @@ async fn shell_sandbox_denied_truncates_error_output() -> Result<()> {
"timeout_ms": 1_000,
});
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
ev_completed("resp-1"),
]),
)
.await;
let second_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -499,13 +457,7 @@ async fn shell_sandbox_denied_truncates_error_output() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let function_outputs = collect_output_items(&bodies, "function_call_output");
let denied_item = function_outputs
.iter()
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
.expect("denied output present");
let denied_item = second_mock.single_request().function_call_output(call_id);
let output = denied_item
.get("output")
@@ -558,18 +510,23 @@ async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
"timeout_ms": 1_000,
});
let responses = vec![
mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
ev_completed("resp-1"),
]),
)
.await;
let second_mock = mount_sse_once(
&server,
sse(vec![
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]),
];
mount_sse_sequence(&server, responses).await;
)
.await;
submit_turn(
&test,
@@ -579,13 +536,7 @@ async fn shell_spawn_failure_truncates_exec_error() -> Result<()> {
)
.await?;
let requests = server.received_requests().await.expect("recorded requests");
let bodies = request_bodies(&requests)?;
let function_outputs = collect_output_items(&bodies, "function_call_output");
let failure_item = function_outputs
.iter()
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
.expect("spawn failure output present");
let failure_item = second_mock.single_request().function_call_output(call_id);
let output = failure_item
.get("output")

View File

@@ -22,16 +22,6 @@ use core_test_support::wait_for_event;
use serde_json::Value;
use wiremock::matchers::any;
fn function_call_output(body: &Value) -> Option<&Value> {
body.get("input")
.and_then(Value::as_array)
.and_then(|items| {
items.iter().find(|item| {
item.get("type").and_then(Value::as_str) == Some("function_call_output")
})
})
}
fn find_image_message(body: &Value) -> Option<&Value> {
body.get("input")
.and_then(Value::as_array)
@@ -59,12 +49,6 @@ fn extract_output_text(item: &Value) -> Option<&str> {
})
}
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
requests
.iter()
.find(|body| function_call_output(body).is_some())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
@@ -100,7 +84,7 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -137,25 +121,14 @@ async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
assert_eq!(tool_event.call_id, call_id);
assert_eq!(tool_event.path, abs_path);
let requests = server.received_requests().await.expect("recorded requests");
assert!(
requests.len() >= 2,
"expected at least two POST requests, got {}",
requests.len()
);
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body = mock.single_request().body_json();
let output_item = mock.single_request().function_call_output(call_id);
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let output_text = extract_output_text(output_item).expect("output text present");
let output_text = extract_output_text(&output_item).expect("output text present");
assert_eq!(output_text, "attached local image path");
let image_message = find_image_message(body_with_tool_output)
.expect("pending input image message not included in request");
let image_message =
find_image_message(&body).expect("pending input image message not included in request");
let image_url = image_message
.get("content")
.and_then(Value::as_array)
@@ -210,7 +183,7 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -231,26 +204,14 @@ async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(
requests.len() >= 2,
"expected at least two POST requests, got {}",
requests.len()
);
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let output_text = extract_output_text(output_item).expect("output text present");
let body_with_tool_output = mock.single_request().body_json();
let output_item = mock.single_request().function_call_output(call_id);
let output_text = extract_output_text(&output_item).expect("output text present");
let expected_message = format!("image path `{}` is not a file", abs_path.display());
assert_eq!(output_text, expected_message);
assert!(
find_image_message(body_with_tool_output).is_none(),
find_image_message(&body_with_tool_output).is_none(),
"directory path should not produce an input_image message"
);
@@ -287,7 +248,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
responses::mount_sse_once_match(&server, any(), second_response).await;
let mock = responses::mount_sse_once_match(&server, any(), second_response).await;
let session_model = session_configured.model.clone();
@@ -308,21 +269,9 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.expect("recorded requests");
assert!(
requests.len() >= 2,
"expected at least two POST requests, got {}",
requests.len()
);
let request_bodies = requests
.iter()
.map(|req| req.body_json::<Value>().expect("request json"))
.collect::<Vec<_>>();
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
.expect("function_call_output item not found in requests");
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
let output_text = extract_output_text(output_item).expect("output text present");
let body_with_tool_output = mock.single_request().body_json();
let output_item = mock.single_request().function_call_output(call_id);
let output_text = extract_output_text(&output_item).expect("output text present");
let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display());
assert!(
output_text.starts_with(&expected_prefix),
@@ -330,7 +279,7 @@ async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
);
assert!(
find_image_message(body_with_tool_output).is_none(),
find_image_message(&body_with_tool_output).is_none(),
"missing file should not produce an input_image message"
);

View File

@@ -28,7 +28,7 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
responses::ev_assistant_message("m1", "fixture hello"),
responses::ev_completed("resp1"),
]);
responses::mount_sse_once_match(&server, any(), body).await;
let response_mock = responses::mount_sse_once_match(&server, any(), body).await;
test.cmd_with_server(&server)
.arg("--skip-git-repo-check")
@@ -43,12 +43,8 @@ async fn exec_includes_output_schema_in_request() -> anyhow::Result<()> {
.assert()
.success();
let requests = server
.received_requests()
.await
.expect("failed to capture requests");
assert_eq!(requests.len(), 1, "expected exactly one request");
let payload: Value = serde_json::from_slice(&requests[0].body)?;
let request = response_mock.single_request();
let payload: Value = request.body_json();
let text = payload.get("text").expect("request missing text field");
let format = text
.get("format")