Always send entire request context (#1641)
Always store the entire conversation history. Request encrypted COT when not storing Responses. Send entire input context instead of sending previous_response_id
This commit is contained in:
@@ -41,7 +41,7 @@ pub(crate) async fn stream_chat_completions(
|
|||||||
|
|
||||||
for item in &prompt.input {
|
for item in &prompt.input {
|
||||||
match item {
|
match item {
|
||||||
ResponseItem::Message { role, content } => {
|
ResponseItem::Message { role, content, .. } => {
|
||||||
let mut text = String::new();
|
let mut text = String::new();
|
||||||
for c in content {
|
for c in content {
|
||||||
match c {
|
match c {
|
||||||
@@ -58,6 +58,7 @@ pub(crate) async fn stream_chat_completions(
|
|||||||
name,
|
name,
|
||||||
arguments,
|
arguments,
|
||||||
call_id,
|
call_id,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
messages.push(json!({
|
messages.push(json!({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
@@ -259,6 +260,7 @@ async fn process_chat_sse<S>(
|
|||||||
content: vec![ContentItem::OutputText {
|
content: vec![ContentItem::OutputText {
|
||||||
text: content.to_string(),
|
text: content.to_string(),
|
||||||
}],
|
}],
|
||||||
|
id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||||
@@ -300,6 +302,7 @@ async fn process_chat_sse<S>(
|
|||||||
"tool_calls" if fn_call_state.active => {
|
"tool_calls" if fn_call_state.active => {
|
||||||
// Build the FunctionCall response item.
|
// Build the FunctionCall response item.
|
||||||
let item = ResponseItem::FunctionCall {
|
let item = ResponseItem::FunctionCall {
|
||||||
|
id: None,
|
||||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||||
arguments: fn_call_state.arguments.clone(),
|
arguments: fn_call_state.arguments.clone(),
|
||||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||||
@@ -402,6 +405,7 @@ where
|
|||||||
}))) => {
|
}))) => {
|
||||||
if !this.cumulative.is_empty() {
|
if !this.cumulative.is_empty() {
|
||||||
let aggregated_item = crate::models::ResponseItem::Message {
|
let aggregated_item = crate::models::ResponseItem::Message {
|
||||||
|
id: None,
|
||||||
role: "assistant".to_string(),
|
role: "assistant".to_string(),
|
||||||
content: vec![crate::models::ContentItem::OutputText {
|
content: vec![crate::models::ContentItem::OutputText {
|
||||||
text: std::mem::take(&mut this.cumulative),
|
text: std::mem::take(&mut this.cumulative),
|
||||||
|
|||||||
@@ -117,6 +117,15 @@ impl ModelClient {
|
|||||||
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
let full_instructions = prompt.get_full_instructions(&self.config.model);
|
||||||
let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?;
|
let tools_json = create_tools_json_for_responses_api(prompt, &self.config.model)?;
|
||||||
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
|
||||||
|
|
||||||
|
// Request encrypted COT if we are not storing responses,
|
||||||
|
// otherwise reasoning items will be referenced by ID
|
||||||
|
let include = if !prompt.store && reasoning.is_some() {
|
||||||
|
vec!["reasoning.encrypted_content".to_string()]
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
let payload = ResponsesApiRequest {
|
let payload = ResponsesApiRequest {
|
||||||
model: &self.config.model,
|
model: &self.config.model,
|
||||||
instructions: &full_instructions,
|
instructions: &full_instructions,
|
||||||
@@ -125,10 +134,10 @@ impl ModelClient {
|
|||||||
tool_choice: "auto",
|
tool_choice: "auto",
|
||||||
parallel_tool_calls: false,
|
parallel_tool_calls: false,
|
||||||
reasoning,
|
reasoning,
|
||||||
previous_response_id: prompt.prev_id.clone(),
|
|
||||||
store: prompt.store,
|
store: prompt.store,
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
stream: true,
|
stream: true,
|
||||||
|
include,
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
|||||||
pub struct Prompt {
|
pub struct Prompt {
|
||||||
/// Conversation context input items.
|
/// Conversation context input items.
|
||||||
pub input: Vec<ResponseItem>,
|
pub input: Vec<ResponseItem>,
|
||||||
/// Optional previous response ID (when storage is enabled).
|
|
||||||
pub prev_id: Option<String>,
|
|
||||||
/// Optional instructions from the user to amend to the built-in agent
|
/// Optional instructions from the user to amend to the built-in agent
|
||||||
/// instructions.
|
/// instructions.
|
||||||
pub user_instructions: Option<String>,
|
pub user_instructions: Option<String>,
|
||||||
@@ -133,11 +131,10 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
|||||||
pub(crate) tool_choice: &'static str,
|
pub(crate) tool_choice: &'static str,
|
||||||
pub(crate) parallel_tool_calls: bool,
|
pub(crate) parallel_tool_calls: bool,
|
||||||
pub(crate) reasoning: Option<Reasoning>,
|
pub(crate) reasoning: Option<Reasoning>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub(crate) previous_response_id: Option<String>,
|
|
||||||
/// true when using the Responses API.
|
/// true when using the Responses API.
|
||||||
pub(crate) store: bool,
|
pub(crate) store: bool,
|
||||||
pub(crate) stream: bool,
|
pub(crate) stream: bool,
|
||||||
|
pub(crate) include: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ use tracing::trace;
|
|||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::WireApi;
|
|
||||||
use crate::client::ModelClient;
|
use crate::client::ModelClient;
|
||||||
use crate::client_common::Prompt;
|
use crate::client_common::Prompt;
|
||||||
use crate::client_common::ResponseEvent;
|
use crate::client_common::ResponseEvent;
|
||||||
@@ -191,6 +190,7 @@ pub(crate) struct Session {
|
|||||||
sandbox_policy: SandboxPolicy,
|
sandbox_policy: SandboxPolicy,
|
||||||
shell_environment_policy: ShellEnvironmentPolicy,
|
shell_environment_policy: ShellEnvironmentPolicy,
|
||||||
writable_roots: Mutex<Vec<PathBuf>>,
|
writable_roots: Mutex<Vec<PathBuf>>,
|
||||||
|
disable_response_storage: bool,
|
||||||
|
|
||||||
/// Manager for external MCP servers/tools.
|
/// Manager for external MCP servers/tools.
|
||||||
mcp_connection_manager: McpConnectionManager,
|
mcp_connection_manager: McpConnectionManager,
|
||||||
@@ -219,13 +219,9 @@ impl Session {
|
|||||||
struct State {
|
struct State {
|
||||||
approved_commands: HashSet<Vec<String>>,
|
approved_commands: HashSet<Vec<String>>,
|
||||||
current_task: Option<AgentTask>,
|
current_task: Option<AgentTask>,
|
||||||
/// Call IDs that have been sent from the Responses API but have not been sent back yet.
|
|
||||||
/// You CANNOT send a Responses API follow-up message unless you have sent back the output for all pending calls or else it will 400.
|
|
||||||
pending_call_ids: HashSet<String>,
|
|
||||||
previous_response_id: Option<String>,
|
|
||||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||||
pending_input: Vec<ResponseInputItem>,
|
pending_input: Vec<ResponseInputItem>,
|
||||||
zdr_transcript: Option<ConversationHistory>,
|
history: ConversationHistory,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@@ -320,18 +316,11 @@ impl Session {
|
|||||||
debug!("Recording items for conversation: {items:?}");
|
debug!("Recording items for conversation: {items:?}");
|
||||||
self.record_state_snapshot(items).await;
|
self.record_state_snapshot(items).await;
|
||||||
|
|
||||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
self.state.lock().unwrap().history.record_items(items);
|
||||||
transcript.record_items(items);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||||
let snapshot = {
|
let snapshot = { crate::rollout::SessionStateSnapshot {} };
|
||||||
let state = self.state.lock().unwrap();
|
|
||||||
crate::rollout::SessionStateSnapshot {
|
|
||||||
previous_response_id: state.previous_response_id.clone(),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let recorder = {
|
let recorder = {
|
||||||
let guard = self.rollout.lock().unwrap();
|
let guard = self.rollout.lock().unwrap();
|
||||||
@@ -433,8 +422,6 @@ impl Session {
|
|||||||
pub fn abort(&self) {
|
pub fn abort(&self) {
|
||||||
info!("Aborting existing session");
|
info!("Aborting existing session");
|
||||||
let mut state = self.state.lock().unwrap();
|
let mut state = self.state.lock().unwrap();
|
||||||
// Don't clear pending_call_ids because we need to keep track of them to ensure we don't 400 on the next turn.
|
|
||||||
// We will generate a synthetic aborted response for each pending call id.
|
|
||||||
state.pending_approvals.clear();
|
state.pending_approvals.clear();
|
||||||
state.pending_input.clear();
|
state.pending_input.clear();
|
||||||
if let Some(task) = state.current_task.take() {
|
if let Some(task) = state.current_task.take() {
|
||||||
@@ -479,15 +466,10 @@ impl Drop for Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl State {
|
impl State {
|
||||||
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
|
pub fn partial_clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
approved_commands: self.approved_commands.clone(),
|
approved_commands: self.approved_commands.clone(),
|
||||||
previous_response_id: self.previous_response_id.clone(),
|
history: self.history.clone(),
|
||||||
zdr_transcript: if retain_zdr_transcript {
|
|
||||||
self.zdr_transcript.clone()
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -606,13 +588,11 @@ async fn submission_loop(
|
|||||||
}
|
}
|
||||||
// Optionally resume an existing rollout.
|
// Optionally resume an existing rollout.
|
||||||
let mut restored_items: Option<Vec<ResponseItem>> = None;
|
let mut restored_items: Option<Vec<ResponseItem>> = None;
|
||||||
let mut restored_prev_id: Option<String> = None;
|
|
||||||
let rollout_recorder: Option<RolloutRecorder> =
|
let rollout_recorder: Option<RolloutRecorder> =
|
||||||
if let Some(path) = resume_path.as_ref() {
|
if let Some(path) = resume_path.as_ref() {
|
||||||
match RolloutRecorder::resume(path).await {
|
match RolloutRecorder::resume(path).await {
|
||||||
Ok((rec, saved)) => {
|
Ok((rec, saved)) => {
|
||||||
session_id = saved.session_id;
|
session_id = saved.session_id;
|
||||||
restored_prev_id = saved.state.previous_response_id;
|
|
||||||
if !saved.items.is_empty() {
|
if !saved.items.is_empty() {
|
||||||
restored_items = Some(saved.items);
|
restored_items = Some(saved.items);
|
||||||
}
|
}
|
||||||
@@ -651,22 +631,13 @@ async fn submission_loop(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// abort any current running session and clone its state
|
// abort any current running session and clone its state
|
||||||
let retain_zdr_transcript =
|
|
||||||
record_conversation_history(disable_response_storage, provider.wire_api);
|
|
||||||
let state = match sess.take() {
|
let state = match sess.take() {
|
||||||
Some(sess) => {
|
Some(sess) => {
|
||||||
sess.abort();
|
sess.abort();
|
||||||
sess.state
|
sess.state.lock().unwrap().partial_clone()
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.partial_clone(retain_zdr_transcript)
|
|
||||||
}
|
}
|
||||||
None => State {
|
None => State {
|
||||||
zdr_transcript: if retain_zdr_transcript {
|
history: ConversationHistory::new(),
|
||||||
Some(ConversationHistory::new())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
},
|
|
||||||
..Default::default()
|
..Default::default()
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@@ -717,18 +688,14 @@ async fn submission_loop(
|
|||||||
state: Mutex::new(state),
|
state: Mutex::new(state),
|
||||||
rollout: Mutex::new(rollout_recorder),
|
rollout: Mutex::new(rollout_recorder),
|
||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
|
disable_response_storage,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Patch restored state into the newly created session.
|
// Patch restored state into the newly created session.
|
||||||
if let Some(sess_arc) = &sess {
|
if let Some(sess_arc) = &sess {
|
||||||
if restored_prev_id.is_some() || restored_items.is_some() {
|
if restored_items.is_some() {
|
||||||
let mut st = sess_arc.state.lock().unwrap();
|
let mut st = sess_arc.state.lock().unwrap();
|
||||||
st.previous_response_id = restored_prev_id;
|
st.history.record_items(restored_items.unwrap().iter());
|
||||||
if let (Some(hist), Some(items)) =
|
|
||||||
(st.zdr_transcript.as_mut(), restored_items.as_ref())
|
|
||||||
{
|
|
||||||
hist.record_items(items.iter());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -875,14 +842,8 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let mut input_for_next_turn: Vec<ResponseInputItem> = vec![initial_input_for_turn];
|
|
||||||
let last_agent_message: Option<String>;
|
let last_agent_message: Option<String>;
|
||||||
loop {
|
loop {
|
||||||
let mut net_new_turn_input = input_for_next_turn
|
|
||||||
.drain(..)
|
|
||||||
.map(ResponseItem::from)
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// Note that pending_input would be something like a message the user
|
// Note that pending_input would be something like a message the user
|
||||||
// submitted through the UI while the model was running. Though the UI
|
// submitted through the UI while the model was running. Though the UI
|
||||||
// may support this, the model might not.
|
// may support this, the model might not.
|
||||||
@@ -899,29 +860,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
// only record the new items that originated in this turn so that it
|
// only record the new items that originated in this turn so that it
|
||||||
// represents an append-only log without duplicates.
|
// represents an append-only log without duplicates.
|
||||||
let turn_input: Vec<ResponseItem> =
|
let turn_input: Vec<ResponseItem> =
|
||||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
[sess.state.lock().unwrap().history.contents(), pending_input].concat();
|
||||||
// If we are using Chat/ZDR, we need to send the transcript with
|
|
||||||
// every turn. By induction, `transcript` already contains:
|
|
||||||
// - The `input` that kicked off this task.
|
|
||||||
// - Each `ResponseItem` that was recorded in the previous turn.
|
|
||||||
// - Each response to a `ResponseItem` (in practice, the only
|
|
||||||
// response type we seem to have is `FunctionCallOutput`).
|
|
||||||
//
|
|
||||||
// The only thing the `transcript` does not contain is the
|
|
||||||
// `pending_input` that was injected while the model was
|
|
||||||
// running. We need to add that to the conversation history
|
|
||||||
// so that the model can see it in the next turn.
|
|
||||||
[transcript.contents(), pending_input].concat()
|
|
||||||
} else {
|
|
||||||
// In practice, net_new_turn_input should contain only:
|
|
||||||
// - User messages
|
|
||||||
// - Outputs for function calls requested by the model
|
|
||||||
net_new_turn_input.extend(pending_input);
|
|
||||||
|
|
||||||
// Responses API path – we can just send the new items and
|
|
||||||
// record the same.
|
|
||||||
net_new_turn_input
|
|
||||||
};
|
|
||||||
|
|
||||||
let turn_input_messages: Vec<String> = turn_input
|
let turn_input_messages: Vec<String> = turn_input
|
||||||
.iter()
|
.iter()
|
||||||
@@ -997,8 +936,19 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
(ResponseItem::Reasoning { .. }, None) => {
|
(
|
||||||
// Omit from conversation history.
|
ResponseItem::Reasoning {
|
||||||
|
id,
|
||||||
|
summary,
|
||||||
|
encrypted_content,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
) => {
|
||||||
|
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||||
|
id: id.clone(),
|
||||||
|
summary: summary.clone(),
|
||||||
|
encrypted_content: encrypted_content.clone(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||||
@@ -1027,8 +977,6 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
|||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
input_for_next_turn = responses;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!("Turn error: {e:#}");
|
info!("Turn error: {e:#}");
|
||||||
@@ -1056,26 +1004,11 @@ async fn run_turn(
|
|||||||
sub_id: String,
|
sub_id: String,
|
||||||
input: Vec<ResponseItem>,
|
input: Vec<ResponseItem>,
|
||||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||||
// Decide whether to use server-side storage (previous_response_id) or disable it
|
|
||||||
let (prev_id, store) = {
|
|
||||||
let state = sess.state.lock().unwrap();
|
|
||||||
let store = state.zdr_transcript.is_none();
|
|
||||||
let prev_id = if store {
|
|
||||||
state.previous_response_id.clone()
|
|
||||||
} else {
|
|
||||||
// When using ZDR, the Responses API may send previous_response_id
|
|
||||||
// back, but trying to use it results in a 400.
|
|
||||||
None
|
|
||||||
};
|
|
||||||
(prev_id, store)
|
|
||||||
};
|
|
||||||
|
|
||||||
let extra_tools = sess.mcp_connection_manager.list_all_tools();
|
let extra_tools = sess.mcp_connection_manager.list_all_tools();
|
||||||
let prompt = Prompt {
|
let prompt = Prompt {
|
||||||
input,
|
input,
|
||||||
prev_id,
|
|
||||||
user_instructions: sess.user_instructions.clone(),
|
user_instructions: sess.user_instructions.clone(),
|
||||||
store,
|
store: !sess.disable_response_storage,
|
||||||
extra_tools,
|
extra_tools,
|
||||||
base_instructions_override: sess.base_instructions.clone(),
|
base_instructions_override: sess.base_instructions.clone(),
|
||||||
};
|
};
|
||||||
@@ -1149,11 +1082,17 @@ async fn try_run_turn(
|
|||||||
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
// This usually happens because the user interrupted the model before we responded to one of its tool calls
|
||||||
// and then the user sent a follow-up message.
|
// and then the user sent a follow-up message.
|
||||||
let missing_calls = {
|
let missing_calls = {
|
||||||
sess.state
|
prompt
|
||||||
.lock()
|
.input
|
||||||
.unwrap()
|
|
||||||
.pending_call_ids
|
|
||||||
.iter()
|
.iter()
|
||||||
|
.filter_map(|ri| match ri {
|
||||||
|
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
||||||
|
ResponseItem::LocalShellCall {
|
||||||
|
call_id: Some(call_id),
|
||||||
|
..
|
||||||
|
} => Some(call_id),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
.filter_map(|call_id| {
|
.filter_map(|call_id| {
|
||||||
if completed_call_ids.contains(&call_id) {
|
if completed_call_ids.contains(&call_id) {
|
||||||
None
|
None
|
||||||
@@ -1207,31 +1146,14 @@ async fn try_run_turn(
|
|||||||
};
|
};
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
ResponseEvent::Created => {
|
ResponseEvent::Created => {}
|
||||||
let mut state = sess.state.lock().unwrap();
|
|
||||||
// We successfully created a new response and ensured that all pending calls were included so we can clear the pending call ids.
|
|
||||||
state.pending_call_ids.clear();
|
|
||||||
}
|
|
||||||
ResponseEvent::OutputItemDone(item) => {
|
ResponseEvent::OutputItemDone(item) => {
|
||||||
let call_id = match &item {
|
|
||||||
ResponseItem::LocalShellCall {
|
|
||||||
call_id: Some(call_id),
|
|
||||||
..
|
|
||||||
} => Some(call_id),
|
|
||||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
if let Some(call_id) = call_id {
|
|
||||||
// We just got a new call id so we need to make sure to respond to it in the next turn.
|
|
||||||
let mut state = sess.state.lock().unwrap();
|
|
||||||
state.pending_call_ids.insert(call_id.clone());
|
|
||||||
}
|
|
||||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||||
|
|
||||||
output.push(ProcessedResponseItem { item, response });
|
output.push(ProcessedResponseItem { item, response });
|
||||||
}
|
}
|
||||||
ResponseEvent::Completed {
|
ResponseEvent::Completed {
|
||||||
response_id,
|
response_id: _,
|
||||||
token_usage,
|
token_usage,
|
||||||
} => {
|
} => {
|
||||||
if let Some(token_usage) = token_usage {
|
if let Some(token_usage) = token_usage {
|
||||||
@@ -1244,8 +1166,6 @@ async fn try_run_turn(
|
|||||||
.ok();
|
.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut state = sess.state.lock().unwrap();
|
|
||||||
state.previous_response_id = Some(response_id);
|
|
||||||
return Ok(output);
|
return Ok(output);
|
||||||
}
|
}
|
||||||
ResponseEvent::OutputTextDelta(delta) => {
|
ResponseEvent::OutputTextDelta(delta) => {
|
||||||
@@ -1285,7 +1205,7 @@ async fn handle_response_item(
|
|||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
ResponseItem::Reasoning { id: _, summary } => {
|
ResponseItem::Reasoning { summary, .. } => {
|
||||||
for item in summary {
|
for item in summary {
|
||||||
let text = match item {
|
let text = match item {
|
||||||
ReasoningItemReasoningSummary::SummaryText { text } => text,
|
ReasoningItemReasoningSummary::SummaryText { text } => text,
|
||||||
@@ -1302,6 +1222,7 @@ async fn handle_response_item(
|
|||||||
name,
|
name,
|
||||||
arguments,
|
arguments,
|
||||||
call_id,
|
call_id,
|
||||||
|
..
|
||||||
} => {
|
} => {
|
||||||
info!("FunctionCall: {arguments}");
|
info!("FunctionCall: {arguments}");
|
||||||
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await)
|
||||||
@@ -2092,7 +2013,7 @@ fn format_exec_output(output: &str, exit_code: i32, duration: Duration) -> Strin
|
|||||||
|
|
||||||
fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<String> {
|
||||||
responses.iter().rev().find_map(|item| {
|
responses.iter().rev().find_map(|item| {
|
||||||
if let ResponseItem::Message { role, content } = item {
|
if let ResponseItem::Message { role, content, .. } = item {
|
||||||
if role == "assistant" {
|
if role == "assistant" {
|
||||||
content.iter().rev().find_map(|ci| {
|
content.iter().rev().find_map(|ci| {
|
||||||
if let ContentItem::OutputText { text } = ci {
|
if let ContentItem::OutputText { text } = ci {
|
||||||
@@ -2109,15 +2030,3 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// See [`ConversationHistory`] for details.
|
|
||||||
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
|
|
||||||
if disable_response_storage {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
match wire_api {
|
|
||||||
WireApi::Responses => false,
|
|
||||||
WireApi::Chat => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,12 +1,7 @@
|
|||||||
use crate::models::ResponseItem;
|
use crate::models::ResponseItem;
|
||||||
|
|
||||||
/// Transcript of conversation history that is needed:
|
/// Transcript of conversation history
|
||||||
/// - for ZDR clients for which previous_response_id is not available, so we
|
#[derive(Debug, Clone, Default)]
|
||||||
/// must include the transcript with every API call. This must include each
|
|
||||||
/// `function_call` and its corresponding `function_call_output`.
|
|
||||||
/// - for clients using the "chat completions" API as opposed to the
|
|
||||||
/// "responses" API.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct ConversationHistory {
|
pub(crate) struct ConversationHistory {
|
||||||
/// The oldest items are at the beginning of the vector.
|
/// The oldest items are at the beginning of the vector.
|
||||||
items: Vec<ResponseItem>,
|
items: Vec<ResponseItem>,
|
||||||
@@ -44,7 +39,8 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
|||||||
ResponseItem::Message { role, .. } => role.as_str() != "system",
|
ResponseItem::Message { role, .. } => role.as_str() != "system",
|
||||||
ResponseItem::FunctionCallOutput { .. }
|
ResponseItem::FunctionCallOutput { .. }
|
||||||
| ResponseItem::FunctionCall { .. }
|
| ResponseItem::FunctionCall { .. }
|
||||||
| ResponseItem::LocalShellCall { .. } => true,
|
| ResponseItem::LocalShellCall { .. }
|
||||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => false,
|
| ResponseItem::Reasoning { .. } => true,
|
||||||
|
ResponseItem::Other => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use std::collections::HashMap;
|
|||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use mcp_types::CallToolResult;
|
use mcp_types::CallToolResult;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use serde::Deserializer;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde::ser::Serializer;
|
use serde::ser::Serializer;
|
||||||
|
|
||||||
@@ -37,12 +38,14 @@ pub enum ContentItem {
|
|||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub enum ResponseItem {
|
pub enum ResponseItem {
|
||||||
Message {
|
Message {
|
||||||
|
id: Option<String>,
|
||||||
role: String,
|
role: String,
|
||||||
content: Vec<ContentItem>,
|
content: Vec<ContentItem>,
|
||||||
},
|
},
|
||||||
Reasoning {
|
Reasoning {
|
||||||
id: String,
|
id: String,
|
||||||
summary: Vec<ReasoningItemReasoningSummary>,
|
summary: Vec<ReasoningItemReasoningSummary>,
|
||||||
|
encrypted_content: Option<String>,
|
||||||
},
|
},
|
||||||
LocalShellCall {
|
LocalShellCall {
|
||||||
/// Set when using the chat completions API.
|
/// Set when using the chat completions API.
|
||||||
@@ -53,6 +56,7 @@ pub enum ResponseItem {
|
|||||||
action: LocalShellAction,
|
action: LocalShellAction,
|
||||||
},
|
},
|
||||||
FunctionCall {
|
FunctionCall {
|
||||||
|
id: Option<String>,
|
||||||
name: String,
|
name: String,
|
||||||
// The Responses API returns the function call arguments as a *string* that contains
|
// The Responses API returns the function call arguments as a *string* that contains
|
||||||
// JSON, not as an already‑parsed object. We keep it as a raw string here and let
|
// JSON, not as an already‑parsed object. We keep it as a raw string here and let
|
||||||
@@ -78,7 +82,11 @@ pub enum ResponseItem {
|
|||||||
impl From<ResponseInputItem> for ResponseItem {
|
impl From<ResponseInputItem> for ResponseItem {
|
||||||
fn from(item: ResponseInputItem) -> Self {
|
fn from(item: ResponseInputItem) -> Self {
|
||||||
match item {
|
match item {
|
||||||
ResponseInputItem::Message { role, content } => Self::Message { role, content },
|
ResponseInputItem::Message { role, content } => Self::Message {
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
id: None,
|
||||||
|
},
|
||||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||||
Self::FunctionCallOutput { call_id, output }
|
Self::FunctionCallOutput { call_id, output }
|
||||||
}
|
}
|
||||||
@@ -177,7 +185,7 @@ pub struct ShellToolCallParams {
|
|||||||
pub timeout_ms: Option<u64>,
|
pub timeout_ms: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FunctionCallOutputPayload {
|
pub struct FunctionCallOutputPayload {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
#[expect(dead_code)]
|
#[expect(dead_code)]
|
||||||
@@ -205,6 +213,19 @@ impl Serialize for FunctionCallOutputPayload {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for FunctionCallOutputPayload {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let s = String::deserialize(deserializer)?;
|
||||||
|
Ok(FunctionCallOutputPayload {
|
||||||
|
content: s,
|
||||||
|
success: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Implement Display so callers can treat the payload like a plain string when logging or doing
|
// Implement Display so callers can treat the payload like a plain string when logging or doing
|
||||||
// trivial substring checks in tests (existing tests call `.contains()` on the output). Display
|
// trivial substring checks in tests (existing tests call `.contains()` on the output). Display
|
||||||
// returns the raw `content` field.
|
// returns the raw `content` field.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ use tokio::io::AsyncWriteExt;
|
|||||||
use tokio::sync::mpsc::Sender;
|
use tokio::sync::mpsc::Sender;
|
||||||
use tokio::sync::mpsc::{self};
|
use tokio::sync::mpsc::{self};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
use tracing::warn;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
@@ -30,9 +31,7 @@ pub struct SessionMeta {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||||
pub struct SessionStateSnapshot {
|
pub struct SessionStateSnapshot {}
|
||||||
pub previous_response_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||||
pub struct SavedSession {
|
pub struct SavedSession {
|
||||||
@@ -119,8 +118,9 @@ impl RolloutRecorder {
|
|||||||
ResponseItem::Message { .. }
|
ResponseItem::Message { .. }
|
||||||
| ResponseItem::LocalShellCall { .. }
|
| ResponseItem::LocalShellCall { .. }
|
||||||
| ResponseItem::FunctionCall { .. }
|
| ResponseItem::FunctionCall { .. }
|
||||||
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
|
| ResponseItem::FunctionCallOutput { .. }
|
||||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
|
| ResponseItem::Reasoning { .. } => filtered.push(item.clone()),
|
||||||
|
ResponseItem::Other => {
|
||||||
// These should never be serialized.
|
// These should never be serialized.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -172,13 +172,17 @@ impl RolloutRecorder {
|
|||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if let Ok(item) = serde_json::from_value::<ResponseItem>(v.clone()) {
|
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||||
match item {
|
Ok(item) => match item {
|
||||||
ResponseItem::Message { .. }
|
ResponseItem::Message { .. }
|
||||||
| ResponseItem::LocalShellCall { .. }
|
| ResponseItem::LocalShellCall { .. }
|
||||||
| ResponseItem::FunctionCall { .. }
|
| ResponseItem::FunctionCall { .. }
|
||||||
| ResponseItem::FunctionCallOutput { .. } => items.push(item),
|
| ResponseItem::FunctionCallOutput { .. }
|
||||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
| ResponseItem::Reasoning { .. } => items.push(item),
|
||||||
|
ResponseItem::Other => {}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
warn!("failed to parse item: {v:?}, error: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,13 +271,14 @@ async fn rollout_writer(
|
|||||||
ResponseItem::Message { .. }
|
ResponseItem::Message { .. }
|
||||||
| ResponseItem::LocalShellCall { .. }
|
| ResponseItem::LocalShellCall { .. }
|
||||||
| ResponseItem::FunctionCall { .. }
|
| ResponseItem::FunctionCall { .. }
|
||||||
| ResponseItem::FunctionCallOutput { .. } => {
|
| ResponseItem::FunctionCallOutput { .. }
|
||||||
|
| ResponseItem::Reasoning { .. } => {
|
||||||
if let Ok(json) = serde_json::to_string(&item) {
|
if let Ok(json) = serde_json::to_string(&item) {
|
||||||
let _ = file.write_all(json.as_bytes()).await;
|
let _ = file.write_all(json.as_bytes()).await;
|
||||||
let _ = file.write_all(b"\n").await;
|
let _ = file.write_all(b"\n").await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
ResponseItem::Other => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let _ = file.flush().await;
|
let _ = file.flush().await;
|
||||||
|
|||||||
@@ -1,165 +0,0 @@
|
|||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use codex_core::Codex;
|
|
||||||
use codex_core::ModelProviderInfo;
|
|
||||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
|
||||||
use codex_core::protocol::ErrorEvent;
|
|
||||||
use codex_core::protocol::EventMsg;
|
|
||||||
use codex_core::protocol::InputItem;
|
|
||||||
use codex_core::protocol::Op;
|
|
||||||
mod test_support;
|
|
||||||
use serde_json::Value;
|
|
||||||
use tempfile::TempDir;
|
|
||||||
use test_support::load_default_config_for_test;
|
|
||||||
use test_support::load_sse_fixture_with_id;
|
|
||||||
use tokio::time::timeout;
|
|
||||||
use wiremock::Match;
|
|
||||||
use wiremock::Mock;
|
|
||||||
use wiremock::MockServer;
|
|
||||||
use wiremock::Request;
|
|
||||||
use wiremock::ResponseTemplate;
|
|
||||||
use wiremock::matchers::method;
|
|
||||||
use wiremock::matchers::path;
|
|
||||||
|
|
||||||
/// Matcher asserting that JSON body has NO `previous_response_id` field.
|
|
||||||
struct NoPrevId;
|
|
||||||
|
|
||||||
impl Match for NoPrevId {
|
|
||||||
fn matches(&self, req: &Request) -> bool {
|
|
||||||
serde_json::from_slice::<Value>(&req.body)
|
|
||||||
.map(|v| v.get("previous_response_id").is_none())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Matcher asserting that JSON body HAS a `previous_response_id` field.
|
|
||||||
struct HasPrevId;
|
|
||||||
|
|
||||||
impl Match for HasPrevId {
|
|
||||||
fn matches(&self, req: &Request) -> bool {
|
|
||||||
serde_json::from_slice::<Value>(&req.body)
|
|
||||||
.map(|v| v.get("previous_response_id").is_some())
|
|
||||||
.unwrap_or(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
|
||||||
fn sse_completed(id: &str) -> String {
|
|
||||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
|
||||||
async fn keeps_previous_response_id_between_tasks() {
|
|
||||||
#![allow(clippy::unwrap_used)]
|
|
||||||
|
|
||||||
if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() {
|
|
||||||
println!(
|
|
||||||
"Skipping test because it cannot execute when network is disabled in a Codex sandbox."
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock server
|
|
||||||
let server = MockServer::start().await;
|
|
||||||
|
|
||||||
// First request – must NOT include `previous_response_id`.
|
|
||||||
let first = ResponseTemplate::new(200)
|
|
||||||
.insert_header("content-type", "text/event-stream")
|
|
||||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
|
||||||
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path("/v1/responses"))
|
|
||||||
.and(NoPrevId)
|
|
||||||
.respond_with(first)
|
|
||||||
.expect(1)
|
|
||||||
.mount(&server)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Second request – MUST include `previous_response_id`.
|
|
||||||
let second = ResponseTemplate::new(200)
|
|
||||||
.insert_header("content-type", "text/event-stream")
|
|
||||||
.set_body_raw(sse_completed("resp2"), "text/event-stream");
|
|
||||||
|
|
||||||
Mock::given(method("POST"))
|
|
||||||
.and(path("/v1/responses"))
|
|
||||||
.and(HasPrevId)
|
|
||||||
.respond_with(second)
|
|
||||||
.expect(1)
|
|
||||||
.mount(&server)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Configure retry behavior explicitly to avoid mutating process-wide
|
|
||||||
// environment variables.
|
|
||||||
let model_provider = ModelProviderInfo {
|
|
||||||
name: "openai".into(),
|
|
||||||
base_url: format!("{}/v1", server.uri()),
|
|
||||||
// Environment variable that should exist in the test environment.
|
|
||||||
// ModelClient will return an error if the environment variable for the
|
|
||||||
// provider is not set.
|
|
||||||
env_key: Some("PATH".into()),
|
|
||||||
env_key_instructions: None,
|
|
||||||
wire_api: codex_core::WireApi::Responses,
|
|
||||||
query_params: None,
|
|
||||||
http_headers: None,
|
|
||||||
env_http_headers: None,
|
|
||||||
// disable retries so we don't get duplicate calls in this test
|
|
||||||
request_max_retries: Some(0),
|
|
||||||
stream_max_retries: Some(0),
|
|
||||||
stream_idle_timeout_ms: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Init session
|
|
||||||
let codex_home = TempDir::new().unwrap();
|
|
||||||
let mut config = load_default_config_for_test(&codex_home);
|
|
||||||
config.model_provider = model_provider;
|
|
||||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
|
||||||
let (codex, _init_id, _session_id) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
|
||||||
|
|
||||||
// Task 1 – triggers first request (no previous_response_id)
|
|
||||||
codex
|
|
||||||
.submit(Op::UserInput {
|
|
||||||
items: vec![InputItem::Text {
|
|
||||||
text: "hello".into(),
|
|
||||||
}],
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Wait for TaskComplete
|
|
||||||
loop {
|
|
||||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Task 2 – should include `previous_response_id` (triggers second request)
|
|
||||||
codex
|
|
||||||
.submit(Op::UserInput {
|
|
||||||
items: vec![InputItem::Text {
|
|
||||||
text: "again".into(),
|
|
||||||
}],
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Wait for TaskComplete or error
|
|
||||||
loop {
|
|
||||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
match ev.msg {
|
|
||||||
EventMsg::TaskComplete(_) => break,
|
|
||||||
EventMsg::Error(ErrorEvent { message }) => {
|
|
||||||
panic!("unexpected error: {message}")
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
// Ignore other events.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user