feat: context compaction (#3446)

## Compact feature:
1. Stops the model when the context window become too large
2. Add a user turn, asking for the model to summarize
3. Build a bridge that contains all the previous user message + the
summary. Rendered from a template
4. Start sampling again from a clean conversation with only that bridge
This commit is contained in:
jif-oai
2025-09-12 13:07:10 -07:00
committed by GitHub
parent d4848e558b
commit ea225df22e
14 changed files with 1243 additions and 326 deletions

View File

@@ -16,7 +16,6 @@ use codex_apply_patch::ApplyPatchAction;
use codex_apply_patch::MaybeApplyPatchVerified;
use codex_apply_patch::maybe_parse_apply_patch_verified;
use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::protocol::CompactedItem;
use codex_protocol::protocol::ConversationPathResponseEvent;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TaskStartedEvent;
@@ -77,7 +76,6 @@ use crate::parse_command::parse_command;
use crate::plan_tool::handle_update_plan;
use crate::project_doc::get_user_instructions;
use crate::protocol::AgentMessageDeltaEvent;
use crate::protocol::AgentMessageEvent;
use crate::protocol::AgentReasoningDeltaEvent;
use crate::protocol::AgentReasoningRawContentDeltaEvent;
use crate::protocol::AgentReasoningSectionBreakEvent;
@@ -102,6 +100,7 @@ use crate::protocol::SessionConfiguredEvent;
use crate::protocol::StreamErrorEvent;
use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
use crate::protocol::TurnDiffEvent;
use crate::protocol::WebSearchBeginEvent;
@@ -127,6 +126,8 @@ use codex_protocol::models::ResponseItem;
use codex_protocol::models::ShellToolCallParams;
use codex_protocol::protocol::InitialHistory;
mod compact;
// A convenience extension trait for acquiring mutex locks where poisoning is
// unrecoverable and should abort the program. This avoids scattered `.unwrap()`
// calls on `lock()` while still surfacing a clear panic message when a lock is
@@ -264,6 +265,7 @@ struct State {
pending_input: Vec<ResponseInputItem>,
history: ConversationHistory,
token_info: Option<TokenUsageInfo>,
next_internal_sub_id: u64,
}
/// Context for an initialized model agent
@@ -534,6 +536,13 @@ impl Session {
}
}
fn next_internal_sub_id(&self) -> String {
let mut state = self.state.lock_unchecked();
let id = state.next_internal_sub_id;
state.next_internal_sub_id += 1;
format!("auto-compact-{id}")
}
async fn record_initial_history(
&self,
turn_context: &TurnContext,
@@ -707,6 +716,21 @@ impl Session {
}
}
fn update_token_usage_info(
&self,
turn_context: &TurnContext,
token_usage: &Option<TokenUsage>,
) -> Option<TokenUsageInfo> {
let mut state = self.state.lock_unchecked();
let info = TokenUsageInfo::new_or_append(
&state.token_info,
token_usage,
turn_context.client.get_model_context_window(),
);
state.token_info = info.clone();
info
}
/// Record a user input item to conversation history and also persist a
/// corresponding UserMessage EventMsg to rollout.
async fn record_input_and_rollout_usermsg(&self, response_input: &ResponseInputItem) {
@@ -1026,8 +1050,7 @@ impl AgentTask {
let sess = sess.clone();
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move { run_task(sess, tc.as_ref(), sub_id, input).await })
.abort_handle()
tokio::spawn(async move { run_task(sess, tc, sub_id, input).await }).abort_handle()
};
Self {
sess,
@@ -1048,7 +1071,7 @@ impl AgentTask {
let sub_id = sub_id.clone();
let tc = Arc::clone(&turn_context);
tokio::spawn(async move {
run_compact_task(sess, tc.as_ref(), sub_id, input, compact_instructions).await
compact::run_compact_task(sess, tc, sub_id, input, compact_instructions).await
})
.abort_handle()
};
@@ -1342,21 +1365,16 @@ async fn submission_loop(
sess.send_event(event).await;
}
Op::Compact => {
// Create a summarization request as user input
const SUMMARIZATION_PROMPT: &str = include_str!("prompt_for_compact_command.md");
// Attempt to inject input into current task
if let Err(items) = sess.inject_input(vec![InputItem::Text {
text: "Start Summarization".to_string(),
text: compact::COMPACT_TRIGGER_TEXT.to_string(),
}]) {
let task = AgentTask::compact(
compact::spawn_compact_task(
sess.clone(),
Arc::clone(&turn_context),
sub.id,
items,
SUMMARIZATION_PROMPT.to_string(),
);
sess.set_task(task);
}
}
Op::Shutdown => {
@@ -1435,7 +1453,7 @@ async fn submission_loop(
/// conversation history and consider the task complete.
async fn run_task(
sess: Arc<Session>,
turn_context: &TurnContext,
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
) {
@@ -1458,6 +1476,7 @@ async fn run_task(
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new();
let mut auto_compact_recently_attempted = false;
loop {
// Note that pending_input would be something like a message the user
@@ -1492,7 +1511,7 @@ async fn run_task(
.collect();
match run_turn(
&sess,
turn_context,
turn_context.as_ref(),
&mut turn_diff_tracker,
sub_id.clone(),
turn_input,
@@ -1500,9 +1519,23 @@ async fn run_task(
.await
{
Ok(turn_output) => {
let TurnRunResult {
processed_items,
total_token_usage,
} = turn_output;
let limit = turn_context
.client
.get_auto_compact_token_limit()
.unwrap_or(i64::MAX);
let total_usage_tokens = total_token_usage
.as_ref()
.map(|usage| usage.tokens_in_context_window());
let token_limit_reached = total_usage_tokens
.map(|tokens| (tokens as i64) >= limit)
.unwrap_or(false);
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
for processed_response_item in turn_output {
for processed_response_item in processed_items {
let ProcessedResponseItem { item, response } = processed_response_item;
match (&item, &response) {
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
@@ -1599,8 +1632,31 @@ async fn run_task(
.await;
}
if token_limit_reached {
if auto_compact_recently_attempted {
let limit_str = limit.to_string();
let current_tokens = total_usage_tokens
.map(|tokens| tokens.to_string())
.unwrap_or_else(|| "unknown".to_string());
let event = Event {
id: sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: format!(
"Conversation is still above the token limit after automatic summarization (limit {limit_str}, current {current_tokens}). Please start a new session or trim your input."
),
}),
};
sess.send_event(event).await;
break;
}
auto_compact_recently_attempted = true;
compact::run_inline_auto_compact_task(sess.clone(), turn_context.clone()).await;
continue;
}
auto_compact_recently_attempted = false;
if responses.is_empty() {
debug!("Turn completed");
last_agent_message = get_last_assistant_message_from_turn(
&items_to_record_in_conversation_history,
);
@@ -1611,6 +1667,7 @@ async fn run_task(
});
break;
}
continue;
}
Err(e) => {
info!("Turn error: {e:#}");
@@ -1640,7 +1697,7 @@ async fn run_turn(
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
) -> CodexResult<TurnRunResult> {
let tools = get_openai_tools(
&turn_context.tools_config,
Some(sess.mcp_connection_manager.list_all_tools()),
@@ -1704,13 +1761,19 @@ struct ProcessedResponseItem {
response: Option<ResponseInputItem>,
}
#[derive(Debug)]
struct TurnRunResult {
processed_items: Vec<ProcessedResponseItem>,
total_token_usage: Option<TokenUsage>,
}
async fn try_run_turn(
sess: &Session,
turn_context: &TurnContext,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
.input
@@ -1828,16 +1891,7 @@ async fn try_run_turn(
response_id: _,
token_usage,
} => {
let info = {
let mut st = sess.state.lock_unchecked();
let info = TokenUsageInfo::new_or_append(
&st.token_info,
&token_usage,
turn_context.client.get_model_context_window(),
);
st.token_info = info.clone();
info
};
let info = sess.update_token_usage_info(turn_context, &token_usage);
let _ = sess
.send_event(Event {
id: sub_id.to_string(),
@@ -1855,7 +1909,12 @@ async fn try_run_turn(
sess.send_event(event).await;
}
return Ok(output);
let result = TurnRunResult {
processed_items: output,
total_token_usage: token_usage.clone(),
};
return Ok(result);
}
ResponseEvent::OutputTextDelta(delta) => {
let event = Event {
@@ -1893,95 +1952,6 @@ async fn try_run_turn(
}
}
async fn run_compact_task(
sess: Arc<Session>,
turn_context: &TurnContext,
sub_id: String,
input: Vec<InputItem>,
compact_instructions: String,
) {
let model_context_window = turn_context.client.get_model_context_window();
let start_event = Event {
id: sub_id.clone(),
msg: EventMsg::TaskStarted(TaskStartedEvent {
model_context_window,
}),
};
sess.send_event(start_event).await;
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let turn_input: Vec<ResponseItem> =
sess.turn_input_with_history(vec![initial_input_for_turn.clone().into()]);
let prompt = Prompt {
input: turn_input,
tools: Vec::new(),
base_instructions_override: Some(compact_instructions.clone()),
};
let max_retries = turn_context.client.get_provider().stream_max_retries();
let mut retries = 0;
loop {
let attempt_result = drain_to_completed(&sess, turn_context, &sub_id, &prompt).await;
match attempt_result {
Ok(()) => break,
Err(CodexErr::Interrupted) => return,
Err(e) => {
if retries < max_retries {
retries += 1;
let delay = backoff(retries);
sess.notify_stream_error(
&sub_id,
format!(
"stream error: {e}; retrying {retries}/{max_retries} in {delay:?}"
),
)
.await;
tokio::time::sleep(delay).await;
continue;
} else {
let event = Event {
id: sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: e.to_string(),
}),
};
sess.send_event(event).await;
return;
}
}
}
}
sess.remove_task(&sub_id);
let rollout_item = {
let mut state = sess.state.lock_unchecked();
state.history.keep_last_messages(1);
RolloutItem::Compacted(CompactedItem {
message: state.history.last_agent_message(),
})
};
sess.persist_rollout_items(&[rollout_item]).await;
let event = Event {
id: sub_id.clone(),
msg: EventMsg::AgentMessage(AgentMessageEvent {
message: "Compact task completed".to_string(),
}),
};
sess.send_event(event).await;
let event = Event {
id: sub_id.clone(),
msg: EventMsg::TaskComplete(TaskCompleteEvent {
last_agent_message: None,
}),
};
sess.send_event(event).await;
}
async fn handle_response_item(
sess: &Session,
turn_context: &TurnContext,
@@ -2964,7 +2934,7 @@ fn format_exec_output(exec_output: &ExecToolCallOutput) -> String {
serde_json::to_string(&payload).expect("serialize ExecOutput")
}
fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
responses.iter().rev().find_map(|item| {
if let ResponseItem::Message { role, content, .. } = item {
if role == "assistant" {
@@ -2983,68 +2953,6 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
}
})
}
async fn drain_to_completed(
sess: &Session,
turn_context: &TurnContext,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<()> {
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context.client.clone().stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {
return Err(CodexErr::Stream(
"stream closed before response.completed".into(),
None,
));
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
// Record only to in-memory conversation history; avoid state snapshot.
let mut state = sess.state.lock_unchecked();
state.history.record_items(std::slice::from_ref(&item));
}
Ok(ResponseEvent::Completed {
response_id: _,
token_usage,
}) => {
let info = {
let mut st = sess.state.lock_unchecked();
let info = TokenUsageInfo::new_or_append(
&st.token_info,
&token_usage,
turn_context.client.get_model_context_window(),
);
st.token_info = info.clone();
info
};
sess.tx_event
.send(Event {
id: sub_id.to_string(),
msg: EventMsg::TokenCount(crate::protocol::TokenCountEvent { info }),
})
.await
.ok();
return Ok(());
}
Ok(_) => continue,
Err(e) => return Err(e),
}
}
}
fn convert_call_tool_result_to_function_call_output_payload(
call_tool_result: &CallToolResult,
) -> FunctionCallOutputPayload {