Add an operation to override current task context (#2431)
- Added an operation to override current task context - Added a test to check that cache stays the same
This commit is contained in:
@@ -33,6 +33,7 @@ use crate::error::CodexErr;
|
|||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::error::UsageLimitReachedError;
|
use crate::error::UsageLimitReachedError;
|
||||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||||
|
use crate::model_family::ModelFamily;
|
||||||
use crate::model_provider_info::ModelProviderInfo;
|
use crate::model_provider_info::ModelProviderInfo;
|
||||||
use crate::model_provider_info::WireApi;
|
use crate::model_provider_info::WireApi;
|
||||||
use crate::models::ResponseItem;
|
use crate::models::ResponseItem;
|
||||||
@@ -311,6 +312,30 @@ impl ModelClient {
|
|||||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||||
self.provider.clone()
|
self.provider.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the currently configured model slug.
|
||||||
|
pub fn get_model(&self) -> String {
|
||||||
|
self.config.model.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the currently configured model family.
|
||||||
|
pub fn get_model_family(&self) -> ModelFamily {
|
||||||
|
self.config.model_family.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the current reasoning effort setting.
|
||||||
|
pub fn get_reasoning_effort(&self) -> ReasoningEffortConfig {
|
||||||
|
self.effort
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the current reasoning summary setting.
|
||||||
|
pub fn get_reasoning_summary(&self) -> ReasoningSummaryConfig {
|
||||||
|
self.summary
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_auth(&self) -> Option<CodexAuth> {
|
||||||
|
self.auth.clone()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
|||||||
@@ -989,7 +989,7 @@ async fn submission_loop(
|
|||||||
rx_sub: Receiver<Submission>,
|
rx_sub: Receiver<Submission>,
|
||||||
) {
|
) {
|
||||||
// Wrap once to avoid cloning TurnContext for each task.
|
// Wrap once to avoid cloning TurnContext for each task.
|
||||||
let turn_context = Arc::new(turn_context);
|
let mut turn_context = Arc::new(turn_context);
|
||||||
// To break out of this loop, send Op::Shutdown.
|
// To break out of this loop, send Op::Shutdown.
|
||||||
while let Ok(sub) = rx_sub.recv().await {
|
while let Ok(sub) = rx_sub.recv().await {
|
||||||
debug!(?sub, "Submission");
|
debug!(?sub, "Submission");
|
||||||
@@ -997,6 +997,83 @@ async fn submission_loop(
|
|||||||
Op::Interrupt => {
|
Op::Interrupt => {
|
||||||
sess.interrupt_task();
|
sess.interrupt_task();
|
||||||
}
|
}
|
||||||
|
Op::OverrideTurnContext {
|
||||||
|
cwd,
|
||||||
|
approval_policy,
|
||||||
|
sandbox_policy,
|
||||||
|
model,
|
||||||
|
effort,
|
||||||
|
summary,
|
||||||
|
} => {
|
||||||
|
// Recalculate the persistent turn context with provided overrides.
|
||||||
|
let prev = Arc::clone(&turn_context);
|
||||||
|
let provider = prev.client.get_provider();
|
||||||
|
|
||||||
|
// Effective model + family
|
||||||
|
let (effective_model, effective_family) = if let Some(m) = model {
|
||||||
|
let fam =
|
||||||
|
find_family_for_model(&m).unwrap_or_else(|| config.model_family.clone());
|
||||||
|
(m, fam)
|
||||||
|
} else {
|
||||||
|
(prev.client.get_model(), prev.client.get_model_family())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Effective reasoning settings
|
||||||
|
let effective_effort = effort.unwrap_or(prev.client.get_reasoning_effort());
|
||||||
|
let effective_summary = summary.unwrap_or(prev.client.get_reasoning_summary());
|
||||||
|
|
||||||
|
let auth = prev.client.get_auth();
|
||||||
|
// Build updated config for the client
|
||||||
|
let mut updated_config = (*config).clone();
|
||||||
|
updated_config.model = effective_model.clone();
|
||||||
|
updated_config.model_family = effective_family.clone();
|
||||||
|
|
||||||
|
let client = ModelClient::new(
|
||||||
|
Arc::new(updated_config),
|
||||||
|
auth,
|
||||||
|
provider,
|
||||||
|
effective_effort,
|
||||||
|
effective_summary,
|
||||||
|
sess.session_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_approval_policy = approval_policy.unwrap_or(prev.approval_policy);
|
||||||
|
let new_sandbox_policy = sandbox_policy
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(prev.sandbox_policy.clone());
|
||||||
|
let new_cwd = cwd.clone().unwrap_or_else(|| prev.cwd.clone());
|
||||||
|
|
||||||
|
let tools_config = ToolsConfig::new(
|
||||||
|
&effective_family,
|
||||||
|
new_approval_policy,
|
||||||
|
new_sandbox_policy.clone(),
|
||||||
|
config.include_plan_tool,
|
||||||
|
config.include_apply_patch_tool,
|
||||||
|
);
|
||||||
|
|
||||||
|
let new_turn_context = TurnContext {
|
||||||
|
client,
|
||||||
|
tools_config,
|
||||||
|
user_instructions: prev.user_instructions.clone(),
|
||||||
|
base_instructions: prev.base_instructions.clone(),
|
||||||
|
approval_policy: new_approval_policy,
|
||||||
|
sandbox_policy: new_sandbox_policy.clone(),
|
||||||
|
shell_environment_policy: prev.shell_environment_policy.clone(),
|
||||||
|
cwd: new_cwd.clone(),
|
||||||
|
disable_response_storage: prev.disable_response_storage,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Install the new persistent context for subsequent tasks/turns.
|
||||||
|
turn_context = Arc::new(new_turn_context);
|
||||||
|
if cwd.is_some() || approval_policy.is_some() || sandbox_policy.is_some() {
|
||||||
|
sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new(
|
||||||
|
new_cwd,
|
||||||
|
new_approval_policy,
|
||||||
|
new_sandbox_policy,
|
||||||
|
))])
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
Op::UserInput { items } => {
|
Op::UserInput { items } => {
|
||||||
// attempt to inject input into current task
|
// attempt to inject input into current task
|
||||||
if let Err(items) = sess.inject_input(items) {
|
if let Err(items) = sess.inject_input(items) {
|
||||||
@@ -1057,7 +1134,7 @@ async fn submission_loop(
|
|||||||
cwd,
|
cwd,
|
||||||
disable_response_storage: turn_context.disable_response_storage,
|
disable_response_storage: turn_context.disable_response_storage,
|
||||||
};
|
};
|
||||||
|
// TODO: record the new environment context in the conversation history
|
||||||
// no current task, spawn a new one with the per‑turn context
|
// no current task, spawn a new one with the per‑turn context
|
||||||
let task =
|
let task =
|
||||||
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
|
AgentTask::spawn(sess.clone(), Arc::new(fresh_turn_context), sub.id, items);
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
use codex_core::ConversationManager;
|
use codex_core::ConversationManager;
|
||||||
use codex_core::ModelProviderInfo;
|
use codex_core::ModelProviderInfo;
|
||||||
use codex_core::built_in_model_providers;
|
use codex_core::built_in_model_providers;
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
use codex_core::protocol::EventMsg;
|
use codex_core::protocol::EventMsg;
|
||||||
use codex_core::protocol::InputItem;
|
use codex_core::protocol::InputItem;
|
||||||
use codex_core::protocol::Op;
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||||
|
use codex_core::protocol_config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||||
use codex_login::CodexAuth;
|
use codex_login::CodexAuth;
|
||||||
use core_test_support::load_default_config_for_test;
|
use core_test_support::load_default_config_for_test;
|
||||||
use core_test_support::load_sse_fixture_with_id;
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
@@ -129,3 +133,126 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests
|
|||||||
);
|
);
|
||||||
assert_eq!(body2["input"], expected_body2);
|
assert_eq!(body2["input"], expected_body2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn overrides_turn_context_but_keeps_cached_prefix_and_key_constant() {
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
|
||||||
|
let sse = sse_completed("resp");
|
||||||
|
let template = ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse, "text/event-stream");
|
||||||
|
|
||||||
|
// Expect two POSTs to /v1/responses
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/responses"))
|
||||||
|
.respond_with(template)
|
||||||
|
.expect(2)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
|
..built_in_model_providers()["openai"].clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let cwd = TempDir::new().unwrap();
|
||||||
|
let codex_home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
config.cwd = cwd.path().to_path_buf();
|
||||||
|
config.model_provider = model_provider;
|
||||||
|
config.user_instructions = Some("be consistent and helpful".to_string());
|
||||||
|
|
||||||
|
let conversation_manager = ConversationManager::default();
|
||||||
|
let codex = conversation_manager
|
||||||
|
.new_conversation_with_auth(config, Some(CodexAuth::from_api_key("Test API Key")))
|
||||||
|
.await
|
||||||
|
.expect("create new conversation")
|
||||||
|
.conversation;
|
||||||
|
|
||||||
|
// First turn
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "hello 1".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// Change everything about the turn context.
|
||||||
|
let new_cwd = TempDir::new().unwrap();
|
||||||
|
let writable = TempDir::new().unwrap();
|
||||||
|
codex
|
||||||
|
.submit(Op::OverrideTurnContext {
|
||||||
|
cwd: Some(new_cwd.path().to_path_buf()),
|
||||||
|
approval_policy: Some(AskForApproval::Never),
|
||||||
|
sandbox_policy: Some(SandboxPolicy::WorkspaceWrite {
|
||||||
|
writable_roots: vec![writable.path().to_path_buf()],
|
||||||
|
network_access: true,
|
||||||
|
exclude_tmpdir_env_var: true,
|
||||||
|
exclude_slash_tmp: true,
|
||||||
|
}),
|
||||||
|
model: Some("o3".to_string()),
|
||||||
|
effort: Some(ReasoningEffortConfig::High),
|
||||||
|
summary: Some(ReasoningSummaryConfig::Detailed),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Second turn after overrides
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "hello 2".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
// Verify we issued exactly two requests, and the cached prefix stayed identical.
|
||||||
|
let requests = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(requests.len(), 2, "expected two POST requests");
|
||||||
|
|
||||||
|
let body1 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||||
|
let body2 = requests[1].body_json::<serde_json::Value>().unwrap();
|
||||||
|
|
||||||
|
// prompt_cache_key should remain constant across overrides
|
||||||
|
assert_eq!(
|
||||||
|
body1["prompt_cache_key"], body2["prompt_cache_key"],
|
||||||
|
"prompt_cache_key should not change across overrides"
|
||||||
|
);
|
||||||
|
|
||||||
|
// The entire prefix from the first request should be identical and reused
|
||||||
|
// as the prefix of the second request, ensuring cache hit potential.
|
||||||
|
let expected_user_message_2 = serde_json::json!({
|
||||||
|
"type": "message",
|
||||||
|
"id": serde_json::Value::Null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [ { "type": "input_text", "text": "hello 2" } ]
|
||||||
|
});
|
||||||
|
// After overriding the turn context, the environment context should be emitted again
|
||||||
|
// reflecting the new cwd, approval policy and sandbox settings.
|
||||||
|
let expected_env_text_2 = format!(
|
||||||
|
"<environment_context>\nCurrent working directory: {}\nApproval policy: never\nSandbox mode: workspace-write\nNetwork access: enabled\n</environment_context>",
|
||||||
|
new_cwd.path().to_string_lossy()
|
||||||
|
);
|
||||||
|
let expected_env_msg_2 = serde_json::json!({
|
||||||
|
"type": "message",
|
||||||
|
"id": serde_json::Value::Null,
|
||||||
|
"role": "user",
|
||||||
|
"content": [ { "type": "input_text", "text": expected_env_text_2 } ]
|
||||||
|
});
|
||||||
|
let expected_body2 = serde_json::json!(
|
||||||
|
[
|
||||||
|
body1["input"].as_array().unwrap().as_slice(),
|
||||||
|
[expected_env_msg_2, expected_user_message_2].as_slice(),
|
||||||
|
]
|
||||||
|
.concat()
|
||||||
|
);
|
||||||
|
assert_eq!(body2["input"], expected_body2);
|
||||||
|
}
|
||||||
|
|||||||
@@ -75,6 +75,38 @@ pub enum Op {
|
|||||||
summary: ReasoningSummaryConfig,
|
summary: ReasoningSummaryConfig,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Override parts of the persistent turn context for subsequent turns.
|
||||||
|
///
|
||||||
|
/// All fields are optional; when omitted, the existing value is preserved.
|
||||||
|
/// This does not enqueue any input – it only updates defaults used for
|
||||||
|
/// future `UserInput` turns.
|
||||||
|
OverrideTurnContext {
|
||||||
|
/// Updated `cwd` for sandbox/tool calls.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
cwd: Option<PathBuf>,
|
||||||
|
|
||||||
|
/// Updated command approval policy.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
approval_policy: Option<AskForApproval>,
|
||||||
|
|
||||||
|
/// Updated sandbox policy for tool calls.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
sandbox_policy: Option<SandboxPolicy>,
|
||||||
|
|
||||||
|
/// Updated model slug. When set, the model family is derived
|
||||||
|
/// automatically.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Updated reasoning effort (honored only for reasoning-capable models).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
effort: Option<ReasoningEffortConfig>,
|
||||||
|
|
||||||
|
/// Updated reasoning summary preference (honored only for reasoning-capable models).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
summary: Option<ReasoningSummaryConfig>,
|
||||||
|
},
|
||||||
|
|
||||||
/// Approve a command execution
|
/// Approve a command execution
|
||||||
ExecApproval {
|
ExecApproval {
|
||||||
/// The id of the submission we are approving
|
/// The id of the submission we are approving
|
||||||
|
|||||||
Reference in New Issue
Block a user