From c283f9f6ce350eac608d48e958c5de51380e9403 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Mon, 18 Aug 2025 12:59:19 -0700 Subject: [PATCH] Add an operation to override current task context (#2431) - Added an operation to override current task context - Added a test to check that cache stays the same --- codex-rs/core/src/client.rs | 25 +++++ codex-rs/core/src/codex.rs | 81 +++++++++++++++- codex-rs/core/tests/prompt_caching.rs | 127 ++++++++++++++++++++++++++ codex-rs/protocol/src/protocol.rs | 32 +++++++ 4 files changed, 263 insertions(+), 2 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 7319c9a0..86a711e4 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -33,6 +33,7 @@ use crate::error::CodexErr; use crate::error::Result; use crate::error::UsageLimitReachedError; use crate::flags::CODEX_RS_SSE_FIXTURE; +use crate::model_family::ModelFamily; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; use crate::models::ResponseItem; @@ -311,6 +312,30 @@ impl ModelClient { pub fn get_provider(&self) -> ModelProviderInfo { self.provider.clone() } + + /// Returns the currently configured model slug. + pub fn get_model(&self) -> String { + self.config.model.clone() + } + + /// Returns the currently configured model family. + pub fn get_model_family(&self) -> ModelFamily { + self.config.model_family.clone() + } + + /// Returns the current reasoning effort setting. + pub fn get_reasoning_effort(&self) -> ReasoningEffortConfig { + self.effort + } + + /// Returns the current reasoning summary setting. + pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig { + self.summary + } + + pub fn get_auth(&self) -> Option { + self.auth.clone() + } } #[derive(Debug, Deserialize, Serialize)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 8670978e..397246a7 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -989,7 +989,7 @@ async fn submission_loop( rx_sub: Receiver, ) { // Wrap once to avoid cloning TurnContext for each task. - let turn_context = Arc::new(turn_context); + let mut turn_context = Arc::new(turn_context); // To break out of this loop, send Op::Shutdown. while let Ok(sub) = rx_sub.recv().await { debug!(?sub, "Submission"); @@ -997,6 +997,83 @@ async fn submission_loop( Op::Interrupt => { sess.interrupt_task(); } + Op::OverrideTurnContext { + cwd, + approval_policy, + sandbox_policy, + model, + effort, + summary, + } => { + // Recalculate the persistent turn context with provided overrides. + let prev = Arc::clone(&turn_context); + let provider = prev.client.get_provider(); + + // Effective model + family + let (effective_model, effective_family) = if let Some(m) = model { + let fam = + find_family_for_model(&m).unwrap_or_else(|| config.model_family.clone()); + (m, fam) + } else { + (prev.client.get_model(), prev.client.get_model_family()) + }; + + // Effective reasoning settings + let effective_effort = effort.unwrap_or(prev.client.get_reasoning_effort()); + let effective_summary = summary.unwrap_or(prev.client.get_reasoning_summary()); + + let auth = prev.client.get_auth(); + // Build updated config for the client + let mut updated_config = (*config).clone(); + updated_config.model = effective_model.clone(); + updated_config.model_family = effective_family.clone(); + + let client = ModelClient::new( + Arc::new(updated_config), + auth, + provider, + effective_effort, + effective_summary, + sess.session_id, + ); + + let new_approval_policy = approval_policy.unwrap_or(prev.approval_policy); + let new_sandbox_policy = sandbox_policy + .clone() + .unwrap_or(prev.sandbox_policy.clone()); + let new_cwd = cwd.clone().unwrap_or_else(|| prev.cwd.clone()); + + let tools_config = ToolsConfig::new( + &effective_family, + new_approval_policy, + new_sandbox_policy.clone(), + config.include_plan_tool, + config.include_apply_patch_tool, + ); + + let new_turn_context = TurnContext { + client, + tools_config, + user_instructions: prev.user_instructions.clone(), + base_instructions: prev.base_instructions.clone(), + approval_policy: new_approval_policy, + sandbox_policy: new_sandbox_policy.clone(), + shell_environment_policy: prev.shell_environment_policy.clone(), + cwd: new_cwd.clone(), + disable_response_storage: prev.disable_response_storage, + }; + + // Install the new persistent context for subsequent tasks/turns. + turn_context = Arc::new(new_turn_context); + if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() { + sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new( + new_cwd, + new_approval_policy, + new_sandbox_policy, + ))]) + .await; + } + } Op::UserInput { items } => { // attempt to inject input into current task if let Err(items) = sess.inject_input(items) { @@ -1057,7 +1134,7 @@ async fn submission_loop( cwd, disable_response_storage: turn_context.disable_response_storage, }; - + // TODO: record the new environment context in the conversation history // no current task, spawn a new one with the per‑turn context let task = AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items); diff --git a/codex-rs/core/tests/prompt_caching.rs b/codex-rs/core/tests/prompt_caching.rs index d637eb67..e528cb7a 100644 --- a/codex-rs/core/tests/prompt_caching.rs +++ b/codex-rs/core/tests/prompt_caching.rs @@ -1,9 +1,13 @@ use codex_core::ConversationManager; use codex_core::ModelProviderInfo; use codex_core::built_in_model_providers; +use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::InputItem; use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_core::protocol_config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_login::CodexAuth; use core_test_support::load_default_config_for_test; use core_test_support::load_sse_fixture_with_id; @@ -129,3 +133,126 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests ); assert_eq!(body2["input"], expected_body2); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() { + use pretty_assertions::assert_eq; + + let server = MockServer::start().await; + + let sse = sse_completed("resp"); + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"); + + // Expect two POSTs to /v1/responses + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(template) + .expect(2) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let cwd = TempDir::new().unwrap(); + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.cwd = cwd.path().to_path_buf(); + config.model_provider = model_provider; + config.user_instructions = Some("be consistent and helpful".to_string()); + + let conversation_manager = ConversationManager::default(); + let codex = conversation_manager + .new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key"))) + .await + .expect("create new conversation") + .conversation; + + // First turn + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello 1".into(), + }], + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + // Change everything about the turn context. + let new_cwd = TempDir::new().unwrap(); + let writable = TempDir::new().unwrap(); + codex + .submit(Op::OverrideTurnContext { + cwd: Some(new_cwd.path().to_path_buf()), + approval_policy: Some(AskForApproval::Never), + sandbox_policy: Some(SandboxPolicy::WorkspaceWrite { + writable_roots: vec![writable.path().to_path_buf()], + network_access: true, + exclude_tmpdir_env_var: true, + exclude_slash_tmp: true, + }), + model: Some("o3".to_string()), + effort: Some(ReasoningEffortConfig::High), + summary: Some(ReasoningSummaryConfig::Detailed), + }) + .await + .unwrap(); + + // Second turn after overrides + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello 2".into(), + }], + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + // Verify we issued exactly two requests, and the cached prefix stayed identical. + let requests = server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 2, "expected two POST requests"); + + let body1 = requests[0].body_json::().unwrap(); + let body2 = requests[1].body_json::().unwrap(); + + // prompt_cache_key should remain constant across overrides + assert_eq!( + body1["prompt_cache_key"], body2["prompt_cache_key"], + "prompt_cache_key should not change across overrides" + ); + + // The entire prefix from the first request should be identical and reused + // as the prefix of the second request, ensuring cache hit potential. + let expected_user_message_2 = serde_json::json!({ + "type": "message", + "id": serde_json::Value::Null, + "role": "user", + "content": [ { "type": "input_text", "text": "hello 2" } ] + }); + // After overriding the turn context, the environment context should be emitted again + // reflecting the new cwd, approval policy and sandbox settings. + let expected_env_text_2 = format!( + "\nCurrent working directory: {}\nApproval policy: never\nSandbox mode: workspace-write\nNetwork access: enabled\n", + new_cwd.path().to_string_lossy() + ); + let expected_env_msg_2 = serde_json::json!({ + "type": "message", + "id": serde_json::Value::Null, + "role": "user", + "content": [ { "type": "input_text", "text": expected_env_text_2 } ] + }); + let expected_body2 = serde_json::json!( + [ + body1["input"].as_array().unwrap().as_slice(), + [expected_env_msg_2, expected_user_message_2].as_slice(), + ] + .concat() + ); + assert_eq!(body2["input"], expected_body2); +} diff --git a/codex-rs/protocol/src/protocol.rs b/codex-rs/protocol/src/protocol.rs index 4b9a2902..2aea2189 100644 --- a/codex-rs/protocol/src/protocol.rs +++ b/codex-rs/protocol/src/protocol.rs @@ -75,6 +75,38 @@ pub enum Op { summary: ReasoningSummaryConfig, }, + /// Override parts of the persistent turn context for subsequent turns. + /// + /// All fields are optional; when omitted, the existing value is preserved. + /// This does not enqueue any input – it only updates defaults used for + /// future `UserInput` turns. + OverrideTurnContext { + /// Updated `cwd` for sandbox/tool calls. + #[serde(skip_serializing_if = "Option::is_none")] + cwd: Option, + + /// Updated command approval policy. + #[serde(skip_serializing_if = "Option::is_none")] + approval_policy: Option, + + /// Updated sandbox policy for tool calls. + #[serde(skip_serializing_if = "Option::is_none")] + sandbox_policy: Option, + + /// Updated model slug. When set, the model family is derived + /// automatically. + #[serde(skip_serializing_if = "Option::is_none")] + model: Option, + + /// Updated reasoning effort (honored only for reasoning-capable models). + #[serde(skip_serializing_if = "Option::is_none")] + effort: Option, + + /// Updated reasoning summary preference (honored only for reasoning-capable models). + #[serde(skip_serializing_if = "Option::is_none")] + summary: Option, + }, + /// Approve a command execution ExecApproval { /// The id of the submission we are approving