Move changing turn input functionalities to ConversationHistory (#5473)
We are doing some ad-hoc logic while dealing with conversation history. Ideally, we shouldn't mutate `vec[responseitem]` manually at all and should depend on `ConversationHistory` for those changes. Those changes are: - Adding input to the history - Removing items from the history - Correcting history I am also adding some `error` logs for cases we shouldn't ideally face. For example, we shouldn't be missing `toolcalls` or `outputs`. We shouldn't hit `ContextWindowExceeded` while performing `compact` This refactor will give us granular control over our context management.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::path::PathBuf;
|
||||
@@ -873,7 +872,7 @@ impl Session {
|
||||
history.record_items(std::iter::once(response_item));
|
||||
}
|
||||
RolloutItem::Compacted(compacted) => {
|
||||
let snapshot = history.contents();
|
||||
let snapshot = history.get_history();
|
||||
let user_messages = collect_user_messages(&snapshot);
|
||||
let rebuilt = build_compacted_history(
|
||||
self.build_initial_context(turn_context),
|
||||
@@ -885,7 +884,7 @@ impl Session {
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
history.contents()
|
||||
history.get_history()
|
||||
}
|
||||
|
||||
/// Append ResponseItems to the in-memory conversation history only.
|
||||
@@ -934,11 +933,17 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
// todo (aibrahim): get rid of this method. we shouldn't deal with vec[resposne_item] and rather use ConversationHistory.
|
||||
pub(crate) async fn history_snapshot(&self) -> Vec<ResponseItem> {
|
||||
let state = self.state.lock().await;
|
||||
let mut state = self.state.lock().await;
|
||||
state.history_snapshot()
|
||||
}
|
||||
|
||||
pub(crate) async fn clone_history(&self) -> ConversationHistory {
|
||||
let state = self.state.lock().await;
|
||||
state.clone_history()
|
||||
}
|
||||
|
||||
async fn update_token_usage_info(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
@@ -1030,16 +1035,6 @@ impl Session {
|
||||
self.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
/// history with additional items for this turn.
|
||||
pub async fn turn_input_with_history(&self, extra: Vec<ResponseItem>) -> Vec<ResponseItem> {
|
||||
let history = {
|
||||
let state = self.state.lock().await;
|
||||
state.history_snapshot()
|
||||
};
|
||||
[history, extra].concat()
|
||||
}
|
||||
|
||||
/// Returns the input if there was no task running to inject into
|
||||
pub async fn inject_input(&self, input: Vec<UserInput>) -> Result<(), Vec<UserInput>> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
@@ -1526,11 +1521,13 @@ pub(crate) async fn run_task(
|
||||
// model sees a fresh conversation without the parent session's history.
|
||||
// For normal turns, continue recording to the session history as before.
|
||||
let is_review_mode = turn_context.is_review_mode;
|
||||
let mut review_thread_history: Vec<ResponseItem> = Vec::new();
|
||||
|
||||
let mut review_thread_history: ConversationHistory = ConversationHistory::new();
|
||||
if is_review_mode {
|
||||
// Seed review threads with environment context so the model knows the working directory.
|
||||
review_thread_history.extend(sess.build_initial_context(turn_context.as_ref()));
|
||||
review_thread_history.push(initial_input_for_turn.into());
|
||||
review_thread_history
|
||||
.record_items(sess.build_initial_context(turn_context.as_ref()).iter());
|
||||
review_thread_history.record_items(std::iter::once(&initial_input_for_turn.into()));
|
||||
} else {
|
||||
sess.record_input_and_rollout_usermsg(turn_context.as_ref(), &initial_input_for_turn)
|
||||
.await;
|
||||
@@ -1565,12 +1562,12 @@ pub(crate) async fn run_task(
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> = if is_review_mode {
|
||||
if !pending_input.is_empty() {
|
||||
review_thread_history.extend(pending_input);
|
||||
review_thread_history.record_items(&pending_input);
|
||||
}
|
||||
review_thread_history.clone()
|
||||
review_thread_history.get_history()
|
||||
} else {
|
||||
sess.record_conversation_items(&pending_input).await;
|
||||
sess.turn_input_with_history(pending_input).await
|
||||
sess.history_snapshot().await
|
||||
};
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
@@ -1708,7 +1705,7 @@ pub(crate) async fn run_task(
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
if is_review_mode {
|
||||
review_thread_history
|
||||
.extend(items_to_record_in_conversation_history.clone());
|
||||
.record_items(items_to_record_in_conversation_history.iter());
|
||||
} else {
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
@@ -1927,61 +1924,6 @@ async fn try_run_turn(
|
||||
task_kind: TaskKind,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
// call_ids that are part of this response.
|
||||
let completed_call_ids = prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// call_ids that were pending but are not part of this response.
|
||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||
// and then the user sent a follow-up message.
|
||||
let missing_calls = {
|
||||
prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|ri| match ri {
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id),
|
||||
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id),
|
||||
_ => None,
|
||||
})
|
||||
.filter_map(|call_id| {
|
||||
if completed_call_ids.contains(&call_id) {
|
||||
None
|
||||
} else {
|
||||
Some(call_id.clone())
|
||||
}
|
||||
})
|
||||
.map(|call_id| ResponseItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: "aborted".to_string(),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
let prompt: Cow<Prompt> = if missing_calls.is_empty() {
|
||||
Cow::Borrowed(prompt)
|
||||
} else {
|
||||
// Add the synthetic aborted missing calls to the beginning of the input to ensure all call ids have responses.
|
||||
let input = [missing_calls, prompt.input.clone()].concat();
|
||||
Cow::Owned(Prompt {
|
||||
input,
|
||||
..prompt.clone()
|
||||
})
|
||||
};
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
@@ -1990,11 +1932,12 @@ async fn try_run_turn(
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream_with_task_kind(prompt.as_ref(), task_kind)
|
||||
.stream_with_task_kind(prompt, task_kind)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
@@ -2982,7 +2925,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant1.clone()));
|
||||
|
||||
let summary1 = "summary one";
|
||||
let snapshot1 = live_history.contents();
|
||||
let snapshot1 = live_history.get_history();
|
||||
let user_messages1 = collect_user_messages(&snapshot1);
|
||||
let rebuilt1 = build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3015,7 +2958,7 @@ mod tests {
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant2.clone()));
|
||||
|
||||
let summary2 = "summary two";
|
||||
let snapshot2 = live_history.contents();
|
||||
let snapshot2 = live_history.get_history();
|
||||
let user_messages2 = collect_user_messages(&snapshot2);
|
||||
let rebuilt2 = build_compacted_history(
|
||||
session.build_initial_context(turn_context),
|
||||
@@ -3047,7 +2990,7 @@ mod tests {
|
||||
live_history.record_items(std::iter::once(&assistant3));
|
||||
rollout_items.push(RolloutItem::ResponseItem(assistant3.clone()));
|
||||
|
||||
(rollout_items, live_history.contents())
|
||||
(rollout_items, live_history.get_history())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
Reference in New Issue
Block a user