Unified execution (#3288)
## Unified PTY-Based Exec Tool
Note: this requires to have this flag in the config:
`use_experimental_unified_exec_tool=true`
- Adds a PTY-backed interactive exec feature (“unified_exec”) with
session reuse via
session_id, bounded output (128 KiB), and timeout clamping (≤ 60 s).
- Protocol: introduces ResponseItem::UnifiedExec { session_id,
arguments, timeout_ms }.
- Tools: exposes unified_exec as a function tool (Responses API);
excluded from Chat
Completions payload while still supported in tool lists.
- Path handling: resolves commands via PATH (or explicit paths), with
UTF‑8/newline‑aware
truncation (truncate_middle).
- Tests: cover command parsing, path resolution, session
persistence/cleanup, multi‑session
isolation, timeouts, and truncation behavior.
This commit is contained in:
@@ -54,6 +54,7 @@ tracing = { version = "0.1.41", features = ["log"] }
|
|||||||
tree-sitter = "0.25.9"
|
tree-sitter = "0.25.9"
|
||||||
tree-sitter-bash = "0.25.0"
|
tree-sitter-bash = "0.25.0"
|
||||||
uuid = { version = "1", features = ["serde", "v4"] }
|
uuid = { version = "1", features = ["serde", "v4"] }
|
||||||
|
which = "6"
|
||||||
wildmatch = "2.4.0"
|
wildmatch = "2.4.0"
|
||||||
|
|
||||||
|
|
||||||
@@ -69,9 +70,6 @@ openssl-sys = { version = "*", features = ["vendored"] }
|
|||||||
[target.aarch64-unknown-linux-musl.dependencies]
|
[target.aarch64-unknown-linux-musl.dependencies]
|
||||||
openssl-sys = { version = "*", features = ["vendored"] }
|
openssl-sys = { version = "*", features = ["vendored"] }
|
||||||
|
|
||||||
[target.'cfg(target_os = "windows")'.dependencies]
|
|
||||||
which = "6"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
assert_cmd = "2"
|
assert_cmd = "2"
|
||||||
core_test_support = { path = "tests/common" }
|
core_test_support = { path = "tests/common" }
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ use codex_protocol::protocol::TurnAbortReason;
|
|||||||
use codex_protocol::protocol::TurnAbortedEvent;
|
use codex_protocol::protocol::TurnAbortedEvent;
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use mcp_types::CallToolResult;
|
use mcp_types::CallToolResult;
|
||||||
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json;
|
use serde_json;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
@@ -112,6 +113,7 @@ use crate::safety::assess_command_safety;
|
|||||||
use crate::safety::assess_safety_for_untrusted_command;
|
use crate::safety::assess_safety_for_untrusted_command;
|
||||||
use crate::shell;
|
use crate::shell;
|
||||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
|
use crate::unified_exec::UnifiedExecSessionManager;
|
||||||
use crate::user_instructions::UserInstructions;
|
use crate::user_instructions::UserInstructions;
|
||||||
use crate::user_notification::UserNotification;
|
use crate::user_notification::UserNotification;
|
||||||
use crate::util::backoff;
|
use crate::util::backoff;
|
||||||
@@ -280,6 +282,7 @@ pub(crate) struct Session {
|
|||||||
/// Manager for external MCP servers/tools.
|
/// Manager for external MCP servers/tools.
|
||||||
mcp_connection_manager: McpConnectionManager,
|
mcp_connection_manager: McpConnectionManager,
|
||||||
session_manager: ExecSessionManager,
|
session_manager: ExecSessionManager,
|
||||||
|
unified_exec_manager: UnifiedExecSessionManager,
|
||||||
|
|
||||||
/// External notifier command (will be passed as args to exec()). When
|
/// External notifier command (will be passed as args to exec()). When
|
||||||
/// `None` this feature is disabled.
|
/// `None` this feature is disabled.
|
||||||
@@ -471,6 +474,7 @@ impl Session {
|
|||||||
include_web_search_request: config.tools_web_search_request,
|
include_web_search_request: config.tools_web_search_request,
|
||||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||||
include_view_image_tool: config.include_view_image_tool,
|
include_view_image_tool: config.include_view_image_tool,
|
||||||
|
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||||
}),
|
}),
|
||||||
user_instructions,
|
user_instructions,
|
||||||
base_instructions,
|
base_instructions,
|
||||||
@@ -484,6 +488,7 @@ impl Session {
|
|||||||
tx_event: tx_event.clone(),
|
tx_event: tx_event.clone(),
|
||||||
mcp_connection_manager,
|
mcp_connection_manager,
|
||||||
session_manager: ExecSessionManager::default(),
|
session_manager: ExecSessionManager::default(),
|
||||||
|
unified_exec_manager: UnifiedExecSessionManager::default(),
|
||||||
notify,
|
notify,
|
||||||
state: Mutex::new(state),
|
state: Mutex::new(state),
|
||||||
rollout: Mutex::new(Some(rollout_recorder)),
|
rollout: Mutex::new(Some(rollout_recorder)),
|
||||||
@@ -1149,6 +1154,7 @@ async fn submission_loop(
|
|||||||
include_web_search_request: config.tools_web_search_request,
|
include_web_search_request: config.tools_web_search_request,
|
||||||
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
use_streamable_shell_tool: config.use_experimental_streamable_shell_tool,
|
||||||
include_view_image_tool: config.include_view_image_tool,
|
include_view_image_tool: config.include_view_image_tool,
|
||||||
|
experimental_unified_exec_tool: config.use_experimental_unified_exec_tool,
|
||||||
});
|
});
|
||||||
|
|
||||||
let new_turn_context = TurnContext {
|
let new_turn_context = TurnContext {
|
||||||
@@ -1251,6 +1257,8 @@ async fn submission_loop(
|
|||||||
use_streamable_shell_tool: config
|
use_streamable_shell_tool: config
|
||||||
.use_experimental_streamable_shell_tool,
|
.use_experimental_streamable_shell_tool,
|
||||||
include_view_image_tool: config.include_view_image_tool,
|
include_view_image_tool: config.include_view_image_tool,
|
||||||
|
experimental_unified_exec_tool: config
|
||||||
|
.use_experimental_unified_exec_tool,
|
||||||
}),
|
}),
|
||||||
user_instructions: turn_context.user_instructions.clone(),
|
user_instructions: turn_context.user_instructions.clone(),
|
||||||
base_instructions: turn_context.base_instructions.clone(),
|
base_instructions: turn_context.base_instructions.clone(),
|
||||||
@@ -2082,6 +2090,72 @@ async fn handle_response_item(
|
|||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_unified_exec_tool_call(
|
||||||
|
sess: &Session,
|
||||||
|
call_id: String,
|
||||||
|
session_id: Option<String>,
|
||||||
|
arguments: Vec<String>,
|
||||||
|
timeout_ms: Option<u64>,
|
||||||
|
) -> ResponseInputItem {
|
||||||
|
let parsed_session_id = if let Some(session_id) = session_id {
|
||||||
|
match session_id.parse::<i32>() {
|
||||||
|
Ok(parsed) => Some(parsed),
|
||||||
|
Err(output) => {
|
||||||
|
return ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id: call_id.to_string(),
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: format!("invalid session_id: {session_id} due to error {output}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = crate::unified_exec::UnifiedExecRequest {
|
||||||
|
session_id: parsed_session_id,
|
||||||
|
input_chunks: &arguments,
|
||||||
|
timeout_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = sess.unified_exec_manager.handle_request(request).await;
|
||||||
|
|
||||||
|
let output_payload = match result {
|
||||||
|
Ok(value) => {
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct SerializedUnifiedExecResult<'a> {
|
||||||
|
session_id: Option<String>,
|
||||||
|
output: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::to_string(&SerializedUnifiedExecResult {
|
||||||
|
session_id: value.session_id.map(|id| id.to_string()),
|
||||||
|
output: &value.output,
|
||||||
|
}) {
|
||||||
|
Ok(serialized) => FunctionCallOutputPayload {
|
||||||
|
content: serialized,
|
||||||
|
success: Some(true),
|
||||||
|
},
|
||||||
|
Err(err) => FunctionCallOutputPayload {
|
||||||
|
content: format!("failed to serialize unified exec output: {err}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => FunctionCallOutputPayload {
|
||||||
|
content: format!("unified exec failed: {err}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: output_payload,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_function_call(
|
async fn handle_function_call(
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
turn_context: &TurnContext,
|
turn_context: &TurnContext,
|
||||||
@@ -2109,6 +2183,38 @@ async fn handle_function_call(
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
"unified_exec" => {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UnifiedExecArgs {
|
||||||
|
input: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
session_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
timeout_ms: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let args = match serde_json::from_str::<UnifiedExecArgs>(&arguments) {
|
||||||
|
Ok(args) => args,
|
||||||
|
Err(err) => {
|
||||||
|
return ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: format!("failed to parse function arguments: {err}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
handle_unified_exec_tool_call(
|
||||||
|
sess,
|
||||||
|
call_id,
|
||||||
|
args.session_id,
|
||||||
|
args.input,
|
||||||
|
args.timeout_ms,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
"view_image" => {
|
"view_image" => {
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct SeeImageArgs {
|
struct SeeImageArgs {
|
||||||
|
|||||||
@@ -172,6 +172,9 @@ pub struct Config {
|
|||||||
|
|
||||||
pub use_experimental_streamable_shell_tool: bool,
|
pub use_experimental_streamable_shell_tool: bool,
|
||||||
|
|
||||||
|
/// If set to `true`, used only the experimental unified exec tool.
|
||||||
|
pub use_experimental_unified_exec_tool: bool,
|
||||||
|
|
||||||
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
/// Include the `view_image` tool that lets the agent attach a local image path to context.
|
||||||
pub include_view_image_tool: bool,
|
pub include_view_image_tool: bool,
|
||||||
|
|
||||||
@@ -487,6 +490,7 @@ pub struct ConfigToml {
|
|||||||
pub experimental_instructions_file: Option<PathBuf>,
|
pub experimental_instructions_file: Option<PathBuf>,
|
||||||
|
|
||||||
pub experimental_use_exec_command_tool: Option<bool>,
|
pub experimental_use_exec_command_tool: Option<bool>,
|
||||||
|
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||||
|
|
||||||
pub projects: Option<HashMap<String, ProjectConfig>>,
|
pub projects: Option<HashMap<String, ProjectConfig>>,
|
||||||
|
|
||||||
@@ -837,6 +841,9 @@ impl Config {
|
|||||||
use_experimental_streamable_shell_tool: cfg
|
use_experimental_streamable_shell_tool: cfg
|
||||||
.experimental_use_exec_command_tool
|
.experimental_use_exec_command_tool
|
||||||
.unwrap_or(false),
|
.unwrap_or(false),
|
||||||
|
use_experimental_unified_exec_tool: cfg
|
||||||
|
.experimental_use_unified_exec_tool
|
||||||
|
.unwrap_or(true),
|
||||||
include_view_image_tool,
|
include_view_image_tool,
|
||||||
active_profile: active_profile_name,
|
active_profile: active_profile_name,
|
||||||
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false),
|
||||||
@@ -1212,6 +1219,7 @@ model_verbosity = "high"
|
|||||||
tools_web_search_request: false,
|
tools_web_search_request: false,
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
use_experimental_streamable_shell_tool: false,
|
use_experimental_streamable_shell_tool: false,
|
||||||
|
use_experimental_unified_exec_tool: true,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
active_profile: Some("o3".to_string()),
|
active_profile: Some("o3".to_string()),
|
||||||
disable_paste_burst: false,
|
disable_paste_burst: false,
|
||||||
@@ -1269,6 +1277,7 @@ model_verbosity = "high"
|
|||||||
tools_web_search_request: false,
|
tools_web_search_request: false,
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
use_experimental_streamable_shell_tool: false,
|
use_experimental_streamable_shell_tool: false,
|
||||||
|
use_experimental_unified_exec_tool: true,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
active_profile: Some("gpt3".to_string()),
|
active_profile: Some("gpt3".to_string()),
|
||||||
disable_paste_burst: false,
|
disable_paste_burst: false,
|
||||||
@@ -1341,6 +1350,7 @@ model_verbosity = "high"
|
|||||||
tools_web_search_request: false,
|
tools_web_search_request: false,
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
use_experimental_streamable_shell_tool: false,
|
use_experimental_streamable_shell_tool: false,
|
||||||
|
use_experimental_unified_exec_tool: true,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
active_profile: Some("zdr".to_string()),
|
active_profile: Some("zdr".to_string()),
|
||||||
disable_paste_burst: false,
|
disable_paste_burst: false,
|
||||||
@@ -1399,6 +1409,7 @@ model_verbosity = "high"
|
|||||||
tools_web_search_request: false,
|
tools_web_search_request: false,
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
use_experimental_streamable_shell_tool: false,
|
use_experimental_streamable_shell_tool: false,
|
||||||
|
use_experimental_unified_exec_tool: true,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
active_profile: Some("gpt5".to_string()),
|
active_profile: Some("gpt5".to_string()),
|
||||||
disable_paste_burst: false,
|
disable_paste_burst: false,
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession {
|
|||||||
|
|
||||||
/// JoinHandle for the child wait task.
|
/// JoinHandle for the child wait task.
|
||||||
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||||
|
|
||||||
|
/// Tracks whether the underlying process has exited.
|
||||||
|
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ExecCommandSession {
|
impl ExecCommandSession {
|
||||||
@@ -34,6 +37,7 @@ impl ExecCommandSession {
|
|||||||
reader_handle: JoinHandle<()>,
|
reader_handle: JoinHandle<()>,
|
||||||
writer_handle: JoinHandle<()>,
|
writer_handle: JoinHandle<()>,
|
||||||
wait_handle: JoinHandle<()>,
|
wait_handle: JoinHandle<()>,
|
||||||
|
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
writer_tx,
|
writer_tx,
|
||||||
@@ -42,6 +46,7 @@ impl ExecCommandSession {
|
|||||||
reader_handle: StdMutex::new(Some(reader_handle)),
|
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||||
writer_handle: StdMutex::new(Some(writer_handle)),
|
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||||
wait_handle: StdMutex::new(Some(wait_handle)),
|
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||||
|
exit_status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +57,10 @@ impl ExecCommandSession {
|
|||||||
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||||
self.output_tx.subscribe()
|
self.output_tx.subscribe()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn has_exited(&self) -> bool {
|
||||||
|
self.exit_status.load(std::sync::atomic::Ordering::SeqCst)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for ExecCommandSession {
|
impl Drop for ExecCommandSession {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ mod session_manager;
|
|||||||
|
|
||||||
pub use exec_command_params::ExecCommandParams;
|
pub use exec_command_params::ExecCommandParams;
|
||||||
pub use exec_command_params::WriteStdinParams;
|
pub use exec_command_params::WriteStdinParams;
|
||||||
|
pub(crate) use exec_command_session::ExecCommandSession;
|
||||||
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||||
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||||
pub use responses_api::create_exec_command_tool_for_responses_api;
|
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use std::io::ErrorKind;
|
|||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::Mutex as StdMutex;
|
use std::sync::Mutex as StdMutex;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
use std::sync::atomic::AtomicU32;
|
use std::sync::atomic::AtomicU32;
|
||||||
|
|
||||||
use portable_pty::CommandBuilder;
|
use portable_pty::CommandBuilder;
|
||||||
@@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams;
|
|||||||
use crate::exec_command::exec_command_params::WriteStdinParams;
|
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||||
use crate::exec_command::exec_command_session::ExecCommandSession;
|
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||||
use crate::exec_command::session_id::SessionId;
|
use crate::exec_command::session_id::SessionId;
|
||||||
|
use crate::truncate::truncate_middle;
|
||||||
use codex_protocol::models::FunctionCallOutputPayload;
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
@@ -327,11 +329,14 @@ async fn create_exec_command_session(
|
|||||||
|
|
||||||
// Keep the child alive until it exits, then signal exit code.
|
// Keep the child alive until it exits, then signal exit code.
|
||||||
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||||
|
let exit_status = Arc::new(AtomicBool::new(false));
|
||||||
|
let wait_exit_status = exit_status.clone();
|
||||||
let wait_handle = tokio::task::spawn_blocking(move || {
|
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||||
let code = match child.wait() {
|
let code = match child.wait() {
|
||||||
Ok(status) => status.exit_code() as i32,
|
Ok(status) => status.exit_code() as i32,
|
||||||
Err(_) => -1,
|
Err(_) => -1,
|
||||||
};
|
};
|
||||||
|
wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||||
let _ = exit_tx.send(code);
|
let _ = exit_tx.send(code);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -343,116 +348,11 @@ async fn create_exec_command_session(
|
|||||||
reader_handle,
|
reader_handle,
|
||||||
writer_handle,
|
writer_handle,
|
||||||
wait_handle,
|
wait_handle,
|
||||||
|
exit_status,
|
||||||
);
|
);
|
||||||
Ok((session, exit_rx))
|
Ok((session, exit_rx))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
|
||||||
/// preserving the beginning and the end. Returns the possibly truncated
|
|
||||||
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
|
||||||
/// if truncation occurred; otherwise returns the original string and `None`.
|
|
||||||
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
|
||||||
// No truncation needed
|
|
||||||
if s.len() <= max_bytes {
|
|
||||||
return (s.to_string(), None);
|
|
||||||
}
|
|
||||||
let est_tokens = (s.len() as u64).div_ceil(4);
|
|
||||||
if max_bytes == 0 {
|
|
||||||
// Cannot keep any content; still return a full marker (never truncated).
|
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to truncate a string to a given byte length on a char boundary.
|
|
||||||
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
|
||||||
if input.len() <= max_len {
|
|
||||||
return input;
|
|
||||||
}
|
|
||||||
let mut end = max_len;
|
|
||||||
while end > 0 && !input.is_char_boundary(end) {
|
|
||||||
end -= 1;
|
|
||||||
}
|
|
||||||
&input[..end]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
|
||||||
// to UTF-8 char boundaries.
|
|
||||||
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
|
||||||
if let Some(head) = s.get(..left_budget)
|
|
||||||
&& let Some(i) = head.rfind('\n')
|
|
||||||
{
|
|
||||||
return i + 1; // keep the newline so suffix starts on a fresh line
|
|
||||||
}
|
|
||||||
truncate_on_boundary(s, left_budget).len()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
|
||||||
let start_tail = s.len().saturating_sub(right_budget);
|
|
||||||
if let Some(tail) = s.get(start_tail..)
|
|
||||||
&& let Some(i) = tail.find('\n')
|
|
||||||
{
|
|
||||||
return start_tail + i + 1; // start after newline
|
|
||||||
}
|
|
||||||
// Fall back to a char boundary at or after start_tail.
|
|
||||||
let mut idx = start_tail.min(s.len());
|
|
||||||
while idx < s.len() && !s.is_char_boundary(idx) {
|
|
||||||
idx += 1;
|
|
||||||
}
|
|
||||||
idx
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refine marker length and budgets until stable. Marker is never truncated.
|
|
||||||
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
|
||||||
for _ in 0..4 {
|
|
||||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
|
||||||
let marker_len = marker.len();
|
|
||||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
|
||||||
if keep_budget == 0 {
|
|
||||||
// No room for any content within the cap; return a full, untruncated marker
|
|
||||||
// that reflects the entire truncated content.
|
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
|
||||||
}
|
|
||||||
|
|
||||||
let left_budget = keep_budget / 2;
|
|
||||||
let right_budget = keep_budget - left_budget;
|
|
||||||
let prefix_end = pick_prefix_end(s, left_budget);
|
|
||||||
let mut suffix_start = pick_suffix_start(s, right_budget);
|
|
||||||
if suffix_start < prefix_end {
|
|
||||||
suffix_start = prefix_end;
|
|
||||||
}
|
|
||||||
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
|
||||||
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
|
||||||
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
|
||||||
if new_tokens == guess_tokens {
|
|
||||||
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
|
||||||
out.push_str(&s[..prefix_end]);
|
|
||||||
out.push_str(&marker);
|
|
||||||
// Place marker on its own line for symmetry when we keep line boundaries.
|
|
||||||
out.push('\n');
|
|
||||||
out.push_str(&s[suffix_start..]);
|
|
||||||
return (out, Some(est_tokens));
|
|
||||||
}
|
|
||||||
guess_tokens = new_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: use last guess to build output.
|
|
||||||
let marker = format!("…{guess_tokens} tokens truncated…");
|
|
||||||
let marker_len = marker.len();
|
|
||||||
let keep_budget = max_bytes.saturating_sub(marker_len);
|
|
||||||
if keep_budget == 0 {
|
|
||||||
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
|
||||||
}
|
|
||||||
let left_budget = keep_budget / 2;
|
|
||||||
let right_budget = keep_budget - left_budget;
|
|
||||||
let prefix_end = pick_prefix_end(s, left_budget);
|
|
||||||
let suffix_start = pick_suffix_start(s, right_budget);
|
|
||||||
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
|
||||||
out.push_str(&s[..prefix_end]);
|
|
||||||
out.push_str(&marker);
|
|
||||||
out.push('\n');
|
|
||||||
out.push_str(&s[suffix_start..]);
|
|
||||||
(out, Some(est_tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -616,50 +516,4 @@ Output:
|
|||||||
abc"#;
|
abc"#;
|
||||||
assert_eq!(expected, text);
|
assert_eq!(expected, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn truncate_middle_no_newlines_fallback() {
|
|
||||||
// A long string with no newlines that exceeds the cap.
|
|
||||||
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
|
||||||
let max_bytes = 16; // force truncation
|
|
||||||
let (out, original) = truncate_middle(s, max_bytes);
|
|
||||||
// For very small caps, we return the full, untruncated marker,
|
|
||||||
// even if it exceeds the cap.
|
|
||||||
assert_eq!(out, "…16 tokens truncated…");
|
|
||||||
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
|
||||||
assert_eq!(original, Some(16));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn truncate_middle_prefers_newline_boundaries() {
|
|
||||||
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
|
||||||
let mut s = String::new();
|
|
||||||
for i in 1..=20 {
|
|
||||||
s.push_str(&format!("{i:03}\n"));
|
|
||||||
}
|
|
||||||
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
|
||||||
assert_eq!(s.len(), 80);
|
|
||||||
|
|
||||||
// Choose a cap that forces truncation while leaving room for
|
|
||||||
// a few lines on each side after accounting for the marker.
|
|
||||||
let max_bytes = 64;
|
|
||||||
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
|
||||||
assert_eq!(
|
|
||||||
truncate_middle(&s, max_bytes),
|
|
||||||
(
|
|
||||||
r#"001
|
|
||||||
002
|
|
||||||
003
|
|
||||||
004
|
|
||||||
…12 tokens truncated…
|
|
||||||
017
|
|
||||||
018
|
|
||||||
019
|
|
||||||
020
|
|
||||||
"#
|
|
||||||
.to_string(),
|
|
||||||
Some(20)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ mod mcp_tool_call;
|
|||||||
mod message_history;
|
mod message_history;
|
||||||
mod model_provider_info;
|
mod model_provider_info;
|
||||||
pub mod parse_command;
|
pub mod parse_command;
|
||||||
|
mod truncate;
|
||||||
|
mod unified_exec;
|
||||||
mod user_instructions;
|
mod user_instructions;
|
||||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||||
pub use model_provider_info::ModelProviderInfo;
|
pub use model_provider_info::ModelProviderInfo;
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ pub(crate) struct ToolsConfig {
|
|||||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||||
pub web_search_request: bool,
|
pub web_search_request: bool,
|
||||||
pub include_view_image_tool: bool,
|
pub include_view_image_tool: bool,
|
||||||
|
pub experimental_unified_exec_tool: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct ToolsConfigParams<'a> {
|
pub(crate) struct ToolsConfigParams<'a> {
|
||||||
@@ -81,6 +82,7 @@ pub(crate) struct ToolsConfigParams<'a> {
|
|||||||
pub(crate) include_web_search_request: bool,
|
pub(crate) include_web_search_request: bool,
|
||||||
pub(crate) use_streamable_shell_tool: bool,
|
pub(crate) use_streamable_shell_tool: bool,
|
||||||
pub(crate) include_view_image_tool: bool,
|
pub(crate) include_view_image_tool: bool,
|
||||||
|
pub(crate) experimental_unified_exec_tool: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToolsConfig {
|
impl ToolsConfig {
|
||||||
@@ -94,6 +96,7 @@ impl ToolsConfig {
|
|||||||
include_web_search_request,
|
include_web_search_request,
|
||||||
use_streamable_shell_tool,
|
use_streamable_shell_tool,
|
||||||
include_view_image_tool,
|
include_view_image_tool,
|
||||||
|
experimental_unified_exec_tool,
|
||||||
} = params;
|
} = params;
|
||||||
let mut shell_type = if *use_streamable_shell_tool {
|
let mut shell_type = if *use_streamable_shell_tool {
|
||||||
ConfigShellToolType::StreamableShell
|
ConfigShellToolType::StreamableShell
|
||||||
@@ -126,6 +129,7 @@ impl ToolsConfig {
|
|||||||
apply_patch_tool_type,
|
apply_patch_tool_type,
|
||||||
web_search_request: *include_web_search_request,
|
web_search_request: *include_web_search_request,
|
||||||
include_view_image_tool: *include_view_image_tool,
|
include_view_image_tool: *include_view_image_tool,
|
||||||
|
experimental_unified_exec_tool: *experimental_unified_exec_tool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -200,6 +204,53 @@ fn create_shell_tool() -> OpenAiTool {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_unified_exec_tool() -> OpenAiTool {
|
||||||
|
let mut properties = BTreeMap::new();
|
||||||
|
properties.insert(
|
||||||
|
"input".to_string(),
|
||||||
|
JsonSchema::Array {
|
||||||
|
items: Box::new(JsonSchema::String { description: None }),
|
||||||
|
description: Some(
|
||||||
|
"When no session_id is provided, treat the array as the command and arguments \
|
||||||
|
to launch. When session_id is set, concatenate the strings (in order) and write \
|
||||||
|
them to the session's stdin."
|
||||||
|
.to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"session_id".to_string(),
|
||||||
|
JsonSchema::String {
|
||||||
|
description: Some(
|
||||||
|
"Identifier for an existing interactive session. If omitted, a new command \
|
||||||
|
is spawned."
|
||||||
|
.to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"timeout_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some(
|
||||||
|
"Maximum time in milliseconds to wait for output after writing the input."
|
||||||
|
.to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
OpenAiTool::Function(ResponsesApiTool {
|
||||||
|
name: "unified_exec".to_string(),
|
||||||
|
description:
|
||||||
|
"Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(),
|
||||||
|
strict: false,
|
||||||
|
parameters: JsonSchema::Object {
|
||||||
|
properties,
|
||||||
|
required: Some(vec!["input".to_string()]),
|
||||||
|
additional_properties: Some(false),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool {
|
||||||
let mut properties = BTreeMap::new();
|
let mut properties = BTreeMap::new();
|
||||||
properties.insert(
|
properties.insert(
|
||||||
@@ -531,23 +582,27 @@ pub(crate) fn get_openai_tools(
|
|||||||
) -> Vec<OpenAiTool> {
|
) -> Vec<OpenAiTool> {
|
||||||
let mut tools: Vec<OpenAiTool> = Vec::new();
|
let mut tools: Vec<OpenAiTool> = Vec::new();
|
||||||
|
|
||||||
match &config.shell_type {
|
if config.experimental_unified_exec_tool {
|
||||||
ConfigShellToolType::DefaultShell => {
|
tools.push(create_unified_exec_tool());
|
||||||
tools.push(create_shell_tool());
|
} else {
|
||||||
}
|
match &config.shell_type {
|
||||||
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
ConfigShellToolType::DefaultShell => {
|
||||||
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
tools.push(create_shell_tool());
|
||||||
}
|
}
|
||||||
ConfigShellToolType::LocalShell => {
|
ConfigShellToolType::ShellWithRequest { sandbox_policy } => {
|
||||||
tools.push(OpenAiTool::LocalShell {});
|
tools.push(create_shell_tool_for_sandbox(sandbox_policy));
|
||||||
}
|
}
|
||||||
ConfigShellToolType::StreamableShell => {
|
ConfigShellToolType::LocalShell => {
|
||||||
tools.push(OpenAiTool::Function(
|
tools.push(OpenAiTool::LocalShell {});
|
||||||
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
}
|
||||||
));
|
ConfigShellToolType::StreamableShell => {
|
||||||
tools.push(OpenAiTool::Function(
|
tools.push(OpenAiTool::Function(
|
||||||
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||||
));
|
));
|
||||||
|
tools.push(OpenAiTool::Function(
|
||||||
|
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -574,10 +629,8 @@ pub(crate) fn get_openai_tools(
|
|||||||
if config.include_view_image_tool {
|
if config.include_view_image_tool {
|
||||||
tools.push(create_view_image_tool());
|
tools.push(create_view_image_tool());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(mcp_tools) = mcp_tools {
|
if let Some(mcp_tools) = mcp_tools {
|
||||||
// Ensure deterministic ordering to maximize prompt cache hits.
|
// Ensure deterministic ordering to maximize prompt cache hits.
|
||||||
// HashMap iteration order is non-deterministic, so sort by fully-qualified tool name.
|
|
||||||
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect();
|
||||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
|
|
||||||
@@ -639,12 +692,13 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||||
|
|
||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&["local_shell", "update_plan", "web_search", "view_image"],
|
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -660,12 +714,13 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||||
|
|
||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&["shell", "update_plan", "web_search", "view_image"],
|
&["unified_exec", "update_plan", "web_search", "view_image"],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -681,6 +736,7 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
&config,
|
&config,
|
||||||
@@ -723,7 +779,7 @@ mod tests {
|
|||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&[
|
&[
|
||||||
"shell",
|
"unified_exec",
|
||||||
"web_search",
|
"web_search",
|
||||||
"view_image",
|
"view_image",
|
||||||
"test_server/do_something_cool",
|
"test_server/do_something_cool",
|
||||||
@@ -786,6 +842,7 @@ mod tests {
|
|||||||
include_web_search_request: false,
|
include_web_search_request: false,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Intentionally construct a map with keys that would sort alphabetically.
|
// Intentionally construct a map with keys that would sort alphabetically.
|
||||||
@@ -838,11 +895,11 @@ mod tests {
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
let tools = get_openai_tools(&config, Some(tools_map));
|
let tools = get_openai_tools(&config, Some(tools_map));
|
||||||
// Expect shell first, followed by MCP tools sorted by fully-qualified name.
|
// Expect unified_exec first, followed by MCP tools sorted by fully-qualified name.
|
||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&[
|
&[
|
||||||
"shell",
|
"unified_exec",
|
||||||
"view_image",
|
"view_image",
|
||||||
"test_server/cool",
|
"test_server/cool",
|
||||||
"test_server/do",
|
"test_server/do",
|
||||||
@@ -863,6 +920,7 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -890,7 +948,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&["shell", "web_search", "view_image", "dash/search"],
|
&["unified_exec", "web_search", "view_image", "dash/search"],
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -925,6 +983,7 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -950,7 +1009,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq_tool_names(
|
assert_eq_tool_names(
|
||||||
&tools,
|
&tools,
|
||||||
&["shell", "web_search", "view_image", "dash/paginate"],
|
&["unified_exec", "web_search", "view_image", "dash/paginate"],
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[3],
|
tools[3],
|
||||||
@@ -982,6 +1041,7 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -1005,7 +1065,10 @@ mod tests {
|
|||||||
)])),
|
)])),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]);
|
assert_eq_tool_names(
|
||||||
|
&tools,
|
||||||
|
&["unified_exec", "web_search", "view_image", "dash/tags"],
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[3],
|
tools[3],
|
||||||
OpenAiTool::Function(ResponsesApiTool {
|
OpenAiTool::Function(ResponsesApiTool {
|
||||||
@@ -1039,6 +1102,7 @@ mod tests {
|
|||||||
include_web_search_request: true,
|
include_web_search_request: true,
|
||||||
use_streamable_shell_tool: false,
|
use_streamable_shell_tool: false,
|
||||||
include_view_image_tool: true,
|
include_view_image_tool: true,
|
||||||
|
experimental_unified_exec_tool: true,
|
||||||
});
|
});
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -1062,7 +1126,10 @@ mod tests {
|
|||||||
)])),
|
)])),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]);
|
assert_eq_tool_names(
|
||||||
|
&tools,
|
||||||
|
&["unified_exec", "web_search", "view_image", "dash/value"],
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tools[3],
|
tools[3],
|
||||||
OpenAiTool::Function(ResponsesApiTool {
|
OpenAiTool::Function(ResponsesApiTool {
|
||||||
|
|||||||
180
codex-rs/core/src/truncate.rs
Normal file
180
codex-rs/core/src/truncate.rs
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
//! Utilities for truncating large chunks of output while preserving a prefix
|
||||||
|
//! and suffix on UTF-8 boundaries.
|
||||||
|
|
||||||
|
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||||
|
/// preserving the beginning and the end. Returns the possibly truncated
|
||||||
|
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||||
|
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||||
|
pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||||
|
if s.len() <= max_bytes {
|
||||||
|
return (s.to_string(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||||
|
if max_bytes == 0 {
|
||||||
|
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||||
|
if input.len() <= max_len {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
let mut end = max_len;
|
||||||
|
while end > 0 && !input.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
&input[..end]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||||
|
if let Some(head) = s.get(..left_budget)
|
||||||
|
&& let Some(i) = head.rfind('\n')
|
||||||
|
{
|
||||||
|
return i + 1;
|
||||||
|
}
|
||||||
|
truncate_on_boundary(s, left_budget).len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||||
|
let start_tail = s.len().saturating_sub(right_budget);
|
||||||
|
if let Some(tail) = s.get(start_tail..)
|
||||||
|
&& let Some(i) = tail.find('\n')
|
||||||
|
{
|
||||||
|
return start_tail + i + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut idx = start_tail.min(s.len());
|
||||||
|
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
idx
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut guess_tokens = est_tokens;
|
||||||
|
for _ in 0..4 {
|
||||||
|
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||||
|
let marker_len = marker.len();
|
||||||
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
|
if keep_budget == 0 {
|
||||||
|
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
let left_budget = keep_budget / 2;
|
||||||
|
let right_budget = keep_budget - left_budget;
|
||||||
|
let prefix_end = pick_prefix_end(s, left_budget);
|
||||||
|
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||||
|
if suffix_start < prefix_end {
|
||||||
|
suffix_start = prefix_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||||
|
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||||
|
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||||
|
|
||||||
|
if new_tokens == guess_tokens {
|
||||||
|
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||||
|
out.push_str(&s[..prefix_end]);
|
||||||
|
out.push_str(&marker);
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(&s[suffix_start..]);
|
||||||
|
return (out, Some(est_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
guess_tokens = new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
let marker = format!("…{guess_tokens} tokens truncated…");
|
||||||
|
let marker_len = marker.len();
|
||||||
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
|
if keep_budget == 0 {
|
||||||
|
return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
let left_budget = keep_budget / 2;
|
||||||
|
let right_budget = keep_budget - left_budget;
|
||||||
|
let prefix_end = pick_prefix_end(s, left_budget);
|
||||||
|
let suffix_start = pick_suffix_start(s, right_budget);
|
||||||
|
|
||||||
|
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||||
|
out.push_str(&s[..prefix_end]);
|
||||||
|
out.push_str(&marker);
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(&s[suffix_start..]);
|
||||||
|
(out, Some(est_tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::truncate_middle;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_no_newlines_fallback() {
|
||||||
|
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*";
|
||||||
|
let max_bytes = 32;
|
||||||
|
let (out, original) = truncate_middle(s, max_bytes);
|
||||||
|
assert!(out.starts_with("abc"));
|
||||||
|
assert!(out.contains("tokens truncated"));
|
||||||
|
assert!(out.ends_with("XYZ*"));
|
||||||
|
assert_eq!(original, Some((s.len() as u64).div_ceil(4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_prefers_newline_boundaries() {
|
||||||
|
let mut s = String::new();
|
||||||
|
for i in 1..=20 {
|
||||||
|
s.push_str(&format!("{i:03}\n"));
|
||||||
|
}
|
||||||
|
assert_eq!(s.len(), 80);
|
||||||
|
|
||||||
|
let max_bytes = 64;
|
||||||
|
let (out, tokens) = truncate_middle(&s, max_bytes);
|
||||||
|
assert!(out.starts_with("001\n002\n003\n004\n"));
|
||||||
|
assert!(out.contains("tokens truncated"));
|
||||||
|
assert!(out.ends_with("017\n018\n019\n020\n"));
|
||||||
|
assert_eq!(tokens, Some(20));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_handles_utf8_content() {
|
||||||
|
let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n";
|
||||||
|
let max_bytes = 32;
|
||||||
|
let (out, tokens) = truncate_middle(s, max_bytes);
|
||||||
|
|
||||||
|
assert!(out.contains("tokens truncated"));
|
||||||
|
assert!(!out.contains('\u{fffd}'));
|
||||||
|
assert_eq!(tokens, Some((s.len() as u64).div_ceil(4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_prefers_newline_boundaries_2() {
|
||||||
|
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||||
|
let mut s = String::new();
|
||||||
|
for i in 1..=20 {
|
||||||
|
s.push_str(&format!("{i:03}\n"));
|
||||||
|
}
|
||||||
|
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||||
|
assert_eq!(s.len(), 80);
|
||||||
|
|
||||||
|
// Choose a cap that forces truncation while leaving room for
|
||||||
|
// a few lines on each side after accounting for the marker.
|
||||||
|
let max_bytes = 64;
|
||||||
|
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||||
|
assert_eq!(
|
||||||
|
truncate_middle(&s, max_bytes),
|
||||||
|
(
|
||||||
|
r#"001
|
||||||
|
002
|
||||||
|
003
|
||||||
|
004
|
||||||
|
…12 tokens truncated…
|
||||||
|
017
|
||||||
|
018
|
||||||
|
019
|
||||||
|
020
|
||||||
|
"#
|
||||||
|
.to_string(),
|
||||||
|
Some(20)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
22
codex-rs/core/src/unified_exec/errors.rs
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub(crate) enum UnifiedExecError {
|
||||||
|
#[error("Failed to create unified exec session: {pty_error}")]
|
||||||
|
CreateSession {
|
||||||
|
#[source]
|
||||||
|
pty_error: anyhow::Error,
|
||||||
|
},
|
||||||
|
#[error("Unknown session id {session_id}")]
|
||||||
|
UnknownSessionId { session_id: i32 },
|
||||||
|
#[error("failed to write to stdin")]
|
||||||
|
WriteToStdin,
|
||||||
|
#[error("missing command line for unified exec request")]
|
||||||
|
MissingCommandLine,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnifiedExecError {
|
||||||
|
pub(crate) fn create_session(error: anyhow::Error) -> Self {
|
||||||
|
Self::CreateSession { pty_error: error }
|
||||||
|
}
|
||||||
|
}
|
||||||
653
codex-rs/core/src/unified_exec/mod.rs
Normal file
653
codex-rs/core/src/unified_exec/mod.rs
Normal file
@@ -0,0 +1,653 @@
|
|||||||
|
use portable_pty::CommandBuilder;
|
||||||
|
use portable_pty::PtySize;
|
||||||
|
use portable_pty::native_pty_system;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
use std::sync::atomic::AtomicBool;
|
||||||
|
use std::sync::atomic::AtomicI32;
|
||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::sync::Notify;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
use crate::exec_command::ExecCommandSession;
|
||||||
|
use crate::truncate::truncate_middle;
|
||||||
|
|
||||||
|
mod errors;
|
||||||
|
|
||||||
|
pub(crate) use errors::UnifiedExecError;
|
||||||
|
|
||||||
|
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||||
|
const MAX_TIMEOUT_MS: u64 = 60_000;
|
||||||
|
const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct UnifiedExecRequest<'a> {
|
||||||
|
pub session_id: Option<i32>,
|
||||||
|
pub input_chunks: &'a [String],
|
||||||
|
pub timeout_ms: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub(crate) struct UnifiedExecResult {
|
||||||
|
pub session_id: Option<i32>,
|
||||||
|
pub output: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub(crate) struct UnifiedExecSessionManager {
|
||||||
|
next_session_id: AtomicI32,
|
||||||
|
sessions: Mutex<HashMap<i32, ManagedUnifiedExecSession>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ManagedUnifiedExecSession {
|
||||||
|
session: ExecCommandSession,
|
||||||
|
output_buffer: OutputBuffer,
|
||||||
|
/// Notifies waiters whenever new output has been appended to
|
||||||
|
/// `output_buffer`, allowing clients to poll for fresh data.
|
||||||
|
output_notify: Arc<Notify>,
|
||||||
|
output_task: JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct OutputBufferState {
|
||||||
|
chunks: VecDeque<Vec<u8>>,
|
||||||
|
total_bytes: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OutputBufferState {
|
||||||
|
fn push_chunk(&mut self, chunk: Vec<u8>) {
|
||||||
|
self.total_bytes = self.total_bytes.saturating_add(chunk.len());
|
||||||
|
self.chunks.push_back(chunk);
|
||||||
|
|
||||||
|
let mut excess = self
|
||||||
|
.total_bytes
|
||||||
|
.saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||||
|
|
||||||
|
while excess > 0 {
|
||||||
|
match self.chunks.front_mut() {
|
||||||
|
Some(front) if excess >= front.len() => {
|
||||||
|
excess -= front.len();
|
||||||
|
self.total_bytes = self.total_bytes.saturating_sub(front.len());
|
||||||
|
self.chunks.pop_front();
|
||||||
|
}
|
||||||
|
Some(front) => {
|
||||||
|
front.drain(..excess);
|
||||||
|
self.total_bytes = self.total_bytes.saturating_sub(excess);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drain(&mut self) -> Vec<Vec<u8>> {
|
||||||
|
let drained: Vec<Vec<u8>> = self.chunks.drain(..).collect();
|
||||||
|
self.total_bytes = 0;
|
||||||
|
drained
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutputBuffer = Arc<Mutex<OutputBufferState>>;
|
||||||
|
type OutputHandles = (OutputBuffer, Arc<Notify>);
|
||||||
|
|
||||||
|
impl ManagedUnifiedExecSession {
|
||||||
|
fn new(session: ExecCommandSession) -> Self {
|
||||||
|
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
|
||||||
|
let output_notify = Arc::new(Notify::new());
|
||||||
|
let mut receiver = session.output_receiver();
|
||||||
|
let buffer_clone = Arc::clone(&output_buffer);
|
||||||
|
let notify_clone = Arc::clone(&output_notify);
|
||||||
|
let output_task = tokio::spawn(async move {
|
||||||
|
while let Ok(chunk) = receiver.recv().await {
|
||||||
|
let mut guard = buffer_clone.lock().await;
|
||||||
|
guard.push_chunk(chunk);
|
||||||
|
drop(guard);
|
||||||
|
notify_clone.notify_waiters();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Self {
|
||||||
|
session,
|
||||||
|
output_buffer,
|
||||||
|
output_notify,
|
||||||
|
output_task,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||||
|
self.session.writer_sender()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn output_handles(&self) -> OutputHandles {
|
||||||
|
(
|
||||||
|
Arc::clone(&self.output_buffer),
|
||||||
|
Arc::clone(&self.output_notify),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_exited(&self) -> bool {
|
||||||
|
self.session.has_exited()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ManagedUnifiedExecSession {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.output_task.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnifiedExecSessionManager {
|
||||||
|
pub async fn handle_request(
|
||||||
|
&self,
|
||||||
|
request: UnifiedExecRequest<'_>,
|
||||||
|
) -> Result<UnifiedExecResult, UnifiedExecError> {
|
||||||
|
let (timeout_ms, timeout_warning) = match request.timeout_ms {
|
||||||
|
Some(requested) if requested > MAX_TIMEOUT_MS => (
|
||||||
|
MAX_TIMEOUT_MS,
|
||||||
|
Some(format!(
|
||||||
|
"Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n"
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
Some(requested) => (requested, None),
|
||||||
|
None => (DEFAULT_TIMEOUT_MS, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut new_session: Option<ManagedUnifiedExecSession> = None;
|
||||||
|
let session_id;
|
||||||
|
let writer_tx;
|
||||||
|
let output_buffer;
|
||||||
|
let output_notify;
|
||||||
|
|
||||||
|
if let Some(existing_id) = request.session_id {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
match sessions.get(&existing_id) {
|
||||||
|
Some(session) => {
|
||||||
|
if session.has_exited() {
|
||||||
|
sessions.remove(&existing_id);
|
||||||
|
return Err(UnifiedExecError::UnknownSessionId {
|
||||||
|
session_id: existing_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let (buffer, notify) = session.output_handles();
|
||||||
|
session_id = existing_id;
|
||||||
|
writer_tx = session.writer_sender();
|
||||||
|
output_buffer = buffer;
|
||||||
|
output_notify = notify;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(UnifiedExecError::UnknownSessionId {
|
||||||
|
session_id: existing_id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(sessions);
|
||||||
|
} else {
|
||||||
|
let command = request.input_chunks.to_vec();
|
||||||
|
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
|
||||||
|
let session = create_unified_exec_session(&command).await?;
|
||||||
|
let managed_session = ManagedUnifiedExecSession::new(session);
|
||||||
|
let (buffer, notify) = managed_session.output_handles();
|
||||||
|
writer_tx = managed_session.writer_sender();
|
||||||
|
output_buffer = buffer;
|
||||||
|
output_notify = notify;
|
||||||
|
session_id = new_id;
|
||||||
|
new_session = Some(managed_session);
|
||||||
|
};
|
||||||
|
|
||||||
|
if request.session_id.is_some() {
|
||||||
|
let joined_input = request.input_chunks.join(" ");
|
||||||
|
if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err()
|
||||||
|
{
|
||||||
|
return Err(UnifiedExecError::WriteToStdin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||||
|
let start = Instant::now();
|
||||||
|
let deadline = start + Duration::from_millis(timeout_ms);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let drained_chunks;
|
||||||
|
let mut wait_for_output = None;
|
||||||
|
{
|
||||||
|
let mut guard = output_buffer.lock().await;
|
||||||
|
drained_chunks = guard.drain();
|
||||||
|
if drained_chunks.is_empty() {
|
||||||
|
wait_for_output = Some(output_notify.notified());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if drained_chunks.is_empty() {
|
||||||
|
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||||
|
if remaining == Duration::ZERO {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let notified = wait_for_output.unwrap_or_else(|| output_notify.notified());
|
||||||
|
tokio::pin!(notified);
|
||||||
|
tokio::select! {
|
||||||
|
_ = &mut notified => {}
|
||||||
|
_ = tokio::time::sleep(remaining) => break,
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk in drained_chunks {
|
||||||
|
collected.extend_from_slice(&chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
if Instant::now() >= deadline {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (output, _maybe_tokens) = truncate_middle(
|
||||||
|
&String::from_utf8_lossy(&collected),
|
||||||
|
UNIFIED_EXEC_OUTPUT_MAX_BYTES,
|
||||||
|
);
|
||||||
|
let output = if let Some(warning) = timeout_warning {
|
||||||
|
format!("{warning}{output}")
|
||||||
|
} else {
|
||||||
|
output
|
||||||
|
};
|
||||||
|
|
||||||
|
let should_store_session = if let Some(session) = new_session.as_ref() {
|
||||||
|
!session.has_exited()
|
||||||
|
} else if request.session_id.is_some() {
|
||||||
|
let mut sessions = self.sessions.lock().await;
|
||||||
|
if let Some(existing) = sessions.get(&session_id) {
|
||||||
|
if existing.has_exited() {
|
||||||
|
sessions.remove(&session_id);
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
};
|
||||||
|
|
||||||
|
if should_store_session {
|
||||||
|
if let Some(session) = new_session {
|
||||||
|
self.sessions.lock().await.insert(session_id, session);
|
||||||
|
}
|
||||||
|
Ok(UnifiedExecResult {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(UnifiedExecResult {
|
||||||
|
session_id: None,
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_unified_exec_session(
|
||||||
|
command: &[String],
|
||||||
|
) -> Result<ExecCommandSession, UnifiedExecError> {
|
||||||
|
if command.is_empty() {
|
||||||
|
return Err(UnifiedExecError::MissingCommandLine);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pty_system = native_pty_system();
|
||||||
|
|
||||||
|
let pair = pty_system
|
||||||
|
.openpty(PtySize {
|
||||||
|
rows: 24,
|
||||||
|
cols: 80,
|
||||||
|
pixel_width: 0,
|
||||||
|
pixel_height: 0,
|
||||||
|
})
|
||||||
|
.map_err(UnifiedExecError::create_session)?;
|
||||||
|
|
||||||
|
// Safe thanks to the check at the top of the function.
|
||||||
|
let mut command_builder = CommandBuilder::new(command[0].clone());
|
||||||
|
for arg in &command[1..] {
|
||||||
|
command_builder.arg(arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut child = pair
|
||||||
|
.slave
|
||||||
|
.spawn_command(command_builder)
|
||||||
|
.map_err(UnifiedExecError::create_session)?;
|
||||||
|
let killer = child.clone_killer();
|
||||||
|
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||||
|
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||||
|
|
||||||
|
let mut reader = pair
|
||||||
|
.master
|
||||||
|
.try_clone_reader()
|
||||||
|
.map_err(UnifiedExecError::create_session)?;
|
||||||
|
let output_tx_clone = output_tx.clone();
|
||||||
|
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut buf = [0u8; 8192];
|
||||||
|
loop {
|
||||||
|
match reader.read(&mut buf) {
|
||||||
|
Ok(0) => break,
|
||||||
|
Ok(n) => {
|
||||||
|
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||||
|
std::thread::sleep(Duration::from_millis(5));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let writer = pair
|
||||||
|
.master
|
||||||
|
.take_writer()
|
||||||
|
.map_err(UnifiedExecError::create_session)?;
|
||||||
|
let writer = Arc::new(StdMutex::new(writer));
|
||||||
|
let writer_handle = tokio::spawn({
|
||||||
|
let writer = writer.clone();
|
||||||
|
async move {
|
||||||
|
while let Some(bytes) = writer_rx.recv().await {
|
||||||
|
let writer = writer.clone();
|
||||||
|
let _ = tokio::task::spawn_blocking(move || {
|
||||||
|
if let Ok(mut guard) = writer.lock() {
|
||||||
|
use std::io::Write;
|
||||||
|
let _ = guard.write_all(&bytes);
|
||||||
|
let _ = guard.flush();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let exit_status = Arc::new(AtomicBool::new(false));
|
||||||
|
let wait_exit_status = Arc::clone(&exit_status);
|
||||||
|
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||||
|
let _ = child.wait();
|
||||||
|
wait_exit_status.store(true, Ordering::SeqCst);
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(ExecCommandSession::new(
|
||||||
|
writer_tx,
|
||||||
|
output_tx,
|
||||||
|
killer,
|
||||||
|
reader_handle,
|
||||||
|
writer_handle,
|
||||||
|
wait_handle,
|
||||||
|
exit_status,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn push_chunk_trims_only_excess_bytes() {
|
||||||
|
let mut buffer = OutputBufferState::default();
|
||||||
|
buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]);
|
||||||
|
buffer.push_chunk(vec![b'b']);
|
||||||
|
buffer.push_chunk(vec![b'c']);
|
||||||
|
|
||||||
|
assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES);
|
||||||
|
assert_eq!(buffer.chunks.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
buffer.chunks.front().unwrap().len(),
|
||||||
|
UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2
|
||||||
|
);
|
||||||
|
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']);
|
||||||
|
assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
|
||||||
|
let open_shell = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let session_id = open_shell.session_id.expect("expected session_id");
|
||||||
|
|
||||||
|
manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &[
|
||||||
|
"export".to_string(),
|
||||||
|
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||||
|
],
|
||||||
|
timeout_ms: Some(2_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let out_2 = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
assert!(out_2.output.contains("codex"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
|
||||||
|
let shell_a = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let session_a = shell_a.session_id.expect("expected session id");
|
||||||
|
|
||||||
|
manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_a),
|
||||||
|
input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let out_2 = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &[
|
||||||
|
"echo".to_string(),
|
||||||
|
"$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(),
|
||||||
|
],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
assert!(!out_2.output.contains("codex"));
|
||||||
|
|
||||||
|
let out_3 = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_a),
|
||||||
|
input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
assert!(out_3.output.contains("codex"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
|
||||||
|
let open_shell = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["bash".to_string(), "-i".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let session_id = open_shell.session_id.expect("expected session id");
|
||||||
|
|
||||||
|
manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &[
|
||||||
|
"export".to_string(),
|
||||||
|
"CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(),
|
||||||
|
],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let out_2 = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()],
|
||||||
|
timeout_ms: Some(10),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
assert!(!out_2.output.contains("codex"));
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_secs(7)).await;
|
||||||
|
|
||||||
|
let empty = Vec::new();
|
||||||
|
let out_3 = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &empty,
|
||||||
|
timeout_ms: Some(100),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(out_3.output.contains("codex"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
|
||||||
|
let result = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||||
|
timeout_ms: Some(120_000),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(result.output.starts_with(
|
||||||
|
"Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n"
|
||||||
|
));
|
||||||
|
assert!(result.output.contains("codex"));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
let result = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["/bin/echo".to_string(), "codex".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(result.session_id.is_none());
|
||||||
|
assert!(result.output.contains("codex"));
|
||||||
|
|
||||||
|
assert!(manager.sessions.lock().await.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn correct_path_resolution() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
let result = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["echo".to_string(), "codex".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
assert!(result.session_id.is_none());
|
||||||
|
assert!(result.output.contains("codex"));
|
||||||
|
|
||||||
|
assert!(manager.sessions.lock().await.is_empty());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> {
|
||||||
|
let manager = UnifiedExecSessionManager::default();
|
||||||
|
|
||||||
|
let open_shell = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: None,
|
||||||
|
input_chunks: &["/bin/bash".to_string(), "-i".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
let session_id = open_shell.session_id.expect("expected session id");
|
||||||
|
|
||||||
|
manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &["exit\n".to_string()],
|
||||||
|
timeout_ms: Some(1_500),
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||||
|
|
||||||
|
let err = manager
|
||||||
|
.handle_request(UnifiedExecRequest {
|
||||||
|
session_id: Some(session_id),
|
||||||
|
input_chunks: &[],
|
||||||
|
timeout_ms: Some(100),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect_err("expected unknown session error");
|
||||||
|
|
||||||
|
match err {
|
||||||
|
UnifiedExecError::UnknownSessionId { session_id: err_id } => {
|
||||||
|
assert_eq!(err_id, session_id);
|
||||||
|
}
|
||||||
|
other => panic!("expected UnknownSessionId, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(!manager.sessions.lock().await.contains_key(&session_id));
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -191,7 +191,8 @@ async fn prompt_tools_are_consistent_across_requests() {
|
|||||||
let expected_instructions: &str = include_str!("../../prompt.md");
|
let expected_instructions: &str = include_str!("../../prompt.md");
|
||||||
// our internal implementation is responsible for keeping tools in sync
|
// our internal implementation is responsible for keeping tools in sync
|
||||||
// with the OpenAI schema, so we just verify the tool presence here
|
// with the OpenAI schema, so we just verify the tool presence here
|
||||||
let expected_tools_names: &[&str] = &["shell", "update_plan", "apply_patch", "view_image"];
|
let expected_tools_names: &[&str] =
|
||||||
|
&["unified_exec", "update_plan", "apply_patch", "view_image"];
|
||||||
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
body0["instructions"],
|
body0["instructions"],
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ pub enum ResponseItem {
|
|||||||
status: Option<String>,
|
status: Option<String>,
|
||||||
action: WebSearchAction,
|
action: WebSearchAction,
|
||||||
},
|
},
|
||||||
|
|
||||||
#[serde(other)]
|
#[serde(other)]
|
||||||
Other,
|
Other,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user