fix: Record EnvironmentContext in SendUserTurn (#3678)

## Summary
SendUserTurn has not been correctly handling updates to policies. While
the tui protocol handles this in `Op::OverrideTurnContext`, the
SendUserTurn should be appending `EnvironmentContext` messages when the
sandbox settings change. MCP client behavior should match the cli
behavior, so we update `SendUserTurn` message to match.

## Testing
- [x] Added prompt caching tests
This commit is contained in:
Dylan
2025-09-16 11:32:20 -07:00
committed by GitHub
parent 244687303b
commit 11285655c4
4 changed files with 410 additions and 9 deletions

View File

@@ -1348,10 +1348,21 @@ async fn submission_loop(
cwd,
is_review_mode: false,
};
// TODO: record the new environment context in the conversation history
// if the environment context has changed, record it in the conversation history
let previous_env_context = EnvironmentContext::from(turn_context.as_ref());
let new_env_context = EnvironmentContext::from(&fresh_turn_context);
if !new_env_context.equals_except_shell(&previous_env_context) {
sess.record_conversation_items(&[ResponseItem::from(new_env_context)])
.await;
}
// Install the new persistent context for subsequent tasks/turns.
turn_context = Arc::new(fresh_turn_context);
// no current task, spawn a new one with the perturn context
let task =
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
sess.set_task(task);
}
}

View File

@@ -2,6 +2,7 @@ use serde::Deserialize;
use serde::Serialize;
use strum_macros::Display as DeriveDisplay;
use crate::codex::TurnContext;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
use crate::shell::Shell;
@@ -71,6 +72,39 @@ impl EnvironmentContext {
shell,
}
}
/// Compares two environment contexts, ignoring the shell. Useful when
/// comparing turn to turn, since the initial environment_context will
/// include the shell, and then it is not configurable from turn to turn.
pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool {
let EnvironmentContext {
cwd,
approval_policy,
sandbox_mode,
network_access,
writable_roots,
// should compare all fields except shell
shell: _,
} = other;
self.cwd == *cwd
&& self.approval_policy == *approval_policy
&& self.sandbox_mode == *sandbox_mode
&& self.network_access == *network_access
&& self.writable_roots == *writable_roots
}
}
impl From<&TurnContext> for EnvironmentContext {
fn from(turn_context: &TurnContext) -> Self {
Self::new(
Some(turn_context.cwd.clone()),
Some(turn_context.approval_policy),
Some(turn_context.sandbox_policy.clone()),
// Shell is not configurable from turn to turn
None,
)
}
}
impl EnvironmentContext {
@@ -140,6 +174,9 @@ impl From<EnvironmentContext> for ResponseItem {
#[cfg(test)]
mod tests {
use crate::shell::BashShell;
use crate::shell::ZshShell;
use super::*;
use pretty_assertions::assert_eq;
@@ -210,4 +247,82 @@ mod tests {
assert_eq!(context.serialize_to_xml(), expected);
}
#[test]
fn equals_except_shell_compares_approval_policy() {
// Approval policy
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::Never),
Some(workspace_write_policy(vec!["/repo"], true)),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_compares_sandbox_policy() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_read_only_policy()),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_workspace_write_policy()),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_compares_workspace_write_policy() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_ignores_shell() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Bash(BashShell {
shell_path: "/bin/bash".into(),
bashrc_path: "/home/user/.bashrc".into(),
})),
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Zsh(ZshShell {
shell_path: "/bin/zsh".into(),
zshrc_path: "/home/user/.zshrc".into(),
})),
);
assert!(context1.equals_except_shell(&context2));
}
}

View File

@@ -5,20 +5,20 @@ use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ZshShell {
shell_path: String,
zshrc_path: String,
pub(crate) shell_path: String,
pub(crate) zshrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct BashShell {
shell_path: String,
bashrc_path: String,
pub(crate) shell_path: String,
pub(crate) bashrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct PowerShellConfig {
exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]

View File

@@ -12,6 +12,7 @@ use codex_core::protocol::Op;
use codex_core::protocol::SandboxPolicy;
use codex_core::protocol_config_types::ReasoningEffort;
use codex_core::protocol_config_types::ReasoningSummary;
use codex_core::shell::Shell;
use codex_core::shell::default_user_shell;
use core_test_support::load_default_config_for_test;
use core_test_support::load_sse_fixture_with_id;
@@ -23,6 +24,30 @@ use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
fn text_user_input(text: String) -> serde_json::Value {
serde_json::json!({
"type": "message",
"role": "user",
"content": [ { "type": "input_text", "text": text } ]
})
}
fn default_env_context_str(cwd: &str, shell: &Shell) -> String {
format!(
r#"<environment_context>
<cwd>{}</cwd>
<approval_policy>on-request</approval_policy>
<sandbox_mode>read-only</sandbox_mode>
<network_access>restricted</network_access>
{}</environment_context>"#,
cwd,
match shell.name() {
Some(name) => format!(" <shell>{name}</shell>\n"),
None => String::new(),
}
)
}
/// Build minimal SSE stream with completed marker using the JSON fixture.
fn sse_completed(id: &str) -> String {
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
@@ -546,12 +571,262 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
"role": "user",
"content": [ { "type": "input_text", "text": "hello 2" } ]
});
let expected_env_text_2 = format!(
r#"<environment_context>
<cwd>{}</cwd>
<approval_policy>never</approval_policy>
<sandbox_mode>workspace-write</sandbox_mode>
<network_access>enabled</network_access>
<writable_roots>
<root>{}</root>
</writable_roots>
</environment_context>"#,
new_cwd.path().to_string_lossy(),
writable.path().to_string_lossy(),
);
let expected_env_msg_2 = serde_json::json!({
"type": "message",
"role": "user",
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
});
let expected_body2 = serde_json::json!(
[
body1["input"].as_array().unwrap().as_slice(),
[expected_user_message_2].as_slice(),
[expected_env_msg_2, expected_user_message_2].as_slice(),
]
.concat()
);
assert_eq!(body2["input"], expected_body2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_user_turn_with_no_changes_does_not_send_environment_context() {
use pretty_assertions::assert_eq;
let server = MockServer::start().await;
let sse = sse_completed("resp");
let template = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse, "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(template)
.expect(2)
.mount(&server)
.await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
..built_in_model_providers()["openai"].clone()
};
let cwd = TempDir::new().unwrap();
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.cwd = cwd.path().to_path_buf();
config.model_provider = model_provider;
config.user_instructions = Some("be consistent and helpful".to_string());
let default_cwd = config.cwd.clone();
let default_approval_policy = config.approval_policy;
let default_sandbox_policy = config.sandbox_policy.clone();
let default_model = config.model.clone();
let default_effort = config.model_reasoning_effort;
let default_summary = config.model_reasoning_summary;
let conversation_manager =
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
let codex = conversation_manager
.new_conversation(config)
.await
.expect("create new conversation")
.conversation;
codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: "hello 1".into(),
}],
cwd: default_cwd.clone(),
approval_policy: default_approval_policy,
sandbox_policy: default_sandbox_policy.clone(),
model: default_model.clone(),
effort: default_effort,
summary: default_summary,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: "hello 2".into(),
}],
cwd: default_cwd.clone(),
approval_policy: default_approval_policy,
sandbox_policy: default_sandbox_policy.clone(),
model: default_model.clone(),
effort: default_effort,
summary: default_summary,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 2, "expected two POST requests");
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
let shell = default_user_shell().await;
let expected_ui_text =
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
let expected_ui_msg = text_user_input(expected_ui_text.to_string());
let expected_env_msg_1 = text_user_input(default_env_context_str(
&cwd.path().to_string_lossy(),
&shell,
));
let expected_user_message_1 = text_user_input("hello 1".to_string());
let expected_input_1 = serde_json::Value::Array(vec![
expected_ui_msg.clone(),
expected_env_msg_1.clone(),
expected_user_message_1.clone(),
]);
assert_eq!(body1["input"], expected_input_1);
let expected_user_message_2 = text_user_input("hello 2".to_string());
let expected_input_2 = serde_json::Value::Array(vec![
expected_ui_msg,
expected_env_msg_1,
expected_user_message_1,
expected_user_message_2,
]);
assert_eq!(body2["input"], expected_input_2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_user_turn_with_changes_sends_environment_context() {
use pretty_assertions::assert_eq;
let server = MockServer::start().await;
let sse = sse_completed("resp");
let template = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse, "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(template)
.expect(2)
.mount(&server)
.await;
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
..built_in_model_providers()["openai"].clone()
};
let cwd = TempDir::new().unwrap();
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.cwd = cwd.path().to_path_buf();
config.model_provider = model_provider;
config.user_instructions = Some("be consistent and helpful".to_string());
let default_cwd = config.cwd.clone();
let default_approval_policy = config.approval_policy;
let default_sandbox_policy = config.sandbox_policy.clone();
let default_model = config.model.clone();
let default_effort = config.model_reasoning_effort;
let default_summary = config.model_reasoning_summary;
let conversation_manager =
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
let codex = conversation_manager
.new_conversation(config.clone())
.await
.expect("create new conversation")
.conversation;
codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: "hello 1".into(),
}],
cwd: default_cwd.clone(),
approval_policy: default_approval_policy,
sandbox_policy: default_sandbox_policy.clone(),
model: default_model,
effort: default_effort,
summary: default_summary,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
codex
.submit(Op::UserTurn {
items: vec![InputItem::Text {
text: "hello 2".into(),
}],
cwd: default_cwd.clone(),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: "o3".to_string(),
effort: Some(ReasoningEffort::High),
summary: ReasoningSummary::Detailed,
})
.await
.unwrap();
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let requests = server.received_requests().await.unwrap();
assert_eq!(requests.len(), 2, "expected two POST requests");
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
let shell = default_user_shell().await;
let expected_ui_text =
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
let expected_ui_msg = serde_json::json!({
"type": "message",
"role": "user",
"content": [ { "type": "input_text", "text": expected_ui_text } ]
});
let expected_env_text_1 = default_env_context_str(&default_cwd.to_string_lossy(), &shell);
let expected_env_msg_1 = text_user_input(expected_env_text_1);
let expected_user_message_1 = text_user_input("hello 1".to_string());
let expected_input_1 = serde_json::Value::Array(vec![
expected_ui_msg.clone(),
expected_env_msg_1.clone(),
expected_user_message_1.clone(),
]);
assert_eq!(body1["input"], expected_input_1);
let expected_env_msg_2 = text_user_input(format!(
r#"<environment_context>
<cwd>{}</cwd>
<approval_policy>never</approval_policy>
<sandbox_mode>danger-full-access</sandbox_mode>
<network_access>enabled</network_access>
</environment_context>"#,
default_cwd.to_string_lossy()
));
let expected_user_message_2 = text_user_input("hello 2".to_string());
let expected_input_2 = serde_json::Value::Array(vec![
expected_ui_msg,
expected_env_msg_1,
expected_user_message_1,
expected_env_msg_2,
expected_user_message_2,
]);
assert_eq!(body2["input"], expected_input_2);
}