diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 375183ab..b170076d 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1348,10 +1348,21 @@ async fn submission_loop( cwd, is_review_mode: false, }; - // TODO: record the new environment context in the conversation history + + // if the environment context has changed, record it in the conversation history + let previous_env_context = EnvironmentContext::from(turn_context.as_ref()); + let new_env_context = EnvironmentContext::from(&fresh_turn_context); + if !new_env_context.equals_except_shell(&previous_env_context) { + sess.record_conversation_items(&[ResponseItem::from(new_env_context)]) + .await; + } + + // Install the new persistent context for subsequent tasks/turns. + turn_context = Arc::new(fresh_turn_context); + // no current task, spawn a new one with the per‑turn context let task = - AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items); + AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); sess.set_task(task); } } diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index 89af9e1c..8f3292a2 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -2,6 +2,7 @@ use serde::Deserialize; use serde::Serialize; use strum_macros::Display as DeriveDisplay; +use crate::codex::TurnContext; use crate::protocol::AskForApproval; use crate::protocol::SandboxPolicy; use crate::shell::Shell; @@ -71,6 +72,39 @@ impl EnvironmentContext { shell, } } + + /// Compares two environment contexts, ignoring the shell. Useful when + /// comparing turn to turn, since the initial environment_context will + /// include the shell, and then it is not configurable from turn to turn. + pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool { + let EnvironmentContext { + cwd, + approval_policy, + sandbox_mode, + network_access, + writable_roots, + // should compare all fields except shell + shell: _, + } = other; + + self.cwd == *cwd + && self.approval_policy == *approval_policy + && self.sandbox_mode == *sandbox_mode + && self.network_access == *network_access + && self.writable_roots == *writable_roots + } +} + +impl From<&TurnContext> for EnvironmentContext { + fn from(turn_context: &TurnContext) -> Self { + Self::new( + Some(turn_context.cwd.clone()), + Some(turn_context.approval_policy), + Some(turn_context.sandbox_policy.clone()), + // Shell is not configurable from turn to turn + None, + ) + } } impl EnvironmentContext { @@ -140,6 +174,9 @@ impl From for ResponseItem { #[cfg(test)] mod tests { + use crate::shell::BashShell; + use crate::shell::ZshShell; + use super::*; use pretty_assertions::assert_eq; @@ -210,4 +247,82 @@ mod tests { assert_eq!(context.serialize_to_xml(), expected); } + + #[test] + fn equals_except_shell_compares_approval_policy() { + // Approval policy + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(workspace_write_policy(vec!["/repo"], false)), + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::Never), + Some(workspace_write_policy(vec!["/repo"], true)), + None, + ); + assert!(!context1.equals_except_shell(&context2)); + } + + #[test] + fn equals_except_shell_compares_sandbox_policy() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(SandboxPolicy::new_read_only_policy()), + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(SandboxPolicy::new_workspace_write_policy()), + None, + ); + + assert!(!context1.equals_except_shell(&context2)); + } + + #[test] + fn equals_except_shell_compares_workspace_write_policy() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)), + None, + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(workspace_write_policy(vec!["/repo", "/tmp"], true)), + None, + ); + + assert!(!context1.equals_except_shell(&context2)); + } + + #[test] + fn equals_except_shell_ignores_shell() { + let context1 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(workspace_write_policy(vec!["/repo"], false)), + Some(Shell::Bash(BashShell { + shell_path: "/bin/bash".into(), + bashrc_path: "/home/user/.bashrc".into(), + })), + ); + let context2 = EnvironmentContext::new( + Some(PathBuf::from("/repo")), + Some(AskForApproval::OnRequest), + Some(workspace_write_policy(vec!["/repo"], false)), + Some(Shell::Zsh(ZshShell { + shell_path: "/bin/zsh".into(), + zshrc_path: "/home/user/.zshrc".into(), + })), + ); + + assert!(context1.equals_except_shell(&context2)); + } } diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs index 1734d5de..cb278fdd 100644 --- a/codex-rs/core/src/shell.rs +++ b/codex-rs/core/src/shell.rs @@ -5,20 +5,20 @@ use std::path::PathBuf; #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct ZshShell { - shell_path: String, - zshrc_path: String, + pub(crate) shell_path: String, + pub(crate) zshrc_path: String, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct BashShell { - shell_path: String, - bashrc_path: String, + pub(crate) shell_path: String, + pub(crate) bashrc_path: String, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct PowerShellConfig { - exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe". - bash_exe_fallback: Option, // In case the model generates a bash command. + pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe". + pub(crate) bash_exe_fallback: Option, // In case the model generates a bash command. } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index c6731fee..a69f57a2 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -12,6 +12,7 @@ use codex_core::protocol::Op; use codex_core::protocol::SandboxPolicy; use codex_core::protocol_config_types::ReasoningEffort; use codex_core::protocol_config_types::ReasoningSummary; +use codex_core::shell::Shell; use codex_core::shell::default_user_shell; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; @@ -23,6 +24,30 @@ use wiremock::ResponseTemplate; use wiremock::matchers::method; use wiremock::matchers::path; +fn text_user_input(text: String) -> serde_json::Value { + serde_json::json!({ + "type": "message", + "role": "user", + "content": [ { "type": "input_text", "text": text } ] + }) +} + +fn default_env_context_str(cwd: &str, shell: &Shell) -> String { + format!( + r#" + {} + on-request + read-only + restricted +{}"#, + cwd, + match shell.name() { + Some(name) => format!(" {name}\n"), + None => String::new(), + } + ) +} + /// Build minimal SSE stream with completed marker using the JSON fixture. fn sse_completed(id: &str) -> String { load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) @@ -546,12 +571,262 @@ async fn per_turn_overrides_keep_cached_prefix_and_key_constant() { "role": "user", "content": [ { "type": "input_text", "text": "hello 2" } ] }); + let expected_env_text_2 = format!( + r#" + {} + never + workspace-write + enabled + + {} + +"#, + new_cwd.path().to_string_lossy(), + writable.path().to_string_lossy(), + ); + let expected_env_msg_2 = serde_json::json!({ + "type": "message", + "role": "user", + "content": [ { "type": "input_text", "text": expected_env_text_2 } ] + }); let expected_body2 = serde_json::json!( [ body1["input"].as_array().unwrap().as_slice(), - [expected_user_message_2].as_slice(), + [expected_env_msg_2, expected_user_message_2].as_slice(), ] .concat() ); assert_eq!(body2["input"], expected_body2); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn send_user_turn_with_no_changes_does_not_send_environment_context() { + use pretty_assertions::assert_eq; + + let server = MockServer::start().await; + + let sse = sse_completed("resp"); + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(template) + .expect(2) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let cwd = TempDir::new().unwrap(); + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.cwd = cwd.path().to_path_buf(); + config.model_provider = model_provider; + config.user_instructions = Some("be consistent and helpful".to_string()); + + let default_cwd = config.cwd.clone(); + let default_approval_policy = config.approval_policy; + let default_sandbox_policy = config.sandbox_policy.clone(); + let default_model = config.model.clone(); + let default_effort = config.model_reasoning_effort; + let default_summary = config.model_reasoning_summary; + + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); + let codex = conversation_manager + .new_conversation(config) + .await + .expect("create new conversation") + .conversation; + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "hello 1".into(), + }], + cwd: default_cwd.clone(), + approval_policy: default_approval_policy, + sandbox_policy: default_sandbox_policy.clone(), + model: default_model.clone(), + effort: default_effort, + summary: default_summary, + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "hello 2".into(), + }], + cwd: default_cwd.clone(), + approval_policy: default_approval_policy, + sandbox_policy: default_sandbox_policy.clone(), + model: default_model.clone(), + effort: default_effort, + summary: default_summary, + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 2, "expected two POST requests"); + + let body1 = requests[0].body_json::().unwrap(); + let body2 = requests[1].body_json::().unwrap(); + + let shell = default_user_shell().await; + let expected_ui_text = + "\n\nbe consistent and helpful\n\n"; + let expected_ui_msg = text_user_input(expected_ui_text.to_string()); + + let expected_env_msg_1 = text_user_input(default_env_context_str( + &cwd.path().to_string_lossy(), + &shell, + )); + let expected_user_message_1 = text_user_input("hello 1".to_string()); + + let expected_input_1 = serde_json::Value::Array(vec![ + expected_ui_msg.clone(), + expected_env_msg_1.clone(), + expected_user_message_1.clone(), + ]); + assert_eq!(body1["input"], expected_input_1); + + let expected_user_message_2 = text_user_input("hello 2".to_string()); + let expected_input_2 = serde_json::Value::Array(vec![ + expected_ui_msg, + expected_env_msg_1, + expected_user_message_1, + expected_user_message_2, + ]); + assert_eq!(body2["input"], expected_input_2); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn send_user_turn_with_changes_sends_environment_context() { + use pretty_assertions::assert_eq; + + let server = MockServer::start().await; + + let sse = sse_completed("resp"); + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(template) + .expect(2) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let cwd = TempDir::new().unwrap(); + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.cwd = cwd.path().to_path_buf(); + config.model_provider = model_provider; + config.user_instructions = Some("be consistent and helpful".to_string()); + + let default_cwd = config.cwd.clone(); + let default_approval_policy = config.approval_policy; + let default_sandbox_policy = config.sandbox_policy.clone(); + let default_model = config.model.clone(); + let default_effort = config.model_reasoning_effort; + let default_summary = config.model_reasoning_summary; + + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); + let codex = conversation_manager + .new_conversation(config.clone()) + .await + .expect("create new conversation") + .conversation; + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "hello 1".into(), + }], + cwd: default_cwd.clone(), + approval_policy: default_approval_policy, + sandbox_policy: default_sandbox_policy.clone(), + model: default_model, + effort: default_effort, + summary: default_summary, + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "hello 2".into(), + }], + cwd: default_cwd.clone(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: "o3".to_string(), + effort: Some(ReasoningEffort::High), + summary: ReasoningSummary::Detailed, + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 2, "expected two POST requests"); + + let body1 = requests[0].body_json::().unwrap(); + let body2 = requests[1].body_json::().unwrap(); + + let shell = default_user_shell().await; + let expected_ui_text = + "\n\nbe consistent and helpful\n\n"; + let expected_ui_msg = serde_json::json!({ + "type": "message", + "role": "user", + "content": [ { "type": "input_text", "text": expected_ui_text } ] + }); + let expected_env_text_1 = default_env_context_str(&default_cwd.to_string_lossy(), &shell); + let expected_env_msg_1 = text_user_input(expected_env_text_1); + let expected_user_message_1 = text_user_input("hello 1".to_string()); + let expected_input_1 = serde_json::Value::Array(vec![ + expected_ui_msg.clone(), + expected_env_msg_1.clone(), + expected_user_message_1.clone(), + ]); + assert_eq!(body1["input"], expected_input_1); + + let expected_env_msg_2 = text_user_input(format!( + r#" + {} + never + danger-full-access + enabled +"#, + default_cwd.to_string_lossy() + )); + let expected_user_message_2 = text_user_input("hello 2".to_string()); + let expected_input_2 = serde_json::Value::Array(vec![ + expected_ui_msg, + expected_env_msg_1, + expected_user_message_1, + expected_env_msg_2, + expected_user_message_2, + ]); + assert_eq!(body2["input"], expected_input_2); +}