Bug fix: deduplicate assistant messages (#2758)
We are treating assistant messages in a different way than other messages which resulted in a duplicated history. See #2698
This commit is contained in:
@@ -1780,11 +1780,6 @@ async fn try_run_turn(
|
|||||||
return Ok(output);
|
return Ok(output);
|
||||||
}
|
}
|
||||||
ResponseEvent::OutputTextDelta(delta) => {
|
ResponseEvent::OutputTextDelta(delta) => {
|
||||||
{
|
|
||||||
let mut st = sess.state.lock_unchecked();
|
|
||||||
st.history.append_assistant_text(&delta);
|
|
||||||
}
|
|
||||||
|
|
||||||
let event = Event {
|
let event = Event {
|
||||||
id: sub_id.to_string(),
|
id: sub_id.to_string(),
|
||||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||||
|
|||||||
@@ -28,49 +28,7 @@ impl ConversationHistory {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge adjacent assistant messages into a single history entry.
|
self.items.push(item.clone());
|
||||||
// This prevents duplicates when a partial assistant message was
|
|
||||||
// streamed into history earlier in the turn and the final full
|
|
||||||
// message is recorded at turn end.
|
|
||||||
match (&*item, self.items.last_mut()) {
|
|
||||||
(
|
|
||||||
ResponseItem::Message {
|
|
||||||
role: new_role,
|
|
||||||
content: new_content,
|
|
||||||
..
|
|
||||||
},
|
|
||||||
Some(ResponseItem::Message {
|
|
||||||
role: last_role,
|
|
||||||
content: last_content,
|
|
||||||
..
|
|
||||||
}),
|
|
||||||
) if new_role == "assistant" && last_role == "assistant" => {
|
|
||||||
append_text_content(last_content, new_content);
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
self.items.push(item.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append a text `delta` to the latest assistant message, creating a new
|
|
||||||
/// assistant entry if none exists yet (e.g. first delta for this turn).
|
|
||||||
pub(crate) fn append_assistant_text(&mut self, delta: &str) {
|
|
||||||
match self.items.last_mut() {
|
|
||||||
Some(ResponseItem::Message { role, content, .. }) if role == "assistant" => {
|
|
||||||
append_text_delta(content, delta);
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Start a new assistant message with the delta.
|
|
||||||
self.items.push(ResponseItem::Message {
|
|
||||||
id: None,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
|
||||||
text: delta.to_string(),
|
|
||||||
}],
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,34 +76,6 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper to append the textual content from `src` into `dst` in place.
|
|
||||||
fn append_text_content(
|
|
||||||
dst: &mut Vec<codex_protocol::models::ContentItem>,
|
|
||||||
src: &Vec<codex_protocol::models::ContentItem>,
|
|
||||||
) {
|
|
||||||
for c in src {
|
|
||||||
if let codex_protocol::models::ContentItem::OutputText { text } = c {
|
|
||||||
append_text_delta(dst, text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Append a single text delta to the last OutputText item in `content`, or
|
|
||||||
/// push a new OutputText item if none exists.
|
|
||||||
fn append_text_delta(content: &mut Vec<codex_protocol::models::ContentItem>, delta: &str) {
|
|
||||||
if let Some(codex_protocol::models::ContentItem::OutputText { text }) = content
|
|
||||||
.iter_mut()
|
|
||||||
.rev()
|
|
||||||
.find(|c| matches!(c, codex_protocol::models::ContentItem::OutputText { .. }))
|
|
||||||
{
|
|
||||||
text.push_str(delta);
|
|
||||||
} else {
|
|
||||||
content.push(codex_protocol::models::ContentItem::OutputText {
|
|
||||||
text: delta.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -171,49 +101,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn merges_adjacent_assistant_messages() {
|
|
||||||
let mut h = ConversationHistory::default();
|
|
||||||
let a1 = assistant_msg("Hello");
|
|
||||||
let a2 = assistant_msg(", world!");
|
|
||||||
h.record_items([&a1, &a2]);
|
|
||||||
|
|
||||||
let items = h.contents();
|
|
||||||
assert_eq!(
|
|
||||||
items,
|
|
||||||
vec![ResponseItem::Message {
|
|
||||||
id: None,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: vec![ContentItem::OutputText {
|
|
||||||
text: "Hello, world!".to_string()
|
|
||||||
}]
|
|
||||||
}]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn append_assistant_text_creates_and_appends() {
|
|
||||||
let mut h = ConversationHistory::default();
|
|
||||||
h.append_assistant_text("Hello");
|
|
||||||
h.append_assistant_text(", world");
|
|
||||||
|
|
||||||
// Now record a final full assistant message and verify it merges.
|
|
||||||
let final_msg = assistant_msg("!");
|
|
||||||
h.record_items([&final_msg]);
|
|
||||||
|
|
||||||
let items = h.contents();
|
|
||||||
assert_eq!(
|
|
||||||
items,
|
|
||||||
vec![ResponseItem::Message {
|
|
||||||
id: None,
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: vec![ContentItem::OutputText {
|
|
||||||
text: "Hello, world!".to_string()
|
|
||||||
}]
|
|
||||||
}]
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn filters_non_api_messages() {
|
fn filters_non_api_messages() {
|
||||||
let mut h = ConversationHistory::default();
|
let mut h = ConversationHistory::default();
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ use codex_login::CodexAuth;
|
|||||||
use core_test_support::load_default_config_for_test;
|
use core_test_support::load_default_config_for_test;
|
||||||
use core_test_support::load_sse_fixture_with_id;
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
use core_test_support::wait_for_event;
|
use core_test_support::wait_for_event;
|
||||||
|
use serde_json::json;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
use wiremock::Mock;
|
use wiremock::Mock;
|
||||||
use wiremock::MockServer;
|
use wiremock::MockServer;
|
||||||
@@ -66,7 +67,6 @@ fn write_auth_json(
|
|||||||
account_id: Option<&str>,
|
account_id: Option<&str>,
|
||||||
) -> String {
|
) -> String {
|
||||||
use base64::Engine as _;
|
use base64::Engine as _;
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
let header = json!({ "alg": "none", "typ": "JWT" });
|
let header = json!({ "alg": "none", "typ": "JWT" });
|
||||||
let payload = json!({
|
let payload = json!({
|
||||||
@@ -746,3 +746,151 @@ async fn env_var_overrides_loaded_auth() {
|
|||||||
fn create_dummy_codex_auth() -> CodexAuth {
|
fn create_dummy_codex_auth() -> CodexAuth {
|
||||||
CodexAuth::create_dummy_chatgpt_auth_for_testing()
|
CodexAuth::create_dummy_chatgpt_auth_for_testing()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Scenario:
|
||||||
|
/// - Turn 1: user sends U1; model streams deltas then a final assistant message A.
|
||||||
|
/// - Turn 2: user sends U2; model streams a delta then the same final assistant message A.
|
||||||
|
/// - Turn 3: user sends U3; model responds (same SSE again, not important).
|
||||||
|
///
|
||||||
|
/// We assert that the `input` sent on each turn contains the expected conversation history
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn history_dedupes_streamed_and_final_messages_across_turns() {
|
||||||
|
// Skip under Codex sandbox network restrictions (mirrors other tests).
|
||||||
|
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||||
|
println!(
|
||||||
|
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock server that will receive three sequential requests and return the same SSE stream
|
||||||
|
// each time: a few deltas, then a final assistant message, then completed.
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
|
||||||
|
// Build a small SSE stream with deltas and a final assistant message.
|
||||||
|
// We emit the same body for all 3 turns; ids vary but are unused by assertions.
|
||||||
|
let sse_raw = r##"[
|
||||||
|
{"type":"response.output_text.delta", "delta":"Hey "},
|
||||||
|
{"type":"response.output_text.delta", "delta":"there"},
|
||||||
|
{"type":"response.output_text.delta", "delta":"!\n"},
|
||||||
|
{"type":"response.output_item.done", "item":{
|
||||||
|
"type":"message", "role":"assistant",
|
||||||
|
"content":[{"type":"output_text","text":"Hey there!\n"}]
|
||||||
|
}},
|
||||||
|
{"type":"response.completed", "response": {"id": "__ID__"}}
|
||||||
|
]"##;
|
||||||
|
let sse1 = core_test_support::load_sse_fixture_with_id_from_str(sse_raw, "resp1");
|
||||||
|
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/responses"))
|
||||||
|
.respond_with(
|
||||||
|
ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse1.clone(), "text/event-stream"),
|
||||||
|
)
|
||||||
|
.expect(3) // respond identically to the three sequential turns
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Configure provider to point to mock server (Responses API) and use API key auth.
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
|
..built_in_model_providers()["openai"].clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Init session with isolated codex home.
|
||||||
|
let codex_home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
config.model_provider = model_provider;
|
||||||
|
|
||||||
|
let conversation_manager =
|
||||||
|
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||||
|
let NewConversation {
|
||||||
|
conversation: codex,
|
||||||
|
..
|
||||||
|
} = conversation_manager
|
||||||
|
.new_conversation(config)
|
||||||
|
.await
|
||||||
|
.expect("create new conversation");
|
||||||
|
|
||||||
|
// Turn 1: user sends U1; wait for completion.
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text { text: "U1".into() }],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// Turn 2: user sends U2; wait for completion.
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text { text: "U2".into() }],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// Turn 3: user sends U3; wait for completion.
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text { text: "U3".into() }],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// Inspect the three captured requests.
|
||||||
|
let requests = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(requests.len(), 3, "expected 3 requests (one per turn)");
|
||||||
|
|
||||||
|
// Replace full-array compare with tail-only raw JSON compare using a single hard-coded value.
|
||||||
|
let r3_tail_expected = serde_json::json!([
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"id": null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type":"input_text","text":"U1"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"id": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"id": null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type":"input_text","text":"U2"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"id": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type":"output_text","text":"Hey there!\n"}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"id": null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type":"input_text","text":"U3"}]
|
||||||
|
}
|
||||||
|
]);
|
||||||
|
|
||||||
|
let r3_input_array = requests[2]
|
||||||
|
.body_json::<serde_json::Value>()
|
||||||
|
.unwrap()
|
||||||
|
.get("input")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.cloned()
|
||||||
|
.expect("r3 missing input array");
|
||||||
|
// skipping earlier context and developer messages
|
||||||
|
let tail_len = r3_tail_expected.as_array().unwrap().len();
|
||||||
|
let actual_tail = &r3_input_array[r3_input_array.len() - tail_len..];
|
||||||
|
assert_eq!(
|
||||||
|
serde_json::Value::Array(actual_tail.to_vec()),
|
||||||
|
r3_tail_expected,
|
||||||
|
"request 3 tail mismatch",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user