Tui: fix backtracking (#4020)

Backtracking multiple times could drop earlier turns. We now derive the
active user-turn positions from the transcript on demand (keying off the
latest session header) instead of caching state. This keeps the replayed
context intact during repeated edits and adds a regression test.
This commit is contained in:
friel-openai
2025-09-22 11:16:25 -07:00
committed by GitHub
parent fa80bbb587
commit 76a9b11678
2 changed files with 145 additions and 49 deletions

View File

@@ -443,11 +443,20 @@ impl App {
mod tests { mod tests {
use super::*; use super::*;
use crate::app_backtrack::BacktrackState; use crate::app_backtrack::BacktrackState;
use crate::app_backtrack::user_count;
use crate::chatwidget::tests::make_chatwidget_manual_with_sender; use crate::chatwidget::tests::make_chatwidget_manual_with_sender;
use crate::file_search::FileSearchManager; use crate::file_search::FileSearchManager;
use crate::history_cell::AgentMessageCell;
use crate::history_cell::HistoryCell;
use crate::history_cell::UserHistoryCell;
use crate::history_cell::new_session_info;
use codex_core::AuthManager; use codex_core::AuthManager;
use codex_core::CodexAuth; use codex_core::CodexAuth;
use codex_core::ConversationManager; use codex_core::ConversationManager;
use codex_core::protocol::SessionConfiguredEvent;
use codex_protocol::mcp_protocol::ConversationId;
use ratatui::prelude::Line;
use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicBool;
@@ -498,4 +507,66 @@ mod tests {
Some(ReasoningEffortConfig::High) Some(ReasoningEffortConfig::High)
); );
} }
#[test]
fn backtrack_selection_with_duplicate_history_targets_unique_turn() {
let mut app = make_test_app();
let user_cell = |text: &str| -> Arc<dyn HistoryCell> {
Arc::new(UserHistoryCell {
message: text.to_string(),
}) as Arc<dyn HistoryCell>
};
let agent_cell = |text: &str| -> Arc<dyn HistoryCell> {
Arc::new(AgentMessageCell::new(
vec![Line::from(text.to_string())],
true,
)) as Arc<dyn HistoryCell>
};
let make_header = |is_first| {
let event = SessionConfiguredEvent {
session_id: ConversationId::new(),
model: "gpt-test".to_string(),
reasoning_effort: None,
history_log_id: 0,
history_entry_count: 0,
initial_messages: None,
rollout_path: PathBuf::new(),
};
Arc::new(new_session_info(
app.chat_widget.config_ref(),
event,
is_first,
)) as Arc<dyn HistoryCell>
};
// Simulate the transcript after trimming for a fork, replaying history, and
// appending the edited turn. The session header separates the retained history
// from the forked conversation's replayed turns.
app.transcript_cells = vec![
make_header(true),
user_cell("first question"),
agent_cell("answer first"),
user_cell("follow-up"),
agent_cell("answer follow-up"),
make_header(false),
user_cell("first question"),
agent_cell("answer first"),
user_cell("follow-up (edited)"),
agent_cell("answer edited"),
];
assert_eq!(user_count(&app.transcript_cells), 2);
app.backtrack.base_id = Some(ConversationId::new());
app.backtrack.primed = true;
app.backtrack.nth_user_message = user_count(&app.transcript_cells).saturating_sub(1);
app.confirm_backtrack_from_main();
let (_, nth, prefill) = app.backtrack.pending.clone().expect("pending backtrack");
assert_eq!(nth, 1);
assert_eq!(prefill, "follow-up (edited)");
}
} }

View File

@@ -1,7 +1,9 @@
use std::any::TypeId;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use crate::app::App; use crate::app::App;
use crate::history_cell::CompositeHistoryCell;
use crate::history_cell::UserHistoryCell; use crate::history_cell::UserHistoryCell;
use crate::pager_overlay::Overlay; use crate::pager_overlay::Overlay;
use crate::tui; use crate::tui;
@@ -160,43 +162,47 @@ impl App {
self.backtrack.primed = true; self.backtrack.primed = true;
self.backtrack.base_id = self.chat_widget.conversation_id(); self.backtrack.base_id = self.chat_widget.conversation_id();
self.backtrack.overlay_preview_active = true; self.backtrack.overlay_preview_active = true;
let last_user_cell_position = self let count = user_count(&self.transcript_cells);
.transcript_cells if let Some(last) = count.checked_sub(1) {
.iter() self.apply_backtrack_selection(last);
.filter_map(|c| c.as_any().downcast_ref::<UserHistoryCell>())
.count() as i64
- 1;
if last_user_cell_position >= 0 {
self.apply_backtrack_selection(last_user_cell_position as usize);
} }
tui.frame_requester().schedule_frame(); tui.frame_requester().schedule_frame();
} }
/// Step selection to the next older user message and update overlay. /// Step selection to the next older user message and update overlay.
fn step_backtrack_and_highlight(&mut self, tui: &mut tui::Tui) { fn step_backtrack_and_highlight(&mut self, tui: &mut tui::Tui) {
let last_user_cell_position = self let count = user_count(&self.transcript_cells);
.transcript_cells if count == 0 {
.iter() return;
.filter(|c| c.as_any().is::<UserHistoryCell>()) }
.take(self.backtrack.nth_user_message)
.count() let last_index = count.saturating_sub(1);
.saturating_sub(1); let next_selection = if self.backtrack.nth_user_message == usize::MAX {
self.apply_backtrack_selection(last_user_cell_position); last_index
} else if self.backtrack.nth_user_message == 0 {
0
} else {
self.backtrack
.nth_user_message
.saturating_sub(1)
.min(last_index)
};
self.apply_backtrack_selection(next_selection);
tui.frame_requester().schedule_frame(); tui.frame_requester().schedule_frame();
} }
/// Apply a computed backtrack selection to the overlay and internal counter. /// Apply a computed backtrack selection to the overlay and internal counter.
fn apply_backtrack_selection(&mut self, nth_user_message: usize) { fn apply_backtrack_selection(&mut self, nth_user_message: usize) {
self.backtrack.nth_user_message = nth_user_message; if let Some(cell_idx) = nth_user_position(&self.transcript_cells, nth_user_message) {
if let Some(Overlay::Transcript(t)) = &mut self.overlay { self.backtrack.nth_user_message = nth_user_message;
let cell = self if let Some(Overlay::Transcript(t)) = &mut self.overlay {
.transcript_cells t.set_highlight_cell(Some(cell_idx));
.iter() }
.enumerate() } else {
.filter(|(_, c)| c.as_any().is::<UserHistoryCell>()) self.backtrack.nth_user_message = usize::MAX;
.nth(nth_user_message); if let Some(Overlay::Transcript(t)) = &mut self.overlay {
if let Some((idx, _)) = cell { t.set_highlight_cell(None);
t.set_highlight_cell(Some(idx));
} }
} }
} }
@@ -217,13 +223,9 @@ impl App {
fn overlay_confirm_backtrack(&mut self, tui: &mut tui::Tui) { fn overlay_confirm_backtrack(&mut self, tui: &mut tui::Tui) {
let nth_user_message = self.backtrack.nth_user_message; let nth_user_message = self.backtrack.nth_user_message;
if let Some(base_id) = self.backtrack.base_id { if let Some(base_id) = self.backtrack.base_id {
let user_cells = self let prefill = nth_user_position(&self.transcript_cells, nth_user_message)
.transcript_cells .and_then(|idx| self.transcript_cells.get(idx))
.iter() .and_then(|cell| cell.as_any().downcast_ref::<UserHistoryCell>())
.filter_map(|c| c.as_any().downcast_ref::<UserHistoryCell>())
.collect::<Vec<_>>();
let prefill = user_cells
.get(nth_user_message)
.map(|c| c.message.clone()) .map(|c| c.message.clone())
.unwrap_or_default(); .unwrap_or_default();
self.close_transcript_overlay(tui); self.close_transcript_overlay(tui);
@@ -246,14 +248,12 @@ impl App {
/// Computes the prefill from the selected user message and requests history. /// Computes the prefill from the selected user message and requests history.
pub(crate) fn confirm_backtrack_from_main(&mut self) { pub(crate) fn confirm_backtrack_from_main(&mut self) {
if let Some(base_id) = self.backtrack.base_id { if let Some(base_id) = self.backtrack.base_id {
let prefill = self let prefill =
.transcript_cells nth_user_position(&self.transcript_cells, self.backtrack.nth_user_message)
.iter() .and_then(|idx| self.transcript_cells.get(idx))
.filter(|c| c.as_any().is::<UserHistoryCell>()) .and_then(|cell| cell.as_any().downcast_ref::<UserHistoryCell>())
.nth(self.backtrack.nth_user_message) .map(|c| c.message.clone())
.and_then(|c| c.as_any().downcast_ref::<UserHistoryCell>()) .unwrap_or_default();
.map(|c| c.message.clone())
.unwrap_or_default();
self.request_backtrack(prefill, base_id, self.backtrack.nth_user_message); self.request_backtrack(prefill, base_id, self.backtrack.nth_user_message);
} }
self.reset_backtrack_state(); self.reset_backtrack_state();
@@ -363,13 +363,41 @@ fn trim_transcript_cells_to_nth_user(
return; return;
} }
let cut_idx = transcript_cells if let Some(cut_idx) = nth_user_position(transcript_cells, nth_user_message) {
transcript_cells.truncate(cut_idx);
}
}
pub(crate) fn user_count(cells: &[Arc<dyn crate::history_cell::HistoryCell>]) -> usize {
user_positions_iter(cells).count()
}
fn nth_user_position(
cells: &[Arc<dyn crate::history_cell::HistoryCell>],
nth: usize,
) -> Option<usize> {
user_positions_iter(cells)
.enumerate()
.find_map(|(i, idx)| (i == nth).then_some(idx))
}
fn user_positions_iter(
cells: &[Arc<dyn crate::history_cell::HistoryCell>],
) -> impl Iterator<Item = usize> + '_ {
let header_type = TypeId::of::<CompositeHistoryCell>();
let user_type = TypeId::of::<UserHistoryCell>();
let type_of = |cell: &Arc<dyn crate::history_cell::HistoryCell>| cell.as_any().type_id();
let start = cells
.iter()
.rposition(|cell| type_of(cell) == header_type)
.map_or(0, |idx| idx + 1);
cells
.iter() .iter()
.enumerate() .enumerate()
.filter_map(|(idx, cell)| cell.as_any().is::<UserHistoryCell>().then_some(idx)) .skip(start)
.nth(nth_user_message) .filter_map(move |(idx, cell)| (type_of(cell) == user_type).then_some(idx))
.unwrap_or(transcript_cells.len());
transcript_cells.truncate(cut_idx);
} }
#[cfg(test)] #[cfg(test)]
@@ -389,7 +417,6 @@ mod tests {
Arc::new(AgentMessageCell::new(vec![Line::from("assistant")], true)) Arc::new(AgentMessageCell::new(vec![Line::from("assistant")], true))
as Arc<dyn HistoryCell>, as Arc<dyn HistoryCell>,
]; ];
trim_transcript_cells_to_nth_user(&mut cells, 0); trim_transcript_cells_to_nth_user(&mut cells, 0);
assert!(cells.is_empty()); assert!(cells.is_empty());
@@ -406,7 +433,6 @@ mod tests {
Arc::new(AgentMessageCell::new(vec![Line::from("after")], false)) Arc::new(AgentMessageCell::new(vec![Line::from("after")], false))
as Arc<dyn HistoryCell>, as Arc<dyn HistoryCell>,
]; ];
trim_transcript_cells_to_nth_user(&mut cells, 0); trim_transcript_cells_to_nth_user(&mut cells, 0);
assert_eq!(cells.len(), 1); assert_eq!(cells.len(), 1);
@@ -440,7 +466,6 @@ mod tests {
Arc::new(AgentMessageCell::new(vec![Line::from("tail")], false)) Arc::new(AgentMessageCell::new(vec![Line::from("tail")], false))
as Arc<dyn HistoryCell>, as Arc<dyn HistoryCell>,
]; ];
trim_transcript_cells_to_nth_user(&mut cells, 1); trim_transcript_cells_to_nth_user(&mut cells, 1);
assert_eq!(cells.len(), 3); assert_eq!(cells.len(), 3);