chore: align unified_exec (#5442)

Align `unified_exec` with b implementation
This commit is contained in:
jif-oai
2025-10-22 11:50:18 +01:00
committed by GitHub
parent 53cadb4df6
commit 00b1e130b3
8 changed files with 1060 additions and 596 deletions

View File

@@ -76,6 +76,13 @@ pub(crate) enum ToolEmitter {
changes: HashMap<PathBuf, FileChange>,
auto_approved: bool,
},
UnifiedExec {
command: String,
cwd: PathBuf,
// True for `exec_command` and false for `write_stdin`.
#[allow(dead_code)]
is_startup_command: bool,
},
}
impl ToolEmitter {
@@ -90,6 +97,14 @@ impl ToolEmitter {
}
}
pub fn unified_exec(command: String, cwd: PathBuf, is_startup_command: bool) -> Self {
Self::UnifiedExec {
command,
cwd,
is_startup_command,
}
}
pub async fn emit(&self, ctx: ToolEventCtx<'_>, stage: ToolEventStage) {
match (self, stage) {
(Self::Shell { command, cwd }, ToolEventStage::Begin) => {
@@ -181,6 +196,10 @@ impl ToolEmitter {
) => {
emit_patch_end(ctx, String::new(), (*message).to_string(), false).await;
}
(Self::UnifiedExec { command, cwd, .. }, _) => {
// TODO(jif) add end and failures.
emit_exec_command_begin(ctx, &[command.to_string()], cwd.as_path()).await;
}
}
}
}

View File

@@ -1,35 +1,68 @@
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use serde::Serialize;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::events::ToolEventStage;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::unified_exec::UnifiedExecRequest;
use crate::unified_exec::ExecCommandRequest;
use crate::unified_exec::UnifiedExecContext;
use crate::unified_exec::UnifiedExecResponse;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::unified_exec::WriteStdinRequest;
pub struct UnifiedExecHandler;
#[derive(Deserialize)]
struct UnifiedExecArgs {
input: Vec<String>,
#[derive(Debug, Deserialize)]
struct ExecCommandArgs {
cmd: String,
#[serde(default = "default_shell")]
shell: String,
#[serde(default = "default_login")]
login: bool,
#[serde(default)]
session_id: Option<String>,
yield_time_ms: Option<u64>,
#[serde(default)]
timeout_ms: Option<u64>,
max_output_tokens: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct WriteStdinArgs {
session_id: i32,
#[serde(default)]
chars: String,
#[serde(default)]
yield_time_ms: Option<u64>,
#[serde(default)]
max_output_tokens: Option<usize>,
}
fn default_shell() -> String {
"/bin/bash".to_string()
}
fn default_login() -> bool {
true
}
#[async_trait]
impl ToolHandler for UnifiedExecHandler {
fn kind(&self) -> ToolKind {
ToolKind::UnifiedExec
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
payload,
ToolPayload::UnifiedExec { .. } | ToolPayload::Function { .. }
ToolPayload::Function { .. } | ToolPayload::UnifiedExec { .. }
)
}
@@ -38,19 +71,14 @@ impl ToolHandler for UnifiedExecHandler {
session,
turn,
call_id,
tool_name: _tool_name,
tool_name,
payload,
..
} = invocation;
let args = match payload {
ToolPayload::UnifiedExec { arguments } | ToolPayload::Function { arguments } => {
serde_json::from_str::<UnifiedExecArgs>(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?
}
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
ToolPayload::UnifiedExec { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"unified_exec handler received unsupported payload".to_string(),
@@ -58,58 +86,69 @@ impl ToolHandler for UnifiedExecHandler {
}
};
let UnifiedExecArgs {
input,
session_id,
timeout_ms,
} = args;
let manager: &UnifiedExecSessionManager = &session.services.unified_exec_manager;
let context = UnifiedExecContext {
session: &session,
turn: turn.as_ref(),
call_id: &call_id,
};
let parsed_session_id = if let Some(session_id) = session_id {
match session_id.parse::<i32>() {
Ok(parsed) => Some(parsed),
Err(output) => {
return Err(FunctionCallError::RespondToModel(format!(
"invalid session_id: {session_id} due to error {output:?}"
)));
}
let response = match tool_name.as_str() {
"exec_command" => {
let args: ExecCommandArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse exec_command arguments: {err:?}"
))
})?;
let event_ctx =
ToolEventCtx::new(context.session, context.turn, context.call_id, None);
let emitter =
ToolEmitter::unified_exec(args.cmd.clone(), context.turn.cwd.clone(), true);
emitter.emit(event_ctx, ToolEventStage::Begin).await;
manager
.exec_command(
ExecCommandRequest {
command: &args.cmd,
shell: &args.shell,
login: args.login,
yield_time_ms: args.yield_time_ms,
max_output_tokens: args.max_output_tokens,
},
&context,
)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("exec_command failed: {err:?}"))
})?
}
"write_stdin" => {
let args: WriteStdinArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse write_stdin arguments: {err:?}"
))
})?;
manager
.write_stdin(WriteStdinRequest {
session_id: args.session_id,
input: &args.chars,
yield_time_ms: args.yield_time_ms,
max_output_tokens: args.max_output_tokens,
})
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("write_stdin failed: {err:?}"))
})?
}
other => {
return Err(FunctionCallError::RespondToModel(format!(
"unsupported unified exec function {other}"
)));
}
} else {
None
};
let request = UnifiedExecRequest {
input_chunks: &input,
timeout_ms,
};
let value = session
.services
.unified_exec_manager
.handle_request(
request,
crate::unified_exec::UnifiedExecContext {
session: &session,
turn: turn.as_ref(),
call_id: &call_id,
session_id: parsed_session_id,
},
)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}"))
})?;
#[derive(serde::Serialize)]
struct SerializedUnifiedExecResult {
session_id: Option<String>,
output: String,
}
let content = serde_json::to_string(&SerializedUnifiedExecResult {
session_id: value.session_id.map(|id| id.to_string()),
output: value.output,
})
.map_err(|err| {
let content = serialize_response(&response).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to serialize unified exec output: {err:?}"
))
@@ -121,3 +160,33 @@ impl ToolHandler for UnifiedExecHandler {
})
}
}
#[derive(Serialize)]
struct SerializedUnifiedExecResponse<'a> {
chunk_id: &'a str,
wall_time_seconds: f64,
output: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
exit_code: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
original_token_count: Option<usize>,
}
fn serialize_response(response: &UnifiedExecResponse) -> Result<String, serde_json::Error> {
let payload = SerializedUnifiedExecResponse {
chunk_id: &response.chunk_id,
wall_time_seconds: duration_to_seconds(response.wall_time),
output: &response.output,
session_id: response.session_id,
exit_code: response.exit_code,
original_token_count: response.original_token_count,
};
serde_json::to_string(&payload)
}
fn duration_to_seconds(duration: Duration) -> f64 {
duration.as_secs_f64()
}

View File

@@ -15,7 +15,6 @@ use crate::tools::context::ToolPayload;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ToolKind {
Function,
UnifiedExec,
Mcp,
}
@@ -27,7 +26,6 @@ pub trait ToolHandler: Send + Sync {
matches!(
(self.kind(), payload),
(ToolKind::Function, ToolPayload::Function { .. })
| (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. })
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
)
}

View File

@@ -136,48 +136,99 @@ impl From<JsonSchema> for AdditionalProperties {
}
}
fn create_unified_exec_tool() -> ToolSpec {
fn create_exec_command_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"input".to_string(),
JsonSchema::Array {
items: Box::new(JsonSchema::String { description: None }),
description: Some(
"When no session_id is provided, treat the array as the command and arguments \
to launch. When session_id is set, concatenate the strings (in order) and write \
them to the session's stdin."
.to_string(),
),
},
);
properties.insert(
"session_id".to_string(),
"cmd".to_string(),
JsonSchema::String {
description: Some("Shell command to execute.".to_string()),
},
);
properties.insert(
"shell".to_string(),
JsonSchema::String {
description: Some("Shell binary to launch. Defaults to /bin/bash.".to_string()),
},
);
properties.insert(
"login".to_string(),
JsonSchema::Boolean {
description: Some(
"Identifier for an existing interactive session. If omitted, a new command \
is spawned."
.to_string(),
"Whether to run the shell with -l/-i semantics. Defaults to true.".to_string(),
),
},
);
properties.insert(
"timeout_ms".to_string(),
"yield_time_ms".to_string(),
JsonSchema::Number {
description: Some(
"Maximum time in milliseconds to wait for output after writing the input."
.to_string(),
"How long to wait (in milliseconds) for output before yielding.".to_string(),
),
},
);
properties.insert(
"max_output_tokens".to_string(),
JsonSchema::Number {
description: Some(
"Maximum number of tokens to return. Excess output will be truncated.".to_string(),
),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "unified_exec".to_string(),
name: "exec_command".to_string(),
description:
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
"Runs a command in a PTY, returning output or a session ID for ongoing interaction."
.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["input".to_string()]),
required: Some(vec!["cmd".to_string()]),
additional_properties: Some(false.into()),
},
})
}
fn create_write_stdin_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"session_id".to_string(),
JsonSchema::Number {
description: Some("Identifier of the running unified exec session.".to_string()),
},
);
properties.insert(
"chars".to_string(),
JsonSchema::String {
description: Some("Bytes to write to stdin (may be empty to poll).".to_string()),
},
);
properties.insert(
"yield_time_ms".to_string(),
JsonSchema::Number {
description: Some(
"How long to wait (in milliseconds) for output before yielding.".to_string(),
),
},
);
properties.insert(
"max_output_tokens".to_string(),
JsonSchema::Number {
description: Some(
"Maximum number of tokens to return. Excess output will be truncated.".to_string(),
),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "write_stdin".to_string(),
description:
"Writes characters to an existing unified exec session and returns recent output."
.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["session_id".to_string()]),
additional_properties: Some(false.into()),
},
})
@@ -839,19 +890,20 @@ pub(crate) fn build_specs(
|| matches!(config.shell_type, ConfigShellToolType::Streamable);
if use_unified_exec {
builder.push_spec(create_unified_exec_tool());
builder.register_handler("unified_exec", unified_exec_handler);
} else {
match &config.shell_type {
ConfigShellToolType::Default => {
builder.push_spec(create_shell_tool());
}
ConfigShellToolType::Local => {
builder.push_spec(ToolSpec::LocalShell {});
}
ConfigShellToolType::Streamable => {
// Already handled by use_unified_exec.
}
builder.push_spec(create_exec_command_tool());
builder.push_spec(create_write_stdin_tool());
builder.register_handler("exec_command", unified_exec_handler.clone());
builder.register_handler("write_stdin", unified_exec_handler);
}
match &config.shell_type {
ConfigShellToolType::Default => {
builder.push_spec(create_shell_tool());
}
ConfigShellToolType::Local => {
builder.push_spec(ToolSpec::LocalShell {});
}
ConfigShellToolType::Streamable => {
// Already handled by use_unified_exec.
}
}
@@ -986,6 +1038,14 @@ mod tests {
}
}
fn shell_tool_name(config: &ToolsConfig) -> Option<&'static str> {
match config.shell_type {
ConfigShellToolType::Default => Some("shell"),
ConfigShellToolType::Local => Some("local_shell"),
ConfigShellToolType::Streamable => None,
}
}
fn find_tool<'a>(
tools: &'a [ConfiguredToolSpec],
expected_name: &str,
@@ -1009,18 +1069,20 @@ mod tests {
});
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
}
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
]);
assert_eq_tool_names(&tools, &expected);
}
#[test]
@@ -1035,18 +1097,20 @@ mod tests {
});
let (tools, _) = build_specs(&config, Some(HashMap::new())).build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
}
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
]);
assert_eq_tool_names(&tools, &expected);
}
#[test]
@@ -1063,7 +1127,8 @@ mod tests {
});
let (tools, _) = build_specs(&config, None).build();
assert!(!find_tool(&tools, "unified_exec").supports_parallel_tool_calls);
assert!(!find_tool(&tools, "exec_command").supports_parallel_tool_calls);
assert!(!find_tool(&tools, "write_stdin").supports_parallel_tool_calls);
assert!(find_tool(&tools, "grep_files").supports_parallel_tool_calls);
assert!(find_tool(&tools, "list_dir").supports_parallel_tool_calls);
assert!(find_tool(&tools, "read_file").supports_parallel_tool_calls);
@@ -1148,19 +1213,21 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
"test_server/do_something_cool",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
}
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"web_search",
"view_image",
"test_server/do_something_cool",
]);
assert_eq_tool_names(&tools, &expected);
let tool = find_tool(&tools, "test_server/do_something_cool");
assert_eq!(
@@ -1267,21 +1334,23 @@ mod tests {
]);
let (tools, _) = build_specs(&config, Some(tools_map)).build();
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"view_image",
"test_server/cool",
"test_server/do",
"test_server/something",
],
);
// Expect exec_command/write_stdin first, followed by MCP tools sorted by fully-qualified name.
let mut expected = vec!["exec_command", "write_stdin"];
if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
}
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"view_image",
"test_server/cool",
"test_server/do",
"test_server/something",
]);
assert_eq_tool_names(&tools, &expected);
}
#[test]
@@ -1320,23 +1389,28 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/search",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
let has_shell = if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
true
} else {
false
};
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/search",
]);
assert_eq_tool_names(&tools, &expected);
assert_eq!(
tools[8].spec,
tools[if has_shell { 10 } else { 9 }].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/search".to_string(),
parameters: JsonSchema::Object {
@@ -1389,22 +1463,27 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/paginate",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
let has_shell = if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
true
} else {
false
};
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/paginate",
]);
assert_eq_tool_names(&tools, &expected);
assert_eq!(
tools[8].spec,
tools[if has_shell { 10 } else { 9 }].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/paginate".to_string(),
parameters: JsonSchema::Object {
@@ -1456,22 +1535,26 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/tags",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
let has_shell = if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
true
} else {
false
};
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/tags",
]);
assert_eq_tool_names(&tools, &expected);
assert_eq!(
tools[8].spec,
tools[if has_shell { 10 } else { 9 }].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/tags".to_string(),
parameters: JsonSchema::Object {
@@ -1525,22 +1608,26 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/value",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
let has_shell = if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
true
} else {
false
};
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"dash/value",
]);
assert_eq_tool_names(&tools, &expected);
assert_eq!(
tools[8].spec,
tools[if has_shell { 10 } else { 9 }].spec,
ToolSpec::Function(ResponsesApiTool {
name: "dash/value".to_string(),
parameters: JsonSchema::Object {
@@ -1631,23 +1718,28 @@ mod tests {
)
.build();
assert_eq_tool_names(
&tools,
&[
"unified_exec",
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"test_server/do_something_cool",
],
);
let mut expected = vec!["exec_command", "write_stdin"];
let has_shell = if let Some(shell_tool) = shell_tool_name(&config) {
expected.push(shell_tool);
true
} else {
false
};
expected.extend([
"list_mcp_resources",
"list_mcp_resource_templates",
"read_mcp_resource",
"update_plan",
"apply_patch",
"web_search",
"view_image",
"test_server/do_something_cool",
]);
assert_eq_tool_names(&tools, &expected);
assert_eq!(
tools[8].spec,
tools[if has_shell { 10 } else { 9 }].spec,
ToolSpec::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {