feat: add ZDR support to Rust implementation (#642)
This adds support for the `--disable-response-storage` flag across our multiple Rust CLIs to support customers who have opted into Zero-Data Retention (ZDR). The analogous changes to the TypeScript CLI were: * https://github.com/openai/codex/pull/481 * https://github.com/openai/codex/pull/543 For a client using ZDR, `previous_response_id` will never be available, so the `input` field of an API request must include the full transcript of the conversation thus far. As such, this PR changes the type of `Prompt.input` from `Vec<ResponseInputItem>` to `Vec<ResponseItem>`. Practically speaking, `ResponseItem` was effectively a "superset" of `ResponseInputItem` already. The main difference for us is that `ResponseItem` includes the `FunctionCall` variant that we have to include as part of the conversation history in the ZDR case. Another key change in this PR is modifying `try_run_turn()` so that it returns the `Vec<ResponseItem>` for the turn in addition to the `Vec<ResponseInputItem>` produced by `try_run_turn()`. This is because the caller of `run_turn()` needs to record the `Vec<ResponseItem>` when ZDR is enabled. To that end, this PR introduces `ZdrTranscript` (and adds `zdr_transcript: Option<ZdrTranscript>` to `struct State` in `codex.rs`) to take responsibility for maintaining the conversation transcript in the ZDR case.
This commit is contained in:
@@ -55,6 +55,7 @@ use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::util::backoff;
|
||||
use crate::zdr_transcript::ZdrTranscript;
|
||||
|
||||
/// The high-level interface to the Codex system.
|
||||
/// It operates as a queue pair where you send submissions and receive events.
|
||||
@@ -214,6 +215,7 @@ struct State {
|
||||
previous_response_id: Option<String>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
zdr_transcript: Option<ZdrTranscript>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -399,6 +401,7 @@ impl State {
|
||||
Self {
|
||||
approved_commands: self.approved_commands.clone(),
|
||||
previous_response_id: self.previous_response_id.clone(),
|
||||
zdr_transcript: self.zdr_transcript.clone(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -489,6 +492,7 @@ async fn submission_loop(
|
||||
instructions,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
} => {
|
||||
let model = model.unwrap_or_else(|| OPENAI_DEFAULT_MODEL.to_string());
|
||||
info!(model, "Configuring session");
|
||||
@@ -500,7 +504,14 @@ async fn submission_loop(
|
||||
sess.abort();
|
||||
sess.state.lock().unwrap().partial_clone()
|
||||
}
|
||||
None => State::default(),
|
||||
None => State {
|
||||
zdr_transcript: if disable_response_storage {
|
||||
Some(ZdrTranscript::new())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
// update session
|
||||
@@ -587,18 +598,54 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut turn_input = vec![ResponseInputItem::from(input)];
|
||||
let mut pending_response_input: Vec<ResponseInputItem> = vec![ResponseInputItem::from(input)];
|
||||
loop {
|
||||
let pending_input = sess.get_pending_input();
|
||||
turn_input.splice(0..0, pending_input);
|
||||
let mut net_new_turn_input = pending_response_input
|
||||
.drain(..)
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
// may support this, the model might not.
|
||||
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
// If we are using ZDR, we need to send the transcript with every turn.
|
||||
let mut full_transcript = transcript.contents();
|
||||
full_transcript.extend(net_new_turn_input.clone());
|
||||
transcript.record_items(net_new_turn_input);
|
||||
full_transcript
|
||||
} else {
|
||||
net_new_turn_input
|
||||
};
|
||||
|
||||
match run_turn(&sess, sub_id.clone(), turn_input).await {
|
||||
Ok(turn_output) => {
|
||||
if turn_output.is_empty() {
|
||||
let (items, responses): (Vec<_>, Vec<_>) = turn_output
|
||||
.into_iter()
|
||||
.map(|p| (p.item, p.response))
|
||||
.unzip();
|
||||
let responses = responses
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<ResponseInputItem>>();
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items.is_empty() {
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(items);
|
||||
}
|
||||
}
|
||||
|
||||
if responses.is_empty() {
|
||||
debug!("Turn completed");
|
||||
break;
|
||||
}
|
||||
turn_input = turn_output;
|
||||
|
||||
pending_response_input = responses;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
@@ -624,21 +671,31 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
async fn run_turn(
|
||||
sess: &Session,
|
||||
sub_id: String,
|
||||
input: Vec<ResponseInputItem>,
|
||||
) -> CodexResult<Vec<ResponseInputItem>> {
|
||||
let prev_id = {
|
||||
input: Vec<ResponseItem>,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
// Decide whether to use server-side storage (previous_response_id) or disable it
|
||||
let (prev_id, store, is_first_turn) = {
|
||||
let state = sess.state.lock().unwrap();
|
||||
state.previous_response_id.clone()
|
||||
let is_first_turn = state.previous_response_id.is_none();
|
||||
if state.zdr_transcript.is_some() {
|
||||
// When using ZDR, the Reponses API may send previous_response_id
|
||||
// back, but trying to use it results in a 400.
|
||||
(None, true, is_first_turn)
|
||||
} else {
|
||||
(state.previous_response_id.clone(), false, is_first_turn)
|
||||
}
|
||||
};
|
||||
|
||||
let instructions = match prev_id {
|
||||
Some(_) => None,
|
||||
None => sess.instructions.clone(),
|
||||
let instructions = if is_first_turn {
|
||||
sess.instructions.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
prev_id,
|
||||
instructions,
|
||||
store,
|
||||
};
|
||||
|
||||
let mut retries = 0;
|
||||
@@ -676,11 +733,20 @@ async fn run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
/// When the model is prompted, it returns a stream of events. Some of these
|
||||
/// events map to a `ResponseItem`. A `ResponseItem` may need to be
|
||||
/// "handled" such that it produces a `ResponseInputItem` that needs to be
|
||||
/// sent back to the model on the next turn.
|
||||
struct ProcessedResponseItem {
|
||||
item: ResponseItem,
|
||||
response: Option<ResponseInputItem>,
|
||||
}
|
||||
|
||||
async fn try_run_turn(
|
||||
sess: &Session,
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ResponseInputItem>> {
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
|
||||
// Buffer all the incoming messages from the stream first, then execute them.
|
||||
@@ -694,9 +760,8 @@ async fn try_run_turn(
|
||||
for event in input {
|
||||
match event {
|
||||
ResponseEvent::OutputItemDone(item) => {
|
||||
if let Some(item) = handle_response_item(sess, sub_id, item).await? {
|
||||
output.push(item);
|
||||
}
|
||||
let response = handle_response_item(sess, sub_id, item.clone()).await?;
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::Completed { response_id } => {
|
||||
let mut state = sess.state.lock().unwrap();
|
||||
|
||||
Reference in New Issue
Block a user