Use a unified shell tell to not break cache (#3814)

Currently, we change the tool description according to the sandbox
policy and approval policy. This breaks the cache when the user hits
`/approvals`. This PR does the following:
- Always use the shell with escalation parameter:
- removes `create_shell_tool_for_sandbox` and always uses unified tool
via `create_shell_tool`
- Reject the func call when the model uses escalation parameter when it
cannot.
This commit is contained in:
Ahmed Ibrahim
2025-09-18 17:08:28 -07:00
committed by GitHub
parent de64f5f007
commit a7fda70053
3 changed files with 133 additions and 128 deletions

View File

@@ -456,8 +456,6 @@ impl Session {
client,
tools_config: ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
approval_policy,
sandbox_policy: sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
@@ -1237,8 +1235,6 @@ async fn submission_loop(
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &effective_family,
approval_policy: new_approval_policy,
sandbox_policy: new_sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
@@ -1325,8 +1321,6 @@ async fn submission_loop(
client,
tools_config: ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
approval_policy,
sandbox_policy: sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
@@ -1553,8 +1547,6 @@ async fn spawn_review_thread(
.unwrap_or_else(|| parent_turn_context.client.get_model_family());
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &review_model_family,
approval_policy: parent_turn_context.approval_policy,
sandbox_policy: parent_turn_context.sandbox_policy.clone(),
include_plan_tool: false,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: false,
@@ -2724,6 +2716,21 @@ async fn handle_container_exec_with_params(
sub_id: String,
call_id: String,
) -> ResponseInputItem {
if params.with_escalated_permissions.unwrap_or(false)
&& !matches!(turn_context.approval_policy, AskForApproval::OnRequest)
{
return ResponseInputItem::FunctionCallOutput {
call_id,
output: FunctionCallOutputPayload {
content: format!(
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
policy = turn_context.approval_policy
),
success: None,
},
};
}
// check if this was a patch, and apply it if so
let apply_patch_exec = match maybe_parse_apply_patch_verified(&params.command, &params.cwd) {
MaybeApplyPatchVerified::Body(changes) => {
@@ -3345,6 +3352,7 @@ mod tests {
use mcp_types::ContentBlock;
use mcp_types::TextContent;
use pretty_assertions::assert_eq;
use serde::Deserialize;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
@@ -3594,8 +3602,6 @@ mod tests {
);
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &config.model_family,
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
include_plan_tool: config.include_plan_tool,
include_apply_patch_tool: config.include_apply_patch_tool,
include_web_search_request: config.tools_web_search_request,
@@ -3735,4 +3741,105 @@ mod tests {
(rollout_items, live_history.contents())
}
#[tokio::test]
async fn rejects_escalated_permissions_when_policy_not_on_request() {
use crate::exec::ExecParams;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
use crate::turn_diff_tracker::TurnDiffTracker;
use std::collections::HashMap;
let (session, mut turn_context) = make_session_and_context();
// Ensure policy is NOT OnRequest so the early rejection path triggers
turn_context.approval_policy = AskForApproval::OnFailure;
let params = ExecParams {
command: if cfg!(windows) {
vec![
"cmd.exe".to_string(),
"/C".to_string(),
"echo hi".to_string(),
]
} else {
vec![
"/bin/sh".to_string(),
"-c".to_string(),
"echo hi".to_string(),
]
},
cwd: turn_context.cwd.clone(),
timeout_ms: Some(1000),
env: HashMap::new(),
with_escalated_permissions: Some(true),
justification: Some("test".to_string()),
};
let params2 = ExecParams {
with_escalated_permissions: Some(false),
..params.clone()
};
let mut turn_diff_tracker = TurnDiffTracker::new();
let sub_id = "test-sub".to_string();
let call_id = "test-call".to_string();
let resp = handle_container_exec_with_params(
params,
&session,
&turn_context,
&mut turn_diff_tracker,
sub_id,
call_id,
)
.await;
let ResponseInputItem::FunctionCallOutput { output, .. } = resp else {
panic!("expected FunctionCallOutput");
};
let expected = format!(
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
policy = turn_context.approval_policy
);
pretty_assertions::assert_eq!(output.content, expected);
// Now retry the same command WITHOUT escalated permissions; should succeed.
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
let resp2 = handle_container_exec_with_params(
params2,
&session,
&turn_context,
&mut turn_diff_tracker,
"test-sub".to_string(),
"test-call-2".to_string(),
)
.await;
let ResponseInputItem::FunctionCallOutput { output, .. } = resp2 else {
panic!("expected FunctionCallOutput on retry");
};
#[derive(Deserialize, PartialEq, Eq, Debug)]
struct ResponseExecMetadata {
exit_code: i32,
}
#[derive(Deserialize)]
struct ResponseExecOutput {
output: String,
metadata: ResponseExecMetadata,
}
let exec_output: ResponseExecOutput =
serde_json::from_str(&output.content).expect("valid exec output json");
pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 });
assert!(exec_output.output.contains("hi"));
pretty_assertions::assert_eq!(output.success, Some(true));
}
}