Add a TurnDiffTracker to create a unified diff for an entire turn (#1770)

This lets us show an accumulating diff across all patches in a turn.
Refer to the docs for TurnDiffTracker for implementation details.

There are multiple ways this could have been done and this felt like the
right tradeoff between reliability and completeness:
*Pros*
* It will pick up all changes to files that the model touched including
if they prettier or another command that updates them.
* It will not pick up changes made by the user or other agents to files
it didn't modify.

*Cons*
* It will pick up changes that the user made to a file that the model
also touched
* It will not pick up changes to codegen or files that were not modified
with apply_patch
This commit is contained in:
Gabriel Peal
2025-08-04 08:57:04 -07:00
committed by GitHub
parent e3565a3f43
commit 1f3318c1c5
9 changed files with 998 additions and 18 deletions

View File

@@ -85,11 +85,13 @@ use crate::protocol::SandboxPolicy;
use crate::protocol::SessionConfiguredEvent;
use crate::protocol::Submission;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TurnDiffEvent;
use crate::rollout::RolloutRecorder;
use crate::safety::SafetyCheck;
use crate::safety::assess_command_safety;
use crate::safety::assess_safety_for_untrusted_command;
use crate::shell;
use crate::turn_diff_tracker::TurnDiffTracker;
use crate::user_notification::UserNotification;
use crate::util::backoff;
@@ -362,7 +364,11 @@ impl Session {
}
}
async fn notify_exec_command_begin(&self, exec_command_context: ExecCommandContext) {
async fn on_exec_command_begin(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
exec_command_context: ExecCommandContext,
) {
let ExecCommandContext {
sub_id,
call_id,
@@ -374,11 +380,15 @@ impl Session {
Some(ApplyPatchCommandContext {
user_explicitly_approved_this_action,
changes,
}) => EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
auto_approved: !user_explicitly_approved_this_action,
changes,
}),
}) => {
turn_diff_tracker.on_patch_begin(&changes);
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id,
auto_approved: !user_explicitly_approved_this_action,
changes,
})
}
None => EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
call_id,
command: command_for_display.clone(),
@@ -392,8 +402,10 @@ impl Session {
let _ = self.tx_event.send(event).await;
}
async fn notify_exec_command_end(
#[allow(clippy::too_many_arguments)]
async fn on_exec_command_end(
&self,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
call_id: &str,
output: &ExecToolCallOutput,
@@ -433,6 +445,20 @@ impl Session {
msg,
};
let _ = self.tx_event.send(event).await;
// If this is an apply_patch, after we emit the end patch, emit a second event
// with the full turn diff if there is one.
if is_apply_patch {
let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
id: sub_id.into(),
msg,
};
let _ = self.tx_event.send(event).await;
}
}
}
/// Helper that emits a BackgroundEvent with the given message. This keeps
@@ -1006,6 +1032,10 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
.await;
let last_agent_message: Option<String>;
// Although from the perspective of codex.rs, TurnDiffTracker has the lifecycle of a Task which contains
// many turns, from the perspective of the user, it is a single turn.
let mut turn_diff_tracker = TurnDiffTracker::new();
loop {
// Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI
@@ -1037,7 +1067,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
})
})
.collect();
match run_turn(&sess, sub_id.clone(), turn_input).await {
match run_turn(&sess, &mut turn_diff_tracker, sub_id.clone(), turn_input).await {
Ok(turn_output) => {
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
@@ -1163,6 +1193,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
async fn run_turn(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
@@ -1177,7 +1208,7 @@ async fn run_turn(
let mut retries = 0;
loop {
match try_run_turn(sess, &sub_id, &prompt).await {
match try_run_turn(sess, turn_diff_tracker, &sub_id, &prompt).await {
Ok(output) => return Ok(output),
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
@@ -1223,6 +1254,7 @@ struct ProcessedResponseItem {
async fn try_run_turn(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<Vec<ProcessedResponseItem>> {
@@ -1310,7 +1342,8 @@ async fn try_run_turn(
match event {
ResponseEvent::Created => {}
ResponseEvent::OutputItemDone(item) => {
let response = handle_response_item(sess, sub_id, item.clone()).await?;
let response =
handle_response_item(sess, turn_diff_tracker, sub_id, item.clone()).await?;
output.push(ProcessedResponseItem { item, response });
}
@@ -1328,6 +1361,16 @@ async fn try_run_turn(
.ok();
}
let unified_diff = turn_diff_tracker.get_unified_diff();
if let Ok(Some(unified_diff)) = unified_diff {
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
let event = Event {
id: sub_id.to_string(),
msg,
};
let _ = sess.tx_event.send(event).await;
}
return Ok(output);
}
ResponseEvent::OutputTextDelta(delta) => {
@@ -1432,6 +1475,7 @@ async fn run_compact_task(
async fn handle_response_item(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: &str,
item: ResponseItem,
) -> CodexResult<Option<ResponseInputItem>> {
@@ -1469,7 +1513,17 @@ async fn handle_response_item(
..
} => {
info!("FunctionCall: {arguments}");
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
Some(
handle_function_call(
sess,
turn_diff_tracker,
sub_id.to_string(),
name,
arguments,
call_id,
)
.await,
)
}
ResponseItem::LocalShellCall {
id,
@@ -1504,6 +1558,7 @@ async fn handle_response_item(
handle_container_exec_with_params(
exec_params,
sess,
turn_diff_tracker,
sub_id.to_string(),
effective_call_id,
)
@@ -1521,6 +1576,7 @@ async fn handle_response_item(
async fn handle_function_call(
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
name: String,
arguments: String,
@@ -1534,7 +1590,8 @@ async fn handle_function_call(
return *output;
}
};
handle_container_exec_with_params(params, sess, sub_id, call_id).await
handle_container_exec_with_params(params, sess, turn_diff_tracker, sub_id, call_id)
.await
}
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
_ => {
@@ -1608,6 +1665,7 @@ fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams
async fn handle_container_exec_with_params(
params: ExecParams,
sess: &Session,
turn_diff_tracker: &mut TurnDiffTracker,
sub_id: String,
call_id: String,
) -> ResponseInputItem {
@@ -1755,7 +1813,7 @@ async fn handle_container_exec_with_params(
},
),
};
sess.notify_exec_command_begin(exec_command_context.clone())
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context.clone())
.await;
let params = maybe_run_with_user_profile(params, sess);
@@ -1782,7 +1840,8 @@ async fn handle_container_exec_with_params(
duration,
} = &output;
sess.notify_exec_command_end(
sess.on_exec_command_end(
turn_diff_tracker,
&sub_id,
&call_id,
&output,
@@ -1806,7 +1865,15 @@ async fn handle_container_exec_with_params(
}
}
Err(CodexErr::Sandbox(error)) => {
handle_sandbox_error(params, exec_command_context, error, sandbox_type, sess).await
handle_sandbox_error(
turn_diff_tracker,
params,
exec_command_context,
error,
sandbox_type,
sess,
)
.await
}
Err(e) => {
// Handle non-sandbox errors
@@ -1822,6 +1889,7 @@ async fn handle_container_exec_with_params(
}
async fn handle_sandbox_error(
turn_diff_tracker: &mut TurnDiffTracker,
params: ExecParams,
exec_command_context: ExecCommandContext,
error: SandboxErr,
@@ -1878,7 +1946,8 @@ async fn handle_sandbox_error(
sess.notify_background_event(&sub_id, "retrying command without sandbox")
.await;
sess.notify_exec_command_begin(exec_command_context).await;
sess.on_exec_command_begin(turn_diff_tracker, exec_command_context)
.await;
// This is an escalated retry; the policy will not be
// examined and the sandbox has been set to `None`.
@@ -1905,8 +1974,14 @@ async fn handle_sandbox_error(
duration,
} = &retry_output;
sess.notify_exec_command_end(&sub_id, &call_id, &retry_output, is_apply_patch)
.await;
sess.on_exec_command_end(
turn_diff_tracker,
&sub_id,
&call_id,
&retry_output,
is_apply_patch,
)
.await;
let is_success = *exit_code == 0;
let content = format_exec_output(