fix: Record EnvironmentContext in SendUserTurn (#3678)
## Summary SendUserTurn has not been correctly handling updates to policies. While the tui protocol handles this in `Op::OverrideTurnContext`, the SendUserTurn should be appending `EnvironmentContext` messages when the sandbox settings change. MCP client behavior should match the cli behavior, so we update `SendUserTurn` message to match. ## Testing - [x] Added prompt caching tests
This commit is contained in:
@@ -1348,10 +1348,21 @@ async fn submission_loop(
|
||||
cwd,
|
||||
is_review_mode: false,
|
||||
};
|
||||
// TODO: record the new environment context in the conversation history
|
||||
|
||||
// if the environment context has changed, record it in the conversation history
|
||||
let previous_env_context = EnvironmentContext::from(turn_context.as_ref());
|
||||
let new_env_context = EnvironmentContext::from(&fresh_turn_context);
|
||||
if !new_env_context.equals_except_shell(&previous_env_context) {
|
||||
sess.record_conversation_items(&[ResponseItem::from(new_env_context)])
|
||||
.await;
|
||||
}
|
||||
|
||||
// Install the new persistent context for subsequent tasks/turns.
|
||||
turn_context = Arc::new(fresh_turn_context);
|
||||
|
||||
// no current task, spawn a new one with the per‑turn context
|
||||
let task =
|
||||
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
|
||||
AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items);
|
||||
sess.set_task(task);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display as DeriveDisplay;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::shell::Shell;
|
||||
@@ -71,6 +72,39 @@ impl EnvironmentContext {
|
||||
shell,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compares two environment contexts, ignoring the shell. Useful when
|
||||
/// comparing turn to turn, since the initial environment_context will
|
||||
/// include the shell, and then it is not configurable from turn to turn.
|
||||
pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool {
|
||||
let EnvironmentContext {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
network_access,
|
||||
writable_roots,
|
||||
// should compare all fields except shell
|
||||
shell: _,
|
||||
} = other;
|
||||
|
||||
self.cwd == *cwd
|
||||
&& self.approval_policy == *approval_policy
|
||||
&& self.sandbox_mode == *sandbox_mode
|
||||
&& self.network_access == *network_access
|
||||
&& self.writable_roots == *writable_roots
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&TurnContext> for EnvironmentContext {
|
||||
fn from(turn_context: &TurnContext) -> Self {
|
||||
Self::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
@@ -140,6 +174,9 @@ impl From<EnvironmentContext> for ResponseItem {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::shell::BashShell;
|
||||
use crate::shell::ZshShell;
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
@@ -210,4 +247,82 @@ mod tests {
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_approval_policy() {
|
||||
// Approval policy
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::Never),
|
||||
Some(workspace_write_policy(vec!["/repo"], true)),
|
||||
None,
|
||||
);
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_sandbox_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_read_only_policy()),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_workspace_write_policy()),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_workspace_write_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_ignores_shell() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".into(),
|
||||
bashrc_path: "/home/user/.bashrc".into(),
|
||||
})),
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Zsh(ZshShell {
|
||||
shell_path: "/bin/zsh".into(),
|
||||
zshrc_path: "/home/user/.zshrc".into(),
|
||||
})),
|
||||
);
|
||||
|
||||
assert!(context1.equals_except_shell(&context2));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,20 +5,20 @@ use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct ZshShell {
|
||||
shell_path: String,
|
||||
zshrc_path: String,
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct BashShell {
|
||||
shell_path: String,
|
||||
bashrc_path: String,
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PowerShellConfig {
|
||||
exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
|
||||
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
|
||||
@@ -12,6 +12,7 @@ use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_core::protocol_config_types::ReasoningEffort;
|
||||
use codex_core::protocol_config_types::ReasoningSummary;
|
||||
use codex_core::shell::Shell;
|
||||
use codex_core::shell::default_user_shell;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::load_sse_fixture_with_id;
|
||||
@@ -23,6 +24,30 @@ use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
fn text_user_input(text: String) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": text } ]
|
||||
})
|
||||
}
|
||||
|
||||
fn default_env_context_str(cwd: &str, shell: &Shell) -> String {
|
||||
format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
{}</environment_context>"#,
|
||||
cwd,
|
||||
match shell.name() {
|
||||
Some(name) => format!(" <shell>{name}</shell>\n"),
|
||||
None => String::new(),
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Build minimal SSE stream with completed marker using the JSON fixture.
|
||||
fn sse_completed(id: &str) -> String {
|
||||
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||
@@ -546,12 +571,262 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() {
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||
});
|
||||
let expected_env_text_2 = format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
<writable_roots>
|
||||
<root>{}</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#,
|
||||
new_cwd.path().to_string_lossy(),
|
||||
writable.path().to_string_lossy(),
|
||||
);
|
||||
let expected_env_msg_2 = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||
});
|
||||
let expected_body2 = serde_json::json!(
|
||||
[
|
||||
body1["input"].as_array().unwrap().as_slice(),
|
||||
[expected_user_message_2].as_slice(),
|
||||
[expected_env_msg_2, expected_user_message_2].as_slice(),
|
||||
]
|
||||
.concat()
|
||||
);
|
||||
assert_eq!(body2["input"], expected_body2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn send_user_turn_with_no_changes_does_not_send_environment_context() {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed("resp");
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
|
||||
let default_cwd = config.cwd.clone();
|
||||
let default_approval_policy = config.approval_policy;
|
||||
let default_sandbox_policy = config.sandbox_policy.clone();
|
||||
let default_model = config.model.clone();
|
||||
let default_effort = config.model_reasoning_effort;
|
||||
let default_summary = config.model_reasoning_summary;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello 1".into(),
|
||||
}],
|
||||
cwd: default_cwd.clone(),
|
||||
approval_policy: default_approval_policy,
|
||||
sandbox_policy: default_sandbox_policy.clone(),
|
||||
model: default_model.clone(),
|
||||
effort: default_effort,
|
||||
summary: default_summary,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello 2".into(),
|
||||
}],
|
||||
cwd: default_cwd.clone(),
|
||||
approval_policy: default_approval_policy,
|
||||
sandbox_policy: default_sandbox_policy.clone(),
|
||||
model: default_model.clone(),
|
||||
effort: default_effort,
|
||||
summary: default_summary,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
let shell = default_user_shell().await;
|
||||
let expected_ui_text =
|
||||
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
|
||||
let expected_ui_msg = text_user_input(expected_ui_text.to_string());
|
||||
|
||||
let expected_env_msg_1 = text_user_input(default_env_context_str(
|
||||
&cwd.path().to_string_lossy(),
|
||||
&shell,
|
||||
));
|
||||
let expected_user_message_1 = text_user_input("hello 1".to_string());
|
||||
|
||||
let expected_input_1 = serde_json::Value::Array(vec![
|
||||
expected_ui_msg.clone(),
|
||||
expected_env_msg_1.clone(),
|
||||
expected_user_message_1.clone(),
|
||||
]);
|
||||
assert_eq!(body1["input"], expected_input_1);
|
||||
|
||||
let expected_user_message_2 = text_user_input("hello 2".to_string());
|
||||
let expected_input_2 = serde_json::Value::Array(vec![
|
||||
expected_ui_msg,
|
||||
expected_env_msg_1,
|
||||
expected_user_message_1,
|
||||
expected_user_message_2,
|
||||
]);
|
||||
assert_eq!(body2["input"], expected_input_2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn send_user_turn_with_changes_sends_environment_context() {
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
let server = MockServer::start().await;
|
||||
|
||||
let sse = sse_completed("resp");
|
||||
let template = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse, "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(template)
|
||||
.expect(2)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let cwd = TempDir::new().unwrap();
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.cwd = cwd.path().to_path_buf();
|
||||
config.model_provider = model_provider;
|
||||
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||
|
||||
let default_cwd = config.cwd.clone();
|
||||
let default_approval_policy = config.approval_policy;
|
||||
let default_sandbox_policy = config.sandbox_policy.clone();
|
||||
let default_model = config.model.clone();
|
||||
let default_effort = config.model_reasoning_effort;
|
||||
let default_summary = config.model_reasoning_summary;
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config.clone())
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello 1".into(),
|
||||
}],
|
||||
cwd: default_cwd.clone(),
|
||||
approval_policy: default_approval_policy,
|
||||
sandbox_policy: default_sandbox_policy.clone(),
|
||||
model: default_model,
|
||||
effort: default_effort,
|
||||
summary: default_summary,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserTurn {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello 2".into(),
|
||||
}],
|
||||
cwd: default_cwd.clone(),
|
||||
approval_policy: AskForApproval::Never,
|
||||
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||
model: "o3".to_string(),
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: ReasoningSummary::Detailed,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = server.received_requests().await.unwrap();
|
||||
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||
|
||||
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
let shell = default_user_shell().await;
|
||||
let expected_ui_text =
|
||||
"<user_instructions>\n\nbe consistent and helpful\n\n</user_instructions>";
|
||||
let expected_ui_msg = serde_json::json!({
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [ { "type": "input_text", "text": expected_ui_text } ]
|
||||
});
|
||||
let expected_env_text_1 = default_env_context_str(&default_cwd.to_string_lossy(), &shell);
|
||||
let expected_env_msg_1 = text_user_input(expected_env_text_1);
|
||||
let expected_user_message_1 = text_user_input("hello 1".to_string());
|
||||
let expected_input_1 = serde_json::Value::Array(vec![
|
||||
expected_ui_msg.clone(),
|
||||
expected_env_msg_1.clone(),
|
||||
expected_user_message_1.clone(),
|
||||
]);
|
||||
assert_eq!(body1["input"], expected_input_1);
|
||||
|
||||
let expected_env_msg_2 = text_user_input(format!(
|
||||
r#"<environment_context>
|
||||
<cwd>{}</cwd>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#,
|
||||
default_cwd.to_string_lossy()
|
||||
));
|
||||
let expected_user_message_2 = text_user_input("hello 2".to_string());
|
||||
let expected_input_2 = serde_json::Value::Array(vec![
|
||||
expected_ui_msg,
|
||||
expected_env_msg_1,
|
||||
expected_user_message_1,
|
||||
expected_env_msg_2,
|
||||
expected_user_message_2,
|
||||
]);
|
||||
assert_eq!(body2["input"], expected_input_2);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user