feat: context compaction (#3446)
## Compact feature: 1. Stops the model when the context window become too large 2. Add a user turn, asking for the model to summarize 3. Build a bridge that contains all the previous user message + the summary. Rendered from a template 4. Start sampling again from a clean conversation with only that bridge
This commit is contained in:
@@ -5,6 +5,7 @@ use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::InputItem;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -15,13 +16,20 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::wait_for_event;
|
||||
use serde_json::Value;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::BodyPrintLimit;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::Request;
|
||||
use wiremock::Respond;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
// --- Test helpers -----------------------------------------------------------
|
||||
|
||||
@@ -52,6 +60,22 @@ fn ev_completed(id: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
fn ev_completed_with_tokens(id: &str, total_tokens: u64) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": id,
|
||||
"usage": {
|
||||
"input_tokens": total_tokens,
|
||||
"input_tokens_details": null,
|
||||
"output_tokens": 0,
|
||||
"output_tokens_details": null,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience: SSE event for a single assistant message output item.
|
||||
fn ev_assistant_message(id: &str, text: &str) -> Value {
|
||||
serde_json::json!({
|
||||
@@ -65,6 +89,18 @@ fn ev_assistant_message(id: &str, text: &str) -> Value {
|
||||
})
|
||||
}
|
||||
|
||||
fn ev_function_call(call_id: &str, name: &str, arguments: &str) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.output_item.done",
|
||||
"item": {
|
||||
"type": "function_call",
|
||||
"call_id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn sse_response(body: String) -> ResponseTemplate {
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
@@ -84,10 +120,28 @@ where
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn start_mock_server() -> MockServer {
|
||||
MockServer::builder()
|
||||
.body_print_limit(BodyPrintLimit::Limited(80_000))
|
||||
.start()
|
||||
.await
|
||||
}
|
||||
|
||||
const FIRST_REPLY: &str = "FIRST_REPLY";
|
||||
const SUMMARY_TEXT: &str = "SUMMARY_ONLY_CONTEXT";
|
||||
const SUMMARIZE_TRIGGER: &str = "Start Summarization";
|
||||
const THIRD_USER_MSG: &str = "next turn";
|
||||
const AUTO_SUMMARY_TEXT: &str = "AUTO_SUMMARY";
|
||||
const FIRST_AUTO_MSG: &str = "token limit start";
|
||||
const SECOND_AUTO_MSG: &str = "token limit push";
|
||||
const STILL_TOO_BIG_REPLY: &str = "STILL_TOO_BIG";
|
||||
const MULTI_AUTO_MSG: &str = "multi auto";
|
||||
const SECOND_LARGE_REPLY: &str = "SECOND_LARGE_REPLY";
|
||||
const FIRST_AUTO_SUMMARY: &str = "FIRST_AUTO_SUMMARY";
|
||||
const SECOND_AUTO_SUMMARY: &str = "SECOND_AUTO_SUMMARY";
|
||||
const FINAL_REPLY: &str = "FINAL_REPLY";
|
||||
const DUMMY_FUNCTION_NAME: &str = "unsupported_tool";
|
||||
const DUMMY_CALL_ID: &str = "call-multi-auto";
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn summarize_context_three_requests_and_instructions() {
|
||||
@@ -99,7 +153,7 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
}
|
||||
|
||||
// Set up a mock server that we can inspect after the run.
|
||||
let server = MockServer::start().await;
|
||||
let server = start_mock_server().await;
|
||||
|
||||
// SSE 1: assistant replies normally so it is recorded in history.
|
||||
let sse1 = sse(vec![
|
||||
@@ -144,6 +198,7 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200_000);
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
@@ -198,7 +253,7 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
"summarization should override base instructions"
|
||||
);
|
||||
assert!(
|
||||
instr2.contains("You are a summarization assistant"),
|
||||
instr2.contains("You have exceeded the maximum number of tokens"),
|
||||
"summarization instructions not applied"
|
||||
);
|
||||
|
||||
@@ -209,14 +264,17 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
assert_eq!(last2.get("type").unwrap().as_str().unwrap(), "message");
|
||||
assert_eq!(last2.get("role").unwrap().as_str().unwrap(), "user");
|
||||
let text2 = last2["content"][0]["text"].as_str().unwrap();
|
||||
assert!(text2.contains(SUMMARIZE_TRIGGER));
|
||||
assert!(
|
||||
text2.contains(SUMMARIZE_TRIGGER),
|
||||
"expected summarize trigger, got `{text2}`"
|
||||
);
|
||||
|
||||
// Third request must contain only the summary from step 2 as prior history plus new user msg.
|
||||
// Third request must contain the refreshed instructions, bridge summary message and new user msg.
|
||||
let input3 = body3.get("input").and_then(|v| v.as_array()).unwrap();
|
||||
println!("third request body: {body3}");
|
||||
assert!(
|
||||
input3.len() >= 2,
|
||||
"expected summary + new user message in third request"
|
||||
input3.len() >= 3,
|
||||
"expected refreshed context and new user message in third request"
|
||||
);
|
||||
|
||||
// Collect all (role, text) message tuples.
|
||||
@@ -232,24 +290,35 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
}
|
||||
}
|
||||
|
||||
// Exactly one assistant message should remain after compaction and the new user message is present.
|
||||
// No previous assistant messages should remain and the new user message is present.
|
||||
let assistant_count = messages.iter().filter(|(r, _)| r == "assistant").count();
|
||||
assert_eq!(
|
||||
assistant_count, 1,
|
||||
"exactly one assistant message should remain after compaction"
|
||||
);
|
||||
assert_eq!(assistant_count, 0, "assistant history should be cleared");
|
||||
assert!(
|
||||
messages
|
||||
.iter()
|
||||
.any(|(r, t)| r == "user" && t == THIRD_USER_MSG),
|
||||
"third request should include the new user message"
|
||||
);
|
||||
let Some((_, bridge_text)) = messages.iter().find(|(role, text)| {
|
||||
role == "user"
|
||||
&& (text.contains("Here were the user messages")
|
||||
|| text.contains("Here are all the user messages"))
|
||||
&& text.contains(SUMMARY_TEXT)
|
||||
}) else {
|
||||
panic!("expected a bridge message containing the summary");
|
||||
};
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains("hello world")),
|
||||
"third request should not include the original user input"
|
||||
bridge_text.contains("hello world"),
|
||||
"bridge should capture earlier user messages"
|
||||
);
|
||||
assert!(
|
||||
!messages.iter().any(|(_, t)| t.contains(SUMMARIZE_TRIGGER)),
|
||||
!bridge_text.contains(SUMMARIZE_TRIGGER),
|
||||
"bridge text should not echo the summarize trigger"
|
||||
);
|
||||
assert!(
|
||||
!messages
|
||||
.iter()
|
||||
.any(|(_, text)| text.contains(SUMMARIZE_TRIGGER)),
|
||||
"third request should not include the summarize trigger"
|
||||
);
|
||||
|
||||
@@ -258,6 +327,7 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
// Verify rollout contains APITurn entries for each API call and a Compacted entry.
|
||||
println!("rollout path: {}", rollout_path.display());
|
||||
let text = std::fs::read_to_string(&rollout_path).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"failed to read rollout file {}: {e}",
|
||||
@@ -296,3 +366,506 @@ async fn summarize_context_three_requests_and_instructions() {
|
||||
"expected a Compacted entry containing the summarizer output"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_runs_after_token_limit_hit() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed_with_tokens("r1", 70_000),
|
||||
]);
|
||||
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", "SECOND_REPLY"),
|
||||
ev_completed_with_tokens("r2", 330_000),
|
||||
]);
|
||||
|
||||
let sse3 = sse(vec![
|
||||
ev_assistant_message("m3", AUTO_SUMMARY_TEXT),
|
||||
ev_completed_with_tokens("r3", 200),
|
||||
]);
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(SECOND_AUTO_MSG)
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200_000);
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: FIRST_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: SECOND_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
// wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 3, "auto compact should add a third request");
|
||||
|
||||
let body3 = requests[2].body_json::<serde_json::Value>().unwrap();
|
||||
let instructions = body3
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
instructions.contains("You have exceeded the maximum number of tokens"),
|
||||
"auto compact should reuse summarization instructions"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_persists_rollout_entries() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed_with_tokens("r1", 70_000),
|
||||
]);
|
||||
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", "SECOND_REPLY"),
|
||||
ev_completed_with_tokens("r2", 330_000),
|
||||
]);
|
||||
|
||||
let sse3 = sse(vec![
|
||||
ev_assistant_message("m3", AUTO_SUMMARY_TEXT),
|
||||
ev_completed_with_tokens("r3", 200),
|
||||
]);
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains(SECOND_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(SECOND_AUTO_MSG)
|
||||
&& body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let NewConversation {
|
||||
conversation: codex,
|
||||
session_configured,
|
||||
..
|
||||
} = conversation_manager.new_conversation(config).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: FIRST_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: SECOND_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex.submit(Op::Shutdown).await.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::ShutdownComplete)).await;
|
||||
|
||||
let rollout_path = session_configured.rollout_path;
|
||||
let text = std::fs::read_to_string(&rollout_path).unwrap_or_else(|e| {
|
||||
panic!(
|
||||
"failed to read rollout file {}: {e}",
|
||||
rollout_path.display()
|
||||
)
|
||||
});
|
||||
|
||||
let mut turn_context_count = 0usize;
|
||||
for line in text.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let Ok(entry): Result<RolloutLine, _> = serde_json::from_str(trimmed) else {
|
||||
continue;
|
||||
};
|
||||
match entry.item {
|
||||
RolloutItem::TurnContext(_) => {
|
||||
turn_context_count += 1;
|
||||
}
|
||||
RolloutItem::Compacted(_) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
turn_context_count >= 2,
|
||||
"expected at least two turn context entries, got {turn_context_count}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_stops_after_failed_attempt() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed_with_tokens("r1", 500),
|
||||
]);
|
||||
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", SUMMARY_TEXT),
|
||||
ev_completed_with_tokens("r2", 50),
|
||||
]);
|
||||
|
||||
let sse3 = sse(vec![
|
||||
ev_assistant_message("m3", STILL_TOO_BIG_REPLY),
|
||||
ev_completed_with_tokens("r3", 500),
|
||||
]);
|
||||
|
||||
let first_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains(FIRST_AUTO_MSG)
|
||||
&& !body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(first_matcher)
|
||||
.respond_with(sse_response(sse1.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let second_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
body.contains("You have exceeded the maximum number of tokens")
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(second_matcher)
|
||||
.respond_with(sse_response(sse2.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let third_matcher = |req: &wiremock::Request| {
|
||||
let body = std::str::from_utf8(&req.body).unwrap_or("");
|
||||
!body.contains("You have exceeded the maximum number of tokens")
|
||||
&& body.contains(SUMMARY_TEXT)
|
||||
};
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.and(third_matcher)
|
||||
.respond_with(sse_response(sse3.clone()))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200);
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: FIRST_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let error_event = wait_for_event(&codex, |ev| matches!(ev, EventMsg::Error(_))).await;
|
||||
let EventMsg::Error(ErrorEvent { message }) = error_event else {
|
||||
panic!("expected error event");
|
||||
};
|
||||
assert!(
|
||||
message.contains("limit"),
|
||||
"error message should include limit information: {message}"
|
||||
);
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
3,
|
||||
"auto compact should attempt at most one summarization before erroring"
|
||||
);
|
||||
|
||||
let last_body = requests[2].body_json::<serde_json::Value>().unwrap();
|
||||
let instructions = last_body
|
||||
.get("instructions")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
assert!(
|
||||
!instructions.contains("You have exceeded the maximum number of tokens"),
|
||||
"third request should be the follow-up turn, not another summarization"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn auto_compact_allows_multiple_attempts_when_interleaved_with_other_turn_events() {
|
||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
||||
println!(
|
||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let server = start_mock_server().await;
|
||||
|
||||
let sse1 = sse(vec![
|
||||
ev_assistant_message("m1", FIRST_REPLY),
|
||||
ev_completed_with_tokens("r1", 500),
|
||||
]);
|
||||
let sse2 = sse(vec![
|
||||
ev_assistant_message("m2", FIRST_AUTO_SUMMARY),
|
||||
ev_completed_with_tokens("r2", 50),
|
||||
]);
|
||||
let sse3 = sse(vec![
|
||||
ev_function_call(DUMMY_CALL_ID, DUMMY_FUNCTION_NAME, "{}"),
|
||||
ev_completed_with_tokens("r3", 150),
|
||||
]);
|
||||
let sse4 = sse(vec![
|
||||
ev_assistant_message("m4", SECOND_LARGE_REPLY),
|
||||
ev_completed_with_tokens("r4", 450),
|
||||
]);
|
||||
let sse5 = sse(vec![
|
||||
ev_assistant_message("m5", SECOND_AUTO_SUMMARY),
|
||||
ev_completed_with_tokens("r5", 60),
|
||||
]);
|
||||
let sse6 = sse(vec![
|
||||
ev_assistant_message("m6", FINAL_REPLY),
|
||||
ev_completed_with_tokens("r6", 120),
|
||||
]);
|
||||
|
||||
#[derive(Clone)]
|
||||
struct SeqResponder {
|
||||
bodies: Arc<Vec<String>>,
|
||||
calls: Arc<AtomicUsize>,
|
||||
requests: Arc<Mutex<Vec<Vec<u8>>>>,
|
||||
}
|
||||
|
||||
impl SeqResponder {
|
||||
fn new(bodies: Vec<String>) -> Self {
|
||||
Self {
|
||||
bodies: Arc::new(bodies),
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
requests: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn recorded_requests(&self) -> Vec<Vec<u8>> {
|
||||
self.requests.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Respond for SeqResponder {
|
||||
fn respond(&self, req: &Request) -> ResponseTemplate {
|
||||
let idx = self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.requests.lock().unwrap().push(req.body.clone());
|
||||
let body = self
|
||||
.bodies
|
||||
.get(idx)
|
||||
.unwrap_or_else(|| panic!("unexpected request index {idx}"))
|
||||
.clone();
|
||||
ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(body, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
let responder = SeqResponder::new(vec![sse1, sse2, sse3, sse4, sse5, sse6]);
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(responder.clone())
|
||||
.expect(6)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&home);
|
||||
config.model_provider = model_provider;
|
||||
config.model_auto_compact_token_limit = Some(200);
|
||||
let conversation_manager = ConversationManager::with_auth(CodexAuth::from_api_key("dummy"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.unwrap()
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: MULTI_AUTO_MSG.into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
loop {
|
||||
let event = codex.next_event().await.unwrap();
|
||||
if let EventMsg::TaskComplete(_) = &event.msg
|
||||
&& !event.id.starts_with("auto-compact-")
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let request_bodies: Vec<String> = responder
|
||||
.recorded_requests()
|
||||
.into_iter()
|
||||
.map(|body| String::from_utf8(body).unwrap_or_default())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
request_bodies.len(),
|
||||
6,
|
||||
"expected six requests including two auto compactions"
|
||||
);
|
||||
assert!(
|
||||
request_bodies[0].contains(MULTI_AUTO_MSG),
|
||||
"first request should contain the user input"
|
||||
);
|
||||
assert!(
|
||||
request_bodies[1].contains("You have exceeded the maximum number of tokens"),
|
||||
"first auto compact request should use summarization instructions"
|
||||
);
|
||||
assert!(
|
||||
request_bodies[3].contains(&format!("unsupported call: {DUMMY_FUNCTION_NAME}")),
|
||||
"function call output should be sent before the second auto compact"
|
||||
);
|
||||
assert!(
|
||||
request_bodies[4].contains("You have exceeded the maximum number of tokens"),
|
||||
"second auto compact request should reuse summarization instructions"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user