Bug fix: deduplicate assistant messages (#2758)

We are treating assistant messages in a different way than other
messages which resulted in a duplicated history.

See #2698
This commit is contained in:
Ahmed Ibrahim
2025-08-27 01:29:16 -07:00
committed by GitHub
parent d0e06f74e2
commit 2d2f66f9c5
3 changed files with 150 additions and 120 deletions

View File

@@ -12,6 +12,7 @@ use codex_login::CodexAuth;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
use core_test_support::wait_for_event;
use serde_json::json;
use tempfile::TempDir;
use wiremock::Mock;
use wiremock::MockServer;
@@ -66,7 +67,6 @@ fn write_auth_json(
account_id: Option<&str>,
) -> String {
use base64::Engine as _;
use serde_json::json;
let header = json!({ "alg": "none", "typ": "JWT" });
let payload = json!({
@@ -746,3 +746,151 @@ async fn env_var_overrides_loaded_auth() {
fn create_dummy_codex_auth() -> CodexAuth {
CodexAuth::create_dummy_chatgpt_auth_for_testing()
}
/// Scenario:
/// - Turn 1: user sends U1; model streams deltas then a final assistant message A.
/// - Turn 2: user sends U2; model streams a delta then the same final assistant message A.
/// - Turn 3: user sends U3; model responds (same SSE again, not important).
///
/// We assert that the `input` sent on each turn contains the expected conversation history
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn history_dedupes_streamed_and_final_messages_across_turns() {
// Skip under Codex sandbox network restrictions (mirrors other tests).
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
println!(
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
);
return;
}
// Mock server that will receive three sequential requests and return the same SSE stream
// each time: a few deltas, then a final assistant message, then completed.
let server = MockServer::start().await;
// Build a small SSE stream with deltas and a final assistant message.
// We emit the same body for all 3 turns; ids vary but are unused by assertions.
let sse_raw = r##"[
{"type":"response.output_text.delta", "delta":"Hey "},
{"type":"response.output_text.delta", "delta":"there"},
{"type":"response.output_text.delta", "delta":"!\n"},
{"type":"response.output_item.done", "item":{
"type":"message", "role":"assistant",
"content":[{"type":"output_text","text":"Hey there!\n"}]
}},
{"type":"response.completed", "response": {"id": "__ID__"}}
]"##;
let sse1 = core_test_support::load_sse_fixture_with_id_from_str(sse_raw, "resp1");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(
ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse1.clone(), "text/event-stream"),
)
.expect(3) // respond identically to the three sequential turns
.mount(&server)
.await;
// Configure provider to point to mock server (Responses API) and use API key auth.
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
..built_in_model_providers()["openai"].clone()
};
// Init session with isolated codex home.
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.model_provider = model_provider;
let conversation_manager =
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
let NewConversation {
conversation: codex,
..
} = conversation_manager
.new_conversation(config)
.await
.expect("create new conversation");
// Turn 1: user sends U1; wait for completion.
codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: "U1".into() }],
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
// Turn 2: user sends U2; wait for completion.
codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: "U2".into() }],
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
// Turn 3: user sends U3; wait for completion.
codex
.submit(Op::UserInput {
items: vec![InputItem::Text { text: "U3".into() }],
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
// Inspect the three captured requests.
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
let r3_tail_expected = serde_json::json!([
{
"type": "message",
"id": null,
"role": "user",
"content": [{"type":"input_text","text":"U1"}]
},
{
"type": "message",
"id": null,
"role": "assistant",
"content": [{"type":"output_text","text":"Hey there!\n"}]
},
{
"type": "message",
"id": null,
"role": "user",
"content": [{"type":"input_text","text":"U2"}]
},
{
"type": "message",
"id": null,
"role": "assistant",
"content": [{"type":"output_text","text":"Hey there!\n"}]
},
{
"type": "message",
"id": null,
"role": "user",
"content": [{"type":"input_text","text":"U3"}]
}
]);
let r3_input_array = requests[2]
.body_json::<serde_json::Value>()
.unwrap()
.get("input")
.and_then(|v| v.as_array())
.cloned()
.expect("r3 missing input array");
// skipping earlier context and developer messages
let tail_len = r3_tail_expected.as_array().unwrap().len();
let actual_tail = &r3_input_array[r3_input_array.len() - tail_len..];
assert_eq!(
serde_json::Value::Array(actual_tail.to_vec()),
r3_tail_expected,
"request 3 tail mismatch",
);
}