Tui: fix backtracking (#4020)
Backtracking multiple times could drop earlier turns. We now derive the active user-turn positions from the transcript on demand (keying off the latest session header) instead of caching state. This keeps the replayed context intact during repeated edits and adds a regression test.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use std::any::TypeId;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::app::App;
|
||||
use crate::history_cell::CompositeHistoryCell;
|
||||
use crate::history_cell::UserHistoryCell;
|
||||
use crate::pager_overlay::Overlay;
|
||||
use crate::tui;
|
||||
@@ -160,43 +162,47 @@ impl App {
|
||||
self.backtrack.primed = true;
|
||||
self.backtrack.base_id = self.chat_widget.conversation_id();
|
||||
self.backtrack.overlay_preview_active = true;
|
||||
let last_user_cell_position = self
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.filter_map(|c| c.as_any().downcast_ref::<UserHistoryCell>())
|
||||
.count() as i64
|
||||
- 1;
|
||||
if last_user_cell_position >= 0 {
|
||||
self.apply_backtrack_selection(last_user_cell_position as usize);
|
||||
let count = user_count(&self.transcript_cells);
|
||||
if let Some(last) = count.checked_sub(1) {
|
||||
self.apply_backtrack_selection(last);
|
||||
}
|
||||
tui.frame_requester().schedule_frame();
|
||||
}
|
||||
|
||||
/// Step selection to the next older user message and update overlay.
|
||||
fn step_backtrack_and_highlight(&mut self, tui: &mut tui::Tui) {
|
||||
let last_user_cell_position = self
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.filter(|c| c.as_any().is::<UserHistoryCell>())
|
||||
.take(self.backtrack.nth_user_message)
|
||||
.count()
|
||||
.saturating_sub(1);
|
||||
self.apply_backtrack_selection(last_user_cell_position);
|
||||
let count = user_count(&self.transcript_cells);
|
||||
if count == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let last_index = count.saturating_sub(1);
|
||||
let next_selection = if self.backtrack.nth_user_message == usize::MAX {
|
||||
last_index
|
||||
} else if self.backtrack.nth_user_message == 0 {
|
||||
0
|
||||
} else {
|
||||
self.backtrack
|
||||
.nth_user_message
|
||||
.saturating_sub(1)
|
||||
.min(last_index)
|
||||
};
|
||||
|
||||
self.apply_backtrack_selection(next_selection);
|
||||
tui.frame_requester().schedule_frame();
|
||||
}
|
||||
|
||||
/// Apply a computed backtrack selection to the overlay and internal counter.
|
||||
fn apply_backtrack_selection(&mut self, nth_user_message: usize) {
|
||||
self.backtrack.nth_user_message = nth_user_message;
|
||||
if let Some(Overlay::Transcript(t)) = &mut self.overlay {
|
||||
let cell = self
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, c)| c.as_any().is::<UserHistoryCell>())
|
||||
.nth(nth_user_message);
|
||||
if let Some((idx, _)) = cell {
|
||||
t.set_highlight_cell(Some(idx));
|
||||
if let Some(cell_idx) = nth_user_position(&self.transcript_cells, nth_user_message) {
|
||||
self.backtrack.nth_user_message = nth_user_message;
|
||||
if let Some(Overlay::Transcript(t)) = &mut self.overlay {
|
||||
t.set_highlight_cell(Some(cell_idx));
|
||||
}
|
||||
} else {
|
||||
self.backtrack.nth_user_message = usize::MAX;
|
||||
if let Some(Overlay::Transcript(t)) = &mut self.overlay {
|
||||
t.set_highlight_cell(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,13 +223,9 @@ impl App {
|
||||
fn overlay_confirm_backtrack(&mut self, tui: &mut tui::Tui) {
|
||||
let nth_user_message = self.backtrack.nth_user_message;
|
||||
if let Some(base_id) = self.backtrack.base_id {
|
||||
let user_cells = self
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.filter_map(|c| c.as_any().downcast_ref::<UserHistoryCell>())
|
||||
.collect::<Vec<_>>();
|
||||
let prefill = user_cells
|
||||
.get(nth_user_message)
|
||||
let prefill = nth_user_position(&self.transcript_cells, nth_user_message)
|
||||
.and_then(|idx| self.transcript_cells.get(idx))
|
||||
.and_then(|cell| cell.as_any().downcast_ref::<UserHistoryCell>())
|
||||
.map(|c| c.message.clone())
|
||||
.unwrap_or_default();
|
||||
self.close_transcript_overlay(tui);
|
||||
@@ -246,14 +248,12 @@ impl App {
|
||||
/// Computes the prefill from the selected user message and requests history.
|
||||
pub(crate) fn confirm_backtrack_from_main(&mut self) {
|
||||
if let Some(base_id) = self.backtrack.base_id {
|
||||
let prefill = self
|
||||
.transcript_cells
|
||||
.iter()
|
||||
.filter(|c| c.as_any().is::<UserHistoryCell>())
|
||||
.nth(self.backtrack.nth_user_message)
|
||||
.and_then(|c| c.as_any().downcast_ref::<UserHistoryCell>())
|
||||
.map(|c| c.message.clone())
|
||||
.unwrap_or_default();
|
||||
let prefill =
|
||||
nth_user_position(&self.transcript_cells, self.backtrack.nth_user_message)
|
||||
.and_then(|idx| self.transcript_cells.get(idx))
|
||||
.and_then(|cell| cell.as_any().downcast_ref::<UserHistoryCell>())
|
||||
.map(|c| c.message.clone())
|
||||
.unwrap_or_default();
|
||||
self.request_backtrack(prefill, base_id, self.backtrack.nth_user_message);
|
||||
}
|
||||
self.reset_backtrack_state();
|
||||
@@ -363,13 +363,41 @@ fn trim_transcript_cells_to_nth_user(
|
||||
return;
|
||||
}
|
||||
|
||||
let cut_idx = transcript_cells
|
||||
if let Some(cut_idx) = nth_user_position(transcript_cells, nth_user_message) {
|
||||
transcript_cells.truncate(cut_idx);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn user_count(cells: &[Arc<dyn crate::history_cell::HistoryCell>]) -> usize {
|
||||
user_positions_iter(cells).count()
|
||||
}
|
||||
|
||||
fn nth_user_position(
|
||||
cells: &[Arc<dyn crate::history_cell::HistoryCell>],
|
||||
nth: usize,
|
||||
) -> Option<usize> {
|
||||
user_positions_iter(cells)
|
||||
.enumerate()
|
||||
.find_map(|(i, idx)| (i == nth).then_some(idx))
|
||||
}
|
||||
|
||||
fn user_positions_iter(
|
||||
cells: &[Arc<dyn crate::history_cell::HistoryCell>],
|
||||
) -> impl Iterator<Item = usize> + '_ {
|
||||
let header_type = TypeId::of::<CompositeHistoryCell>();
|
||||
let user_type = TypeId::of::<UserHistoryCell>();
|
||||
let type_of = |cell: &Arc<dyn crate::history_cell::HistoryCell>| cell.as_any().type_id();
|
||||
|
||||
let start = cells
|
||||
.iter()
|
||||
.rposition(|cell| type_of(cell) == header_type)
|
||||
.map_or(0, |idx| idx + 1);
|
||||
|
||||
cells
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, cell)| cell.as_any().is::<UserHistoryCell>().then_some(idx))
|
||||
.nth(nth_user_message)
|
||||
.unwrap_or(transcript_cells.len());
|
||||
transcript_cells.truncate(cut_idx);
|
||||
.skip(start)
|
||||
.filter_map(move |(idx, cell)| (type_of(cell) == user_type).then_some(idx))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -389,7 +417,6 @@ mod tests {
|
||||
Arc::new(AgentMessageCell::new(vec![Line::from("assistant")], true))
|
||||
as Arc<dyn HistoryCell>,
|
||||
];
|
||||
|
||||
trim_transcript_cells_to_nth_user(&mut cells, 0);
|
||||
|
||||
assert!(cells.is_empty());
|
||||
@@ -406,7 +433,6 @@ mod tests {
|
||||
Arc::new(AgentMessageCell::new(vec![Line::from("after")], false))
|
||||
as Arc<dyn HistoryCell>,
|
||||
];
|
||||
|
||||
trim_transcript_cells_to_nth_user(&mut cells, 0);
|
||||
|
||||
assert_eq!(cells.len(), 1);
|
||||
@@ -440,7 +466,6 @@ mod tests {
|
||||
Arc::new(AgentMessageCell::new(vec![Line::from("tail")], false))
|
||||
as Arc<dyn HistoryCell>,
|
||||
];
|
||||
|
||||
trim_transcript_cells_to_nth_user(&mut cells, 1);
|
||||
|
||||
assert_eq!(cells.len(), 3);
|
||||
|
||||
Reference in New Issue
Block a user