Use a unified shell tell to not break cache (#3814)
Currently, we change the tool description according to the sandbox policy and approval policy. This breaks the cache when the user hits `/approvals`. This PR does the following: - Always use the shell with escalation parameter: - removes `create_shell_tool_for_sandbox` and always uses unified tool via `create_shell_tool` - Reject the func call when the model uses escalation parameter when it cannot.
This commit is contained in:
@@ -456,8 +456,6 @@ impl Session {
|
||||
client,
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
@@ -1237,8 +1235,6 @@ async fn submission_loop(
|
||||
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &effective_family,
|
||||
approval_policy: new_approval_policy,
|
||||
sandbox_policy: new_sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
@@ -1325,8 +1321,6 @@ async fn submission_loop(
|
||||
client,
|
||||
tools_config: ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &model_family,
|
||||
approval_policy,
|
||||
sandbox_policy: sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
@@ -1553,8 +1547,6 @@ async fn spawn_review_thread(
|
||||
.unwrap_or_else(|| parent_turn_context.client.get_model_family());
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &review_model_family,
|
||||
approval_policy: parent_turn_context.approval_policy,
|
||||
sandbox_policy: parent_turn_context.sandbox_policy.clone(),
|
||||
include_plan_tool: false,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: false,
|
||||
@@ -2724,6 +2716,21 @@ async fn handle_container_exec_with_params(
|
||||
sub_id: String,
|
||||
call_id: String,
|
||||
) -> ResponseInputItem {
|
||||
if params.with_escalated_permissions.unwrap_or(false)
|
||||
&& !matches!(turn_context.approval_policy, AskForApproval::OnRequest)
|
||||
{
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
|
||||
policy = turn_context.approval_policy
|
||||
),
|
||||
success: None,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// check if this was a patch, and apply it if so
|
||||
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||
MaybeApplyPatchVerified::Body(changes) => {
|
||||
@@ -3345,6 +3352,7 @@ mod tests {
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::TextContent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -3594,8 +3602,6 @@ mod tests {
|
||||
);
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &config.model_family,
|
||||
approval_policy: config.approval_policy,
|
||||
sandbox_policy: config.sandbox_policy.clone(),
|
||||
include_plan_tool: config.include_plan_tool,
|
||||
include_apply_patch_tool: config.include_apply_patch_tool,
|
||||
include_web_search_request: config.tools_web_search_request,
|
||||
@@ -3735,4 +3741,105 @@ mod tests {
|
||||
|
||||
(rollout_items, live_history.contents())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn rejects_escalated_permissions_when_policy_not_on_request() {
|
||||
use crate::exec::ExecParams;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use std::collections::HashMap;
|
||||
|
||||
let (session, mut turn_context) = make_session_and_context();
|
||||
// Ensure policy is NOT OnRequest so the early rejection path triggers
|
||||
turn_context.approval_policy = AskForApproval::OnFailure;
|
||||
|
||||
let params = ExecParams {
|
||||
command: if cfg!(windows) {
|
||||
vec![
|
||||
"cmd.exe".to_string(),
|
||||
"/C".to_string(),
|
||||
"echo hi".to_string(),
|
||||
]
|
||||
} else {
|
||||
vec![
|
||||
"/bin/sh".to_string(),
|
||||
"-c".to_string(),
|
||||
"echo hi".to_string(),
|
||||
]
|
||||
},
|
||||
cwd: turn_context.cwd.clone(),
|
||||
timeout_ms: Some(1000),
|
||||
env: HashMap::new(),
|
||||
with_escalated_permissions: Some(true),
|
||||
justification: Some("test".to_string()),
|
||||
};
|
||||
|
||||
let params2 = ExecParams {
|
||||
with_escalated_permissions: Some(false),
|
||||
..params.clone()
|
||||
};
|
||||
|
||||
let mut turn_diff_tracker = TurnDiffTracker::new();
|
||||
|
||||
let sub_id = "test-sub".to_string();
|
||||
let call_id = "test-call".to_string();
|
||||
|
||||
let resp = handle_container_exec_with_params(
|
||||
params,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
sub_id,
|
||||
call_id,
|
||||
)
|
||||
.await;
|
||||
|
||||
let ResponseInputItem::FunctionCallOutput { output, .. } = resp else {
|
||||
panic!("expected FunctionCallOutput");
|
||||
};
|
||||
|
||||
let expected = format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
|
||||
policy = turn_context.approval_policy
|
||||
);
|
||||
|
||||
pretty_assertions::assert_eq!(output.content, expected);
|
||||
|
||||
// Now retry the same command WITHOUT escalated permissions; should succeed.
|
||||
// Force DangerFullAccess to avoid platform sandbox dependencies in tests.
|
||||
turn_context.sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
|
||||
let resp2 = handle_container_exec_with_params(
|
||||
params2,
|
||||
&session,
|
||||
&turn_context,
|
||||
&mut turn_diff_tracker,
|
||||
"test-sub".to_string(),
|
||||
"test-call-2".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let ResponseInputItem::FunctionCallOutput { output, .. } = resp2 else {
|
||||
panic!("expected FunctionCallOutput on retry");
|
||||
};
|
||||
|
||||
#[derive(Deserialize, PartialEq, Eq, Debug)]
|
||||
struct ResponseExecMetadata {
|
||||
exit_code: i32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ResponseExecOutput {
|
||||
output: String,
|
||||
metadata: ResponseExecMetadata,
|
||||
}
|
||||
|
||||
let exec_output: ResponseExecOutput =
|
||||
serde_json::from_str(&output.content).expect("valid exec output json");
|
||||
|
||||
pretty_assertions::assert_eq!(exec_output.metadata, ResponseExecMetadata { exit_code: 0 });
|
||||
assert!(exec_output.output.contains("hi"));
|
||||
pretty_assertions::assert_eq!(output.success, Some(true));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user