[context] Store context messages in rollouts (#2243)
## Summary Currently, we use request-time logic to determine the user_instructions and environment_context messages. This means that neither of these values can change over time as conversations go on. We want to add in additional details here, so we're migrating these to save these messages to the rollout file instead. This is simpler for the client, and allows us to append additional environment_context messages to each turn if we want ## Testing - [x] Integration test coverage - [x] Tested locally with a few turns, confirmed model could reference environment context and cached token metrics were reasonably high
This commit is contained in:
@@ -5,15 +5,11 @@ use crate::model_family::ModelFamily;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::openai_tools::OpenAiTool;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
@@ -23,62 +19,19 @@ use tokio::sync::mpsc;
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>\n\n";
|
||||
const ENVIRONMENT_CONTEXT_END: &str = "\n\n</environment_context>";
|
||||
|
||||
/// wraps user instructions message in a tag for the model to parse more easily.
|
||||
const USER_INSTRUCTIONS_START: &str = "<user_instructions>\n\n";
|
||||
const USER_INSTRUCTIONS_END: &str = "\n\n</user_instructions>";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct EnvironmentContext {
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox_policy: SandboxPolicy,
|
||||
}
|
||||
|
||||
impl Display for EnvironmentContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"Current working directory: {}",
|
||||
self.cwd.to_string_lossy()
|
||||
)?;
|
||||
writeln!(f, "Approval policy: {}", self.approval_policy)?;
|
||||
writeln!(f, "Sandbox policy: {}", self.sandbox_policy)?;
|
||||
|
||||
let network_access = match self.sandbox_policy.clone() {
|
||||
SandboxPolicy::DangerFullAccess => "enabled",
|
||||
SandboxPolicy::ReadOnly => "restricted",
|
||||
SandboxPolicy::WorkspaceWrite { network_access, .. } => {
|
||||
if network_access {
|
||||
"enabled"
|
||||
} else {
|
||||
"restricted"
|
||||
}
|
||||
}
|
||||
};
|
||||
writeln!(f, "Network access: {network_access}")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional instructions from the user to amend to the built-in agent
|
||||
/// instructions.
|
||||
pub user_instructions: Option<String>,
|
||||
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// A list of key-value pairs that will be added as a developer message
|
||||
/// for the model to use
|
||||
pub environment_context: Option<EnvironmentContext>,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub tools: Vec<OpenAiTool>,
|
||||
@@ -100,36 +53,19 @@ impl Prompt {
|
||||
Cow::Owned(sections.join("\n"))
|
||||
}
|
||||
|
||||
fn get_formatted_user_instructions(&self) -> Option<String> {
|
||||
self.user_instructions
|
||||
.as_ref()
|
||||
.map(|ui| format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"))
|
||||
}
|
||||
|
||||
fn get_formatted_environment_context(&self) -> Option<String> {
|
||||
self.environment_context
|
||||
.as_ref()
|
||||
.map(|ec| format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}"))
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input_with_instructions = Vec::with_capacity(self.input.len() + 2);
|
||||
if let Some(ec) = self.get_formatted_environment_context() {
|
||||
input_with_instructions.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: ec }],
|
||||
});
|
||||
self.input.clone()
|
||||
}
|
||||
|
||||
/// Creates a formatted user instructions message from a string
|
||||
pub(crate) fn format_user_instructions_message(ui: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"),
|
||||
}],
|
||||
}
|
||||
if let Some(ui) = self.get_formatted_user_instructions() {
|
||||
input_with_instructions.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: ui }],
|
||||
});
|
||||
}
|
||||
input_with_instructions.extend(self.input.clone());
|
||||
input_with_instructions
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +195,6 @@ mod tests {
|
||||
#[test]
|
||||
fn get_full_instructions_no_user_content() {
|
||||
let prompt = Prompt {
|
||||
user_instructions: Some("custom instruction".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}");
|
||||
|
||||
@@ -38,7 +38,6 @@ use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::apply_patch::get_writable_roots;
|
||||
use crate::apply_patch::{self};
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::EnvironmentContext;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
@@ -46,6 +45,7 @@ use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use crate::config_types::ShellEnvironmentPolicy;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::environment_context::EnvironmentContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::error::SandboxErr;
|
||||
@@ -437,6 +437,20 @@ impl Session {
|
||||
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
|
||||
});
|
||||
|
||||
// record the initial user instructions and environment context, regardless of whether we restored items.
|
||||
if let Some(user_instructions) = sess.get_user_instructions().clone() {
|
||||
sess.record_conversation_items(&[Prompt::format_user_instructions_message(
|
||||
&user_instructions,
|
||||
)])
|
||||
.await;
|
||||
}
|
||||
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
|
||||
sess.get_cwd().to_path_buf(),
|
||||
sess.get_approval_policy(),
|
||||
sess.get_sandbox_policy().clone(),
|
||||
))])
|
||||
.await;
|
||||
|
||||
// Gather history metadata for SessionConfiguredEvent.
|
||||
let (history_log_id, history_entry_count) =
|
||||
crate::message_history::history_metadata(&config).await;
|
||||
@@ -473,6 +487,14 @@ impl Session {
|
||||
&self.cwd
|
||||
}
|
||||
|
||||
pub(crate) fn get_user_instructions(&self) -> Option<String> {
|
||||
self.user_instructions.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn get_sandbox_policy(&self) -> &SandboxPolicy {
|
||||
&self.sandbox_policy
|
||||
}
|
||||
|
||||
fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
||||
path.as_ref()
|
||||
.map(PathBuf::from)
|
||||
@@ -1237,15 +1259,9 @@ async fn run_turn(
|
||||
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
user_instructions: sess.user_instructions.clone(),
|
||||
store: !sess.disable_response_storage,
|
||||
tools,
|
||||
base_instructions_override: sess.base_instructions.clone(),
|
||||
environment_context: Some(EnvironmentContext {
|
||||
cwd: sess.cwd.clone(),
|
||||
approval_policy: sess.approval_policy,
|
||||
sandbox_policy: sess.sandbox_policy.clone(),
|
||||
}),
|
||||
};
|
||||
|
||||
let mut retries = 0;
|
||||
@@ -1483,9 +1499,7 @@ async fn run_compact_task(
|
||||
|
||||
let prompt = Prompt {
|
||||
input: turn_input,
|
||||
user_instructions: None,
|
||||
store: !sess.disable_response_storage,
|
||||
environment_context: None,
|
||||
tools: Vec::new(),
|
||||
base_instructions_override: Some(compact_instructions.clone()),
|
||||
};
|
||||
|
||||
@@ -78,8 +78,9 @@ pub enum HistoryPersistence {
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Tui {}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize)]
|
||||
#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize, Display)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
pub enum SandboxMode {
|
||||
#[serde(rename = "read-only")]
|
||||
#[default]
|
||||
|
||||
86
codex-rs/core/src/environment_context.rs
Normal file
86
codex-rs/core/src/environment_context.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display as DeriveDisplay;
|
||||
|
||||
use crate::config_types::SandboxMode;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// wraps environment context message in a tag for the model to parse more easily.
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_START: &str = "<environment_context>\n";
|
||||
pub(crate) const ENVIRONMENT_CONTEXT_END: &str = "</environment_context>";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
pub enum NetworkAccess {
|
||||
Restricted,
|
||||
Enabled,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename = "environment_context", rename_all = "snake_case")]
|
||||
pub(crate) struct EnvironmentContext {
|
||||
pub cwd: PathBuf,
|
||||
pub approval_policy: AskForApproval,
|
||||
pub sandbox_mode: SandboxMode,
|
||||
pub network_access: NetworkAccess,
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
pub fn new(
|
||||
cwd: PathBuf,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
) -> Self {
|
||||
Self {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_mode: match sandbox_policy {
|
||||
SandboxPolicy::DangerFullAccess => SandboxMode::DangerFullAccess,
|
||||
SandboxPolicy::ReadOnly => SandboxMode::ReadOnly,
|
||||
SandboxPolicy::WorkspaceWrite { .. } => SandboxMode::WorkspaceWrite,
|
||||
},
|
||||
network_access: match sandbox_policy {
|
||||
SandboxPolicy::DangerFullAccess => NetworkAccess::Enabled,
|
||||
SandboxPolicy::ReadOnly => NetworkAccess::Restricted,
|
||||
SandboxPolicy::WorkspaceWrite { network_access, .. } => {
|
||||
if network_access {
|
||||
NetworkAccess::Enabled
|
||||
} else {
|
||||
NetworkAccess::Restricted
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for EnvironmentContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"Current working directory: {}",
|
||||
self.cwd.to_string_lossy()
|
||||
)?;
|
||||
writeln!(f, "Approval policy: {}", self.approval_policy)?;
|
||||
writeln!(f, "Sandbox mode: {}", self.sandbox_mode)?;
|
||||
writeln!(f, "Network access: {}", self.network_access)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EnvironmentContext> for ResponseItem {
|
||||
fn from(ec: EnvironmentContext) -> Self {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}"),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ pub mod config;
|
||||
pub mod config_profile;
|
||||
pub mod config_types;
|
||||
mod conversation_history;
|
||||
mod environment_context;
|
||||
pub mod error;
|
||||
pub mod exec;
|
||||
pub mod exec_env;
|
||||
|
||||
@@ -373,11 +373,11 @@ async fn includes_user_instructions_message_in_request() {
|
||||
.contains("be nice")
|
||||
);
|
||||
assert_message_role(&request_body["input"][0], "user");
|
||||
assert_message_starts_with(&request_body["input"][0], "<environment_context>\n\n");
|
||||
assert_message_ends_with(&request_body["input"][0], "</environment_context>");
|
||||
assert_message_starts_with(&request_body["input"][0], "<user_instructions>");
|
||||
assert_message_ends_with(&request_body["input"][0], "</user_instructions>");
|
||||
assert_message_role(&request_body["input"][1], "user");
|
||||
assert_message_starts_with(&request_body["input"][1], "<user_instructions>\n\n");
|
||||
assert_message_ends_with(&request_body["input"][1], "</user_instructions>");
|
||||
assert_message_starts_with(&request_body["input"][1], "<environment_context>");
|
||||
assert_message_ends_with(&request_body["input"][1], "</environment_context>");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
|
||||
@@ -85,7 +85,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let expected_env_text = format!(
|
||||
"<environment_context>\n\nCurrent working directory: {}\nApproval policy: on-request\nSandbox policy: read-only\nNetwork access: restricted\n\n\n</environment_context>",
|
||||
"<environment_context>\nCurrent working directory: {}\nApproval policy: on-request\nSandbox mode: read-only\nNetwork access: restricted\n</environment_context>",
|
||||
cwd.path().to_string_lossy()
|
||||
);
|
||||
let expected_ui_text =
|
||||
@@ -113,7 +113,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
assert_eq!(
|
||||
body1["input"],
|
||||
serde_json::json!([expected_env_msg, expected_ui_msg, expected_user_message_1])
|
||||
serde_json::json!([expected_ui_msg, expected_env_msg, expected_user_message_1])
|
||||
);
|
||||
|
||||
let expected_user_message_2 = serde_json::json!({
|
||||
|
||||
Reference in New Issue
Block a user