feat: StreamableShell with exec_command and write_stdin tools (#2574)
This commit is contained in:
75
codex-rs/Cargo.lock
generated
75
codex-rs/Cargo.lock
generated
@@ -731,6 +731,7 @@ dependencies = [
|
|||||||
"mime_guess",
|
"mime_guess",
|
||||||
"openssl-sys",
|
"openssl-sys",
|
||||||
"os_info",
|
"os_info",
|
||||||
|
"portable-pty",
|
||||||
"predicates",
|
"predicates",
|
||||||
"pretty_assertions",
|
"pretty_assertions",
|
||||||
"rand 0.9.2",
|
"rand 0.9.2",
|
||||||
@@ -1479,6 +1480,12 @@ version = "0.15.7"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "downcast-rs"
|
||||||
|
version = "1.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dupe"
|
name = "dupe"
|
||||||
version = "0.9.1"
|
version = "0.9.1"
|
||||||
@@ -1724,6 +1731,17 @@ dependencies = [
|
|||||||
"simd-adler32",
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "filedescriptor"
|
||||||
|
version = "0.8.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"thiserror 1.0.69",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fixedbitset"
|
name = "fixedbitset"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@@ -3439,6 +3457,27 @@ dependencies = [
|
|||||||
"portable-atomic",
|
"portable-atomic",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "portable-pty"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"bitflags 1.3.2",
|
||||||
|
"downcast-rs",
|
||||||
|
"filedescriptor",
|
||||||
|
"lazy_static",
|
||||||
|
"libc",
|
||||||
|
"log",
|
||||||
|
"nix",
|
||||||
|
"serial2",
|
||||||
|
"shared_library",
|
||||||
|
"shell-words",
|
||||||
|
"winapi",
|
||||||
|
"winreg",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "potential_utf"
|
name = "potential_utf"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -4366,6 +4405,17 @@ dependencies = [
|
|||||||
"syn 2.0.104",
|
"syn 2.0.104",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serial2"
|
||||||
|
version = "0.2.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha1"
|
name = "sha1"
|
||||||
version = "0.10.6"
|
version = "0.10.6"
|
||||||
@@ -4397,6 +4447,22 @@ dependencies = [
|
|||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shared_library"
|
||||||
|
version = "0.1.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11"
|
||||||
|
dependencies = [
|
||||||
|
"lazy_static",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "shell-words"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
@@ -6176,6 +6242,15 @@ dependencies = [
|
|||||||
"memchr",
|
"memchr",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "winreg"
|
||||||
|
version = "0.10.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||||
|
dependencies = [
|
||||||
|
"winapi",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winsafe"
|
name = "winsafe"
|
||||||
version = "0.0.19"
|
version = "0.0.19"
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ libc = "0.2.175"
|
|||||||
mcp-types = { path = "../mcp-types" }
|
mcp-types = { path = "../mcp-types" }
|
||||||
mime_guess = "2.0"
|
mime_guess = "2.0"
|
||||||
os_info = "3.12.0"
|
os_info = "3.12.0"
|
||||||
|
portable-pty = "0.9.0"
|
||||||
rand = "0.9"
|
rand = "0.9"
|
||||||
regex-lite = "0.1.6"
|
regex-lite = "0.1.6"
|
||||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||||
|
|||||||
@@ -53,6 +53,11 @@ use crate::exec::SandboxType;
|
|||||||
use crate::exec::StdoutStream;
|
use crate::exec::StdoutStream;
|
||||||
use crate::exec::StreamOutput;
|
use crate::exec::StreamOutput;
|
||||||
use crate::exec::process_exec_tool_call;
|
use crate::exec::process_exec_tool_call;
|
||||||
|
use crate::exec_command::EXEC_COMMAND_TOOL_NAME;
|
||||||
|
use crate::exec_command::ExecCommandParams;
|
||||||
|
use crate::exec_command::SESSION_MANAGER;
|
||||||
|
use crate::exec_command::WRITE_STDIN_TOOL_NAME;
|
||||||
|
use crate::exec_command::WriteStdinParams;
|
||||||
use crate::exec_env::create_env;
|
use crate::exec_env::create_env;
|
||||||
use crate::mcp_connection_manager::McpConnectionManager;
|
use crate::mcp_connection_manager::McpConnectionManager;
|
||||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||||
@@ -498,6 +503,7 @@ impl Session {
|
|||||||
sandbox_policy.clone(),
|
sandbox_policy.clone(),
|
||||||
config.include_plan_tool,
|
config.include_plan_tool,
|
||||||
config.include_apply_patch_tool,
|
config.include_apply_patch_tool,
|
||||||
|
config.use_experimental_streamable_shell_tool,
|
||||||
),
|
),
|
||||||
user_instructions,
|
user_instructions,
|
||||||
base_instructions,
|
base_instructions,
|
||||||
@@ -1080,6 +1086,7 @@ async fn submission_loop(
|
|||||||
new_sandbox_policy.clone(),
|
new_sandbox_policy.clone(),
|
||||||
config.include_plan_tool,
|
config.include_plan_tool,
|
||||||
config.include_apply_patch_tool,
|
config.include_apply_patch_tool,
|
||||||
|
config.use_experimental_streamable_shell_tool,
|
||||||
);
|
);
|
||||||
|
|
||||||
let new_turn_context = TurnContext {
|
let new_turn_context = TurnContext {
|
||||||
@@ -1158,6 +1165,7 @@ async fn submission_loop(
|
|||||||
sandbox_policy.clone(),
|
sandbox_policy.clone(),
|
||||||
config.include_plan_tool,
|
config.include_plan_tool,
|
||||||
config.include_apply_patch_tool,
|
config.include_apply_patch_tool,
|
||||||
|
config.use_experimental_streamable_shell_tool,
|
||||||
),
|
),
|
||||||
user_instructions: turn_context.user_instructions.clone(),
|
user_instructions: turn_context.user_instructions.clone(),
|
||||||
base_instructions: turn_context.base_instructions.clone(),
|
base_instructions: turn_context.base_instructions.clone(),
|
||||||
@@ -2063,6 +2071,52 @@ async fn handle_function_call(
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
"update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await,
|
||||||
|
EXEC_COMMAND_TOOL_NAME => {
|
||||||
|
// TODO(mbolin): Sandbox check.
|
||||||
|
let exec_params = match serde_json::from_str::<ExecCommandParams>(&arguments) {
|
||||||
|
Ok(params) => params,
|
||||||
|
Err(e) => {
|
||||||
|
return ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: format!("failed to parse function arguments: {e}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let result = SESSION_MANAGER
|
||||||
|
.handle_exec_command_request(exec_params)
|
||||||
|
.await;
|
||||||
|
let function_call_output = crate::exec_command::result_into_payload(result);
|
||||||
|
ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: function_call_output,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
WRITE_STDIN_TOOL_NAME => {
|
||||||
|
let write_stdin_params = match serde_json::from_str::<WriteStdinParams>(&arguments) {
|
||||||
|
Ok(params) => params,
|
||||||
|
Err(e) => {
|
||||||
|
return ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: FunctionCallOutputPayload {
|
||||||
|
content: format!("failed to parse function arguments: {e}"),
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let result = SESSION_MANAGER
|
||||||
|
.handle_write_stdin_request(write_stdin_params)
|
||||||
|
.await;
|
||||||
|
let function_call_output: FunctionCallOutputPayload =
|
||||||
|
crate::exec_command::result_into_payload(result);
|
||||||
|
ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: function_call_output,
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
match sess.mcp_connection_manager.parse_tool_name(&name) {
|
||||||
Some((server, tool_name)) => {
|
Some((server, tool_name)) => {
|
||||||
|
|||||||
@@ -174,6 +174,8 @@ pub struct Config {
|
|||||||
|
|
||||||
/// If set to `true`, the API key will be signed with the `originator` header.
|
/// If set to `true`, the API key will be signed with the `originator` header.
|
||||||
pub preferred_auth_method: AuthMode,
|
pub preferred_auth_method: AuthMode,
|
||||||
|
|
||||||
|
pub use_experimental_streamable_shell_tool: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@@ -469,6 +471,8 @@ pub struct ConfigToml {
|
|||||||
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
|
||||||
pub experimental_instructions_file: Option<PathBuf>,
|
pub experimental_instructions_file: Option<PathBuf>,
|
||||||
|
|
||||||
|
pub experimental_use_exec_command_tool: Option<bool>,
|
||||||
|
|
||||||
/// The value for the `originator` header included with Responses API requests.
|
/// The value for the `originator` header included with Responses API requests.
|
||||||
pub responses_originator_header_internal_override: Option<String>,
|
pub responses_originator_header_internal_override: Option<String>,
|
||||||
|
|
||||||
@@ -758,6 +762,9 @@ impl Config {
|
|||||||
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false),
|
||||||
responses_originator_header,
|
responses_originator_header,
|
||||||
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT),
|
||||||
|
use_experimental_streamable_shell_tool: cfg
|
||||||
|
.experimental_use_exec_command_tool
|
||||||
|
.unwrap_or(false),
|
||||||
};
|
};
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
@@ -1124,6 +1131,7 @@ disable_response_storage = true
|
|||||||
include_apply_patch_tool: false,
|
include_apply_patch_tool: false,
|
||||||
responses_originator_header: "codex_cli_rs".to_string(),
|
responses_originator_header: "codex_cli_rs".to_string(),
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
|
use_experimental_streamable_shell_tool: false,
|
||||||
},
|
},
|
||||||
o3_profile_config
|
o3_profile_config
|
||||||
);
|
);
|
||||||
@@ -1178,6 +1186,7 @@ disable_response_storage = true
|
|||||||
include_apply_patch_tool: false,
|
include_apply_patch_tool: false,
|
||||||
responses_originator_header: "codex_cli_rs".to_string(),
|
responses_originator_header: "codex_cli_rs".to_string(),
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
|
use_experimental_streamable_shell_tool: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
|
||||||
@@ -1247,6 +1256,7 @@ disable_response_storage = true
|
|||||||
include_apply_patch_tool: false,
|
include_apply_patch_tool: false,
|
||||||
responses_originator_header: "codex_cli_rs".to_string(),
|
responses_originator_header: "codex_cli_rs".to_string(),
|
||||||
preferred_auth_method: AuthMode::ChatGPT,
|
preferred_auth_method: AuthMode::ChatGPT,
|
||||||
|
use_experimental_streamable_shell_tool: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
assert_eq!(expected_zdr_profile_config, zdr_profile_config);
|
||||||
|
|||||||
57
codex-rs/core/src/exec_command/exec_command_params.rs
Normal file
57
codex-rs/core/src/exec_command/exec_command_params.rs
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::exec_command::session_id::SessionId;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ExecCommandParams {
|
||||||
|
pub(crate) cmd: String,
|
||||||
|
|
||||||
|
#[serde(default = "default_yield_time")]
|
||||||
|
pub(crate) yield_time_ms: u64,
|
||||||
|
|
||||||
|
#[serde(default = "max_output_tokens")]
|
||||||
|
pub(crate) max_output_tokens: u64,
|
||||||
|
|
||||||
|
#[serde(default = "default_shell")]
|
||||||
|
pub(crate) shell: String,
|
||||||
|
|
||||||
|
#[serde(default = "default_login")]
|
||||||
|
pub(crate) login: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_yield_time() -> u64 {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_output_tokens() -> u64 {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_login() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_shell() -> String {
|
||||||
|
"/bin/bash".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct WriteStdinParams {
|
||||||
|
pub(crate) session_id: SessionId,
|
||||||
|
pub(crate) chars: String,
|
||||||
|
|
||||||
|
#[serde(default = "write_stdin_default_yield_time_ms")]
|
||||||
|
pub(crate) yield_time_ms: u64,
|
||||||
|
|
||||||
|
#[serde(default = "write_stdin_default_max_output_tokens")]
|
||||||
|
pub(crate) max_output_tokens: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_stdin_default_yield_time_ms() -> u64 {
|
||||||
|
250
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_stdin_default_max_output_tokens() -> u64 {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
83
codex-rs/core/src/exec_command/exec_command_session.rs
Normal file
83
codex-rs/core/src/exec_command/exec_command_session.rs
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct ExecCommandSession {
|
||||||
|
/// Queue for writing bytes to the process stdin (PTY master write side).
|
||||||
|
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||||
|
/// Broadcast stream of output chunks read from the PTY. New subscribers
|
||||||
|
/// receive only chunks emitted after they subscribe.
|
||||||
|
output_tx: broadcast::Sender<Vec<u8>>,
|
||||||
|
|
||||||
|
/// Child killer handle for termination on drop (can signal independently
|
||||||
|
/// of a thread blocked in `.wait()`).
|
||||||
|
killer: StdMutex<Option<Box<dyn portable_pty::ChildKiller + Send + Sync>>>,
|
||||||
|
|
||||||
|
/// JoinHandle for the blocking PTY reader task.
|
||||||
|
reader_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||||
|
|
||||||
|
/// JoinHandle for the stdin writer task.
|
||||||
|
writer_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||||
|
|
||||||
|
/// JoinHandle for the child wait task.
|
||||||
|
wait_handle: StdMutex<Option<JoinHandle<()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecCommandSession {
|
||||||
|
pub(crate) fn new(
|
||||||
|
writer_tx: mpsc::Sender<Vec<u8>>,
|
||||||
|
output_tx: broadcast::Sender<Vec<u8>>,
|
||||||
|
killer: Box<dyn portable_pty::ChildKiller + Send + Sync>,
|
||||||
|
reader_handle: JoinHandle<()>,
|
||||||
|
writer_handle: JoinHandle<()>,
|
||||||
|
wait_handle: JoinHandle<()>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
writer_tx,
|
||||||
|
output_tx,
|
||||||
|
killer: StdMutex::new(Some(killer)),
|
||||||
|
reader_handle: StdMutex::new(Some(reader_handle)),
|
||||||
|
writer_handle: StdMutex::new(Some(writer_handle)),
|
||||||
|
wait_handle: StdMutex::new(Some(wait_handle)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
|
||||||
|
self.writer_tx.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
|
||||||
|
self.output_tx.subscribe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for ExecCommandSession {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
// Best-effort: terminate child first so blocking tasks can complete.
|
||||||
|
if let Ok(mut killer_opt) = self.killer.lock()
|
||||||
|
&& let Some(mut killer) = killer_opt.take()
|
||||||
|
{
|
||||||
|
let _ = killer.kill();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort background tasks; they may already have exited after kill.
|
||||||
|
if let Ok(mut h) = self.reader_handle.lock()
|
||||||
|
&& let Some(handle) = h.take()
|
||||||
|
{
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
if let Ok(mut h) = self.writer_handle.lock()
|
||||||
|
&& let Some(handle) = h.take()
|
||||||
|
{
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
if let Ok(mut h) = self.wait_handle.lock()
|
||||||
|
&& let Some(handle) = h.take()
|
||||||
|
{
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
14
codex-rs/core/src/exec_command/mod.rs
Normal file
14
codex-rs/core/src/exec_command/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
mod exec_command_params;
|
||||||
|
mod exec_command_session;
|
||||||
|
mod responses_api;
|
||||||
|
mod session_id;
|
||||||
|
mod session_manager;
|
||||||
|
|
||||||
|
pub use exec_command_params::ExecCommandParams;
|
||||||
|
pub use exec_command_params::WriteStdinParams;
|
||||||
|
pub use responses_api::EXEC_COMMAND_TOOL_NAME;
|
||||||
|
pub use responses_api::WRITE_STDIN_TOOL_NAME;
|
||||||
|
pub use responses_api::create_exec_command_tool_for_responses_api;
|
||||||
|
pub use responses_api::create_write_stdin_tool_for_responses_api;
|
||||||
|
pub use session_manager::SESSION_MANAGER;
|
||||||
|
pub use session_manager::result_into_payload;
|
||||||
98
codex-rs/core/src/exec_command/responses_api.rs
Normal file
98
codex-rs/core/src/exec_command/responses_api.rs
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::openai_tools::JsonSchema;
|
||||||
|
use crate::openai_tools::ResponsesApiTool;
|
||||||
|
|
||||||
|
pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
|
||||||
|
pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
|
||||||
|
|
||||||
|
pub fn create_exec_command_tool_for_responses_api() -> ResponsesApiTool {
|
||||||
|
let mut properties = BTreeMap::<String, JsonSchema>::new();
|
||||||
|
properties.insert(
|
||||||
|
"cmd".to_string(),
|
||||||
|
JsonSchema::String {
|
||||||
|
description: Some("The shell command to execute.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"yield_time_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("The maximum time in milliseconds to wait for output.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"max_output_tokens".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("The maximum number of tokens to output.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"shell".to_string(),
|
||||||
|
JsonSchema::String {
|
||||||
|
description: Some("The shell to use. Defaults to \"/bin/bash\".".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"login".to_string(),
|
||||||
|
JsonSchema::Boolean {
|
||||||
|
description: Some(
|
||||||
|
"Whether to run the command as a login shell. Defaults to true.".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
ResponsesApiTool {
|
||||||
|
name: EXEC_COMMAND_TOOL_NAME.to_owned(),
|
||||||
|
description: r#"Execute shell commands on the local machine with streaming output."#
|
||||||
|
.to_string(),
|
||||||
|
strict: false,
|
||||||
|
parameters: JsonSchema::Object {
|
||||||
|
properties,
|
||||||
|
required: Some(vec!["cmd".to_string()]),
|
||||||
|
additional_properties: Some(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_write_stdin_tool_for_responses_api() -> ResponsesApiTool {
|
||||||
|
let mut properties = BTreeMap::<String, JsonSchema>::new();
|
||||||
|
properties.insert(
|
||||||
|
"session_id".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("The ID of the exec_command session.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"chars".to_string(),
|
||||||
|
JsonSchema::String {
|
||||||
|
description: Some("The characters to write to stdin.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"yield_time_ms".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some(
|
||||||
|
"The maximum time in milliseconds to wait for output after writing.".to_string(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
properties.insert(
|
||||||
|
"max_output_tokens".to_string(),
|
||||||
|
JsonSchema::Number {
|
||||||
|
description: Some("The maximum number of tokens to output.".to_string()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
ResponsesApiTool {
|
||||||
|
name: WRITE_STDIN_TOOL_NAME.to_owned(),
|
||||||
|
description: r#"Write characters to an exec session's stdin. Returns all stdout+stderr received within yield_time_ms.
|
||||||
|
Can write control characters (\u0003 for Ctrl-C), or an empty string to just poll stdout+stderr."#
|
||||||
|
.to_string(),
|
||||||
|
strict: false,
|
||||||
|
parameters: JsonSchema::Object {
|
||||||
|
properties,
|
||||||
|
required: Some(vec!["session_id".to_string(), "chars".to_string()]),
|
||||||
|
additional_properties: Some(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
5
codex-rs/core/src/exec_command/session_id.rs
Normal file
5
codex-rs/core/src/exec_command/session_id.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||||
|
pub(crate) struct SessionId(pub u32);
|
||||||
677
codex-rs/core/src/exec_command/session_manager.rs
Normal file
677
codex-rs/core/src/exec_command/session_manager.rs
Normal file
@@ -0,0 +1,677 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
use std::sync::Mutex as StdMutex;
|
||||||
|
use std::sync::atomic::AtomicU32;
|
||||||
|
|
||||||
|
use portable_pty::CommandBuilder;
|
||||||
|
use portable_pty::PtySize;
|
||||||
|
use portable_pty::native_pty_system;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
use tokio::time::Instant;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||||
|
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||||
|
use crate::exec_command::exec_command_session::ExecCommandSession;
|
||||||
|
use crate::exec_command::session_id::SessionId;
|
||||||
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
|
|
||||||
|
pub static SESSION_MANAGER: LazyLock<SessionManager> = LazyLock::new(SessionManager::default);
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct SessionManager {
|
||||||
|
next_session_id: AtomicU32,
|
||||||
|
sessions: Mutex<HashMap<SessionId, ExecCommandSession>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ExecCommandOutput {
|
||||||
|
wall_time: Duration,
|
||||||
|
exit_status: ExitStatus,
|
||||||
|
original_token_count: Option<u64>,
|
||||||
|
output: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExecCommandOutput {
|
||||||
|
fn to_text_output(&self) -> String {
|
||||||
|
let wall_time_secs = self.wall_time.as_secs_f32();
|
||||||
|
let termination_status = match self.exit_status {
|
||||||
|
ExitStatus::Exited(code) => format!("Process exited with code {code}"),
|
||||||
|
ExitStatus::Ongoing(session_id) => {
|
||||||
|
format!("Process running with session ID {}", session_id.0)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let truncation_status = match self.original_token_count {
|
||||||
|
Some(tokens) => {
|
||||||
|
format!("\nWarning: truncated output (original token count: {tokens})")
|
||||||
|
}
|
||||||
|
None => "".to_string(),
|
||||||
|
};
|
||||||
|
format!(
|
||||||
|
r#"Wall time: {wall_time_secs:.3} seconds
|
||||||
|
{termination_status}{truncation_status}
|
||||||
|
Output:
|
||||||
|
{output}"#,
|
||||||
|
output = self.output
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum ExitStatus {
|
||||||
|
Exited(i32),
|
||||||
|
Ongoing(SessionId),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn result_into_payload(result: Result<ExecCommandOutput, String>) -> FunctionCallOutputPayload {
|
||||||
|
match result {
|
||||||
|
Ok(output) => FunctionCallOutputPayload {
|
||||||
|
content: output.to_text_output(),
|
||||||
|
success: Some(true),
|
||||||
|
},
|
||||||
|
Err(err) => FunctionCallOutputPayload {
|
||||||
|
content: err,
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionManager {
|
||||||
|
/// Processes the request and is required to send a response via `outgoing`.
|
||||||
|
pub async fn handle_exec_command_request(
|
||||||
|
&self,
|
||||||
|
params: ExecCommandParams,
|
||||||
|
) -> Result<ExecCommandOutput, String> {
|
||||||
|
// Allocate a session id.
|
||||||
|
let session_id = SessionId(
|
||||||
|
self.next_session_id
|
||||||
|
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
|
||||||
|
);
|
||||||
|
|
||||||
|
let (session, mut exit_rx) =
|
||||||
|
create_exec_command_session(params.clone())
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
format!(
|
||||||
|
"failed to create exec command session for session id {}: {err}",
|
||||||
|
session_id.0
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Insert into session map.
|
||||||
|
let mut output_rx = session.output_receiver();
|
||||||
|
self.sessions.lock().await.insert(session_id, session);
|
||||||
|
|
||||||
|
// Collect output until either timeout expires or process exits.
|
||||||
|
// Do not cap during collection; truncate at the end if needed.
|
||||||
|
// Use a modest initial capacity to avoid large preallocation.
|
||||||
|
let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4);
|
||||||
|
let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
|
||||||
|
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||||
|
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let deadline = start_time + Duration::from_millis(params.yield_time_ms);
|
||||||
|
let mut exit_code: Option<i32> = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if Instant::now() >= deadline {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
exit = &mut exit_rx => {
|
||||||
|
exit_code = exit.ok();
|
||||||
|
// Small grace period to pull remaining buffered output
|
||||||
|
let grace_deadline = Instant::now() + Duration::from_millis(25);
|
||||||
|
while Instant::now() < grace_deadline {
|
||||||
|
match timeout(Duration::from_millis(1), output_rx.recv()).await {
|
||||||
|
Ok(Ok(chunk)) => {
|
||||||
|
collected.extend_from_slice(&chunk);
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
|
||||||
|
// Skip missed messages; keep trying within grace period.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
chunk = timeout(remaining, output_rx.recv()) => {
|
||||||
|
match chunk {
|
||||||
|
Ok(Ok(chunk)) => {
|
||||||
|
collected.extend_from_slice(&chunk);
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
|
||||||
|
// Skip missed messages; continue collecting fresh output.
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; }
|
||||||
|
Err(_) => { break; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = String::from_utf8_lossy(&collected).to_string();
|
||||||
|
|
||||||
|
let exit_status = if let Some(code) = exit_code {
|
||||||
|
ExitStatus::Exited(code)
|
||||||
|
} else {
|
||||||
|
ExitStatus::Ongoing(session_id)
|
||||||
|
};
|
||||||
|
|
||||||
|
// If output exceeds cap, truncate the middle and record original token estimate.
|
||||||
|
let (output, original_token_count) = truncate_middle(&output, cap_bytes);
|
||||||
|
Ok(ExecCommandOutput {
|
||||||
|
wall_time: Instant::now().duration_since(start_time),
|
||||||
|
exit_status,
|
||||||
|
original_token_count,
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`.
|
||||||
|
pub async fn handle_write_stdin_request(
|
||||||
|
&self,
|
||||||
|
params: WriteStdinParams,
|
||||||
|
) -> Result<ExecCommandOutput, String> {
|
||||||
|
let WriteStdinParams {
|
||||||
|
session_id,
|
||||||
|
chars,
|
||||||
|
yield_time_ms,
|
||||||
|
max_output_tokens,
|
||||||
|
} = params;
|
||||||
|
|
||||||
|
// Grab handles without holding the sessions lock across await points.
|
||||||
|
let (writer_tx, mut output_rx) = {
|
||||||
|
let sessions = self.sessions.lock().await;
|
||||||
|
match sessions.get(&session_id) {
|
||||||
|
Some(session) => (session.writer_sender(), session.output_receiver()),
|
||||||
|
None => {
|
||||||
|
return Err(format!("unknown session id {}", session_id.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Write stdin if provided.
|
||||||
|
if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() {
|
||||||
|
return Err("failed to write to stdin".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect output up to yield_time_ms, truncating to max_output_tokens bytes.
|
||||||
|
let mut collected: Vec<u8> = Vec::with_capacity(4096);
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let deadline = start_time + Duration::from_millis(yield_time_ms);
|
||||||
|
loop {
|
||||||
|
let now = Instant::now();
|
||||||
|
if now >= deadline {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let remaining = deadline - now;
|
||||||
|
match timeout(remaining, output_rx.recv()).await {
|
||||||
|
Ok(Ok(chunk)) => {
|
||||||
|
// Collect all output within the time budget; truncate at the end.
|
||||||
|
collected.extend_from_slice(&chunk);
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => {
|
||||||
|
// Skip missed messages; continue collecting fresh output.
|
||||||
|
}
|
||||||
|
Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break,
|
||||||
|
Err(_) => break, // timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return structured output, truncating middle if over cap.
|
||||||
|
let output = String::from_utf8_lossy(&collected).to_string();
|
||||||
|
let cap_bytes_u64 = max_output_tokens.saturating_mul(4);
|
||||||
|
let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize;
|
||||||
|
let (output, original_token_count) = truncate_middle(&output, cap_bytes);
|
||||||
|
Ok(ExecCommandOutput {
|
||||||
|
wall_time: Instant::now().duration_since(start_time),
|
||||||
|
exit_status: ExitStatus::Ongoing(session_id),
|
||||||
|
original_token_count,
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn PTY and child process per spawn_exec_command_session logic.
|
||||||
|
async fn create_exec_command_session(
|
||||||
|
params: ExecCommandParams,
|
||||||
|
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
|
||||||
|
let ExecCommandParams {
|
||||||
|
cmd,
|
||||||
|
yield_time_ms: _,
|
||||||
|
max_output_tokens: _,
|
||||||
|
shell,
|
||||||
|
login,
|
||||||
|
} = params;
|
||||||
|
|
||||||
|
// Use the native pty implementation for the system
|
||||||
|
let pty_system = native_pty_system();
|
||||||
|
|
||||||
|
// Create a new pty
|
||||||
|
let pair = pty_system.openpty(PtySize {
|
||||||
|
rows: 24,
|
||||||
|
cols: 80,
|
||||||
|
pixel_width: 0,
|
||||||
|
pixel_height: 0,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Spawn a shell into the pty
|
||||||
|
let mut command_builder = CommandBuilder::new(shell);
|
||||||
|
let shell_mode_opt = if login { "-lc" } else { "-c" };
|
||||||
|
command_builder.arg(shell_mode_opt);
|
||||||
|
command_builder.arg(cmd);
|
||||||
|
|
||||||
|
let mut child = pair.slave.spawn_command(command_builder)?;
|
||||||
|
// Obtain a killer that can signal the process independently of `.wait()`.
|
||||||
|
let killer = child.clone_killer();
|
||||||
|
|
||||||
|
// Channel to forward write requests to the PTY writer.
|
||||||
|
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
|
||||||
|
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
|
||||||
|
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
|
||||||
|
|
||||||
|
// Reader task: drain PTY and forward chunks to output channel.
|
||||||
|
let mut reader = pair.master.try_clone_reader()?;
|
||||||
|
let output_tx_clone = output_tx.clone();
|
||||||
|
let reader_handle = tokio::task::spawn_blocking(move || {
|
||||||
|
let mut buf = [0u8; 8192];
|
||||||
|
loop {
|
||||||
|
match reader.read(&mut buf) {
|
||||||
|
Ok(0) => break, // EOF
|
||||||
|
Ok(n) => {
|
||||||
|
// Forward to broadcast; best-effort if there are subscribers.
|
||||||
|
let _ = output_tx_clone.send(buf[..n].to_vec());
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::Interrupted => {
|
||||||
|
// Retry on EINTR
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::WouldBlock => {
|
||||||
|
// We're in a blocking thread; back off briefly and retry.
|
||||||
|
std::thread::sleep(Duration::from_millis(5));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Writer task: apply stdin writes to the PTY writer.
|
||||||
|
let writer = pair.master.take_writer()?;
|
||||||
|
let writer = Arc::new(StdMutex::new(writer));
|
||||||
|
let writer_handle = tokio::spawn({
|
||||||
|
let writer = writer.clone();
|
||||||
|
async move {
|
||||||
|
while let Some(bytes) = writer_rx.recv().await {
|
||||||
|
let writer = writer.clone();
|
||||||
|
// Perform blocking write on a blocking thread.
|
||||||
|
let _ = tokio::task::spawn_blocking(move || {
|
||||||
|
if let Ok(mut guard) = writer.lock() {
|
||||||
|
use std::io::Write;
|
||||||
|
let _ = guard.write_all(&bytes);
|
||||||
|
let _ = guard.flush();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Keep the child alive until it exits, then signal exit code.
|
||||||
|
let (exit_tx, exit_rx) = oneshot::channel::<i32>();
|
||||||
|
let wait_handle = tokio::task::spawn_blocking(move || {
|
||||||
|
let code = match child.wait() {
|
||||||
|
Ok(status) => status.exit_code() as i32,
|
||||||
|
Err(_) => -1,
|
||||||
|
};
|
||||||
|
let _ = exit_tx.send(code);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create and store the session with channels.
|
||||||
|
let session = ExecCommandSession::new(
|
||||||
|
writer_tx,
|
||||||
|
output_tx,
|
||||||
|
killer,
|
||||||
|
reader_handle,
|
||||||
|
writer_handle,
|
||||||
|
wait_handle,
|
||||||
|
);
|
||||||
|
Ok((session, exit_rx))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes,
|
||||||
|
/// preserving the beginning and the end. Returns the possibly truncated
|
||||||
|
/// string and `Some(original_token_count)` (estimated at 4 bytes/token)
|
||||||
|
/// if truncation occurred; otherwise returns the original string and `None`.
|
||||||
|
fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option<u64>) {
|
||||||
|
// No truncation needed
|
||||||
|
if s.len() <= max_bytes {
|
||||||
|
return (s.to_string(), None);
|
||||||
|
}
|
||||||
|
let est_tokens = (s.len() as u64).div_ceil(4);
|
||||||
|
if max_bytes == 0 {
|
||||||
|
// Cannot keep any content; still return a full marker (never truncated).
|
||||||
|
return (
|
||||||
|
format!("…{} tokens truncated…", est_tokens),
|
||||||
|
Some(est_tokens),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to truncate a string to a given byte length on a char boundary.
|
||||||
|
fn truncate_on_boundary(input: &str, max_len: usize) -> &str {
|
||||||
|
if input.len() <= max_len {
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
let mut end = max_len;
|
||||||
|
while end > 0 && !input.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
&input[..end]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a left/right budget, prefer newline boundaries; otherwise fall back
|
||||||
|
// to UTF-8 char boundaries.
|
||||||
|
fn pick_prefix_end(s: &str, left_budget: usize) -> usize {
|
||||||
|
if let Some(head) = s.get(..left_budget)
|
||||||
|
&& let Some(i) = head.rfind('\n')
|
||||||
|
{
|
||||||
|
return i + 1; // keep the newline so suffix starts on a fresh line
|
||||||
|
}
|
||||||
|
truncate_on_boundary(s, left_budget).len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_suffix_start(s: &str, right_budget: usize) -> usize {
|
||||||
|
let start_tail = s.len().saturating_sub(right_budget);
|
||||||
|
if let Some(tail) = s.get(start_tail..)
|
||||||
|
&& let Some(i) = tail.find('\n')
|
||||||
|
{
|
||||||
|
return start_tail + i + 1; // start after newline
|
||||||
|
}
|
||||||
|
// Fall back to a char boundary at or after start_tail.
|
||||||
|
let mut idx = start_tail.min(s.len());
|
||||||
|
while idx < s.len() && !s.is_char_boundary(idx) {
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
idx
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refine marker length and budgets until stable. Marker is never truncated.
|
||||||
|
let mut guess_tokens = est_tokens; // worst-case: everything truncated
|
||||||
|
for _ in 0..4 {
|
||||||
|
let marker = format!("…{} tokens truncated…", guess_tokens);
|
||||||
|
let marker_len = marker.len();
|
||||||
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
|
if keep_budget == 0 {
|
||||||
|
// No room for any content within the cap; return a full, untruncated marker
|
||||||
|
// that reflects the entire truncated content.
|
||||||
|
return (
|
||||||
|
format!("…{} tokens truncated…", est_tokens),
|
||||||
|
Some(est_tokens),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let left_budget = keep_budget / 2;
|
||||||
|
let right_budget = keep_budget - left_budget;
|
||||||
|
let prefix_end = pick_prefix_end(s, left_budget);
|
||||||
|
let mut suffix_start = pick_suffix_start(s, right_budget);
|
||||||
|
if suffix_start < prefix_end {
|
||||||
|
suffix_start = prefix_end;
|
||||||
|
}
|
||||||
|
let kept_content_bytes = prefix_end + (s.len() - suffix_start);
|
||||||
|
let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes);
|
||||||
|
let new_tokens = (truncated_content_bytes as u64).div_ceil(4);
|
||||||
|
if new_tokens == guess_tokens {
|
||||||
|
let mut out = String::with_capacity(marker_len + kept_content_bytes + 1);
|
||||||
|
out.push_str(&s[..prefix_end]);
|
||||||
|
out.push_str(&marker);
|
||||||
|
// Place marker on its own line for symmetry when we keep line boundaries.
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(&s[suffix_start..]);
|
||||||
|
return (out, Some(est_tokens));
|
||||||
|
}
|
||||||
|
guess_tokens = new_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: use last guess to build output.
|
||||||
|
let marker = format!("…{} tokens truncated…", guess_tokens);
|
||||||
|
let marker_len = marker.len();
|
||||||
|
let keep_budget = max_bytes.saturating_sub(marker_len);
|
||||||
|
if keep_budget == 0 {
|
||||||
|
return (
|
||||||
|
format!("…{} tokens truncated…", est_tokens),
|
||||||
|
Some(est_tokens),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let left_budget = keep_budget / 2;
|
||||||
|
let right_budget = keep_budget - left_budget;
|
||||||
|
let prefix_end = pick_prefix_end(s, left_budget);
|
||||||
|
let suffix_start = pick_suffix_start(s, right_budget);
|
||||||
|
let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1);
|
||||||
|
out.push_str(&s[..prefix_end]);
|
||||||
|
out.push_str(&marker);
|
||||||
|
out.push('\n');
|
||||||
|
out.push_str(&s[suffix_start..]);
|
||||||
|
(out, Some(est_tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::exec_command::session_id::SessionId;
|
||||||
|
|
||||||
|
/// Test that verifies that [`SessionManager::handle_exec_command_request()`]
|
||||||
|
/// and [`SessionManager::handle_write_stdin_request()`] work as expected
|
||||||
|
/// in the presence of a process that never terminates (but produces
|
||||||
|
/// output continuously).
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[allow(clippy::print_stderr)]
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn session_manager_streams_and_truncates_from_now() {
|
||||||
|
use crate::exec_command::exec_command_params::ExecCommandParams;
|
||||||
|
use crate::exec_command::exec_command_params::WriteStdinParams;
|
||||||
|
use tokio::time::sleep;
|
||||||
|
|
||||||
|
let session_manager = SessionManager::default();
|
||||||
|
// Long-running loop that prints an increasing counter every ~100ms.
|
||||||
|
// Use Python for a portable, reliable sleep across shells/PTYs.
|
||||||
|
let cmd = r#"python3 - <<'PY'
|
||||||
|
import sys, time
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
print(count)
|
||||||
|
sys.stdout.flush()
|
||||||
|
count += 100
|
||||||
|
time.sleep(0.1)
|
||||||
|
PY"#
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Start the session and collect ~3s of output.
|
||||||
|
let params = ExecCommandParams {
|
||||||
|
cmd,
|
||||||
|
yield_time_ms: 3_000,
|
||||||
|
max_output_tokens: 1_000, // large enough to avoid truncation here
|
||||||
|
shell: "/bin/bash".to_string(),
|
||||||
|
login: false,
|
||||||
|
};
|
||||||
|
let initial_output = match session_manager
|
||||||
|
.handle_exec_command_request(params.clone())
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
// PTY may be restricted in some sandboxes; skip in that case.
|
||||||
|
if e.contains("openpty") || e.contains("Operation not permitted") {
|
||||||
|
eprintln!("skipping test due to restricted PTY: {e}");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
panic!("exec request failed unexpectedly: {e}");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
eprintln!("initial output: {initial_output:?}");
|
||||||
|
|
||||||
|
// Should be ongoing (we launched a never-ending loop).
|
||||||
|
let session_id = match initial_output.exit_status {
|
||||||
|
ExitStatus::Ongoing(id) => id,
|
||||||
|
_ => panic!("expected ongoing session"),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the numeric lines and get the max observed value in the first window.
|
||||||
|
let first_nums = extract_monotonic_numbers(&initial_output.output);
|
||||||
|
assert!(
|
||||||
|
!first_nums.is_empty(),
|
||||||
|
"expected some output from first window"
|
||||||
|
);
|
||||||
|
let first_max = *first_nums.iter().max().unwrap();
|
||||||
|
|
||||||
|
// Wait ~4s so counters progress while we're not reading.
|
||||||
|
sleep(Duration::from_millis(4_000)).await;
|
||||||
|
|
||||||
|
// Now read ~3s of output "from now" only.
|
||||||
|
// Use a small token cap so truncation occurs and we test middle truncation.
|
||||||
|
let write_params = WriteStdinParams {
|
||||||
|
session_id,
|
||||||
|
chars: String::new(),
|
||||||
|
yield_time_ms: 3_000,
|
||||||
|
max_output_tokens: 16, // 16 tokens ~= 64 bytes -> likely truncation
|
||||||
|
};
|
||||||
|
let second = session_manager
|
||||||
|
.handle_write_stdin_request(write_params)
|
||||||
|
.await
|
||||||
|
.expect("write stdin should succeed");
|
||||||
|
|
||||||
|
// Verify truncation metadata and size bound (cap is tokens*4 bytes).
|
||||||
|
assert!(second.original_token_count.is_some());
|
||||||
|
let cap_bytes = (16u64 * 4) as usize;
|
||||||
|
assert!(second.output.len() <= cap_bytes);
|
||||||
|
// New middle marker should be present.
|
||||||
|
assert!(
|
||||||
|
second.output.contains("tokens truncated") && second.output.contains('…'),
|
||||||
|
"expected truncation marker in output, got: {}",
|
||||||
|
second.output
|
||||||
|
);
|
||||||
|
|
||||||
|
// Minimal freshness check: the earliest number we see in the second window
|
||||||
|
// should be significantly larger than the last from the first window.
|
||||||
|
let second_nums = extract_monotonic_numbers(&second.output);
|
||||||
|
assert!(
|
||||||
|
!second_nums.is_empty(),
|
||||||
|
"expected some numeric output from second window"
|
||||||
|
);
|
||||||
|
let second_min = *second_nums.iter().min().unwrap();
|
||||||
|
|
||||||
|
// We slept 4 seconds (~40 ticks at 100ms/tick, each +100), so expect
|
||||||
|
// an increase of roughly 4000 or more. Allow a generous margin.
|
||||||
|
assert!(
|
||||||
|
second_min >= first_max + 2000,
|
||||||
|
"second_min={second_min} first_max={first_max}",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
fn extract_monotonic_numbers(s: &str) -> Vec<i64> {
|
||||||
|
s.lines()
|
||||||
|
.filter_map(|line| {
|
||||||
|
if !line.is_empty()
|
||||||
|
&& line.chars().all(|c| c.is_ascii_digit())
|
||||||
|
&& let Ok(n) = line.parse::<i64>()
|
||||||
|
{
|
||||||
|
// Our generator increments by 100; ignore spurious fragments.
|
||||||
|
if n % 100 == 0 {
|
||||||
|
return Some(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_text_output_exited_no_truncation() {
|
||||||
|
let out = ExecCommandOutput {
|
||||||
|
wall_time: Duration::from_millis(1234),
|
||||||
|
exit_status: ExitStatus::Exited(0),
|
||||||
|
original_token_count: None,
|
||||||
|
output: "hello".to_string(),
|
||||||
|
};
|
||||||
|
let text = out.to_text_output();
|
||||||
|
let expected = r#"Wall time: 1.234 seconds
|
||||||
|
Process exited with code 0
|
||||||
|
Output:
|
||||||
|
hello"#;
|
||||||
|
assert_eq!(expected, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_text_output_ongoing_with_truncation() {
|
||||||
|
let out = ExecCommandOutput {
|
||||||
|
wall_time: Duration::from_millis(500),
|
||||||
|
exit_status: ExitStatus::Ongoing(SessionId(42)),
|
||||||
|
original_token_count: Some(1000),
|
||||||
|
output: "abc".to_string(),
|
||||||
|
};
|
||||||
|
let text = out.to_text_output();
|
||||||
|
let expected = r#"Wall time: 0.500 seconds
|
||||||
|
Process running with session ID 42
|
||||||
|
Warning: truncated output (original token count: 1000)
|
||||||
|
Output:
|
||||||
|
abc"#;
|
||||||
|
assert_eq!(expected, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_no_newlines_fallback() {
|
||||||
|
// A long string with no newlines that exceeds the cap.
|
||||||
|
let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
|
||||||
|
let max_bytes = 16; // force truncation
|
||||||
|
let (out, original) = truncate_middle(s, max_bytes);
|
||||||
|
// For very small caps, we return the full, untruncated marker,
|
||||||
|
// even if it exceeds the cap.
|
||||||
|
assert_eq!(out, "…16 tokens truncated…");
|
||||||
|
// Original string length is 62 bytes => ceil(62/4) = 16 tokens.
|
||||||
|
assert_eq!(original, Some(16));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_middle_prefers_newline_boundaries() {
|
||||||
|
// Build a multi-line string of 20 numbered lines (each "NNN\n").
|
||||||
|
let mut s = String::new();
|
||||||
|
for i in 1..=20 {
|
||||||
|
s.push_str(&format!("{i:03}\n"));
|
||||||
|
}
|
||||||
|
// Total length: 20 lines * 4 bytes per line = 80 bytes.
|
||||||
|
assert_eq!(s.len(), 80);
|
||||||
|
|
||||||
|
// Choose a cap that forces truncation while leaving room for
|
||||||
|
// a few lines on each side after accounting for the marker.
|
||||||
|
let max_bytes = 64;
|
||||||
|
// Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20).
|
||||||
|
assert_eq!(
|
||||||
|
truncate_middle(&s, max_bytes),
|
||||||
|
(
|
||||||
|
r#"001
|
||||||
|
002
|
||||||
|
003
|
||||||
|
004
|
||||||
|
…12 tokens truncated…
|
||||||
|
017
|
||||||
|
018
|
||||||
|
019
|
||||||
|
020
|
||||||
|
"#
|
||||||
|
.to_string(),
|
||||||
|
Some(20)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ mod conversation_history;
|
|||||||
mod environment_context;
|
mod environment_context;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod exec;
|
pub mod exec;
|
||||||
|
mod exec_command;
|
||||||
pub mod exec_env;
|
pub mod exec_env;
|
||||||
mod flags;
|
mod flags;
|
||||||
pub mod git_info;
|
pub mod git_info;
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ pub enum ConfigShellToolType {
|
|||||||
DefaultShell,
|
DefaultShell,
|
||||||
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
ShellWithRequest { sandbox_policy: SandboxPolicy },
|
||||||
LocalShell,
|
LocalShell,
|
||||||
|
StreamableShell,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -72,13 +73,16 @@ impl ToolsConfig {
|
|||||||
sandbox_policy: SandboxPolicy,
|
sandbox_policy: SandboxPolicy,
|
||||||
include_plan_tool: bool,
|
include_plan_tool: bool,
|
||||||
include_apply_patch_tool: bool,
|
include_apply_patch_tool: bool,
|
||||||
|
use_streamable_shell_tool: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let mut shell_type = if model_family.uses_local_shell_tool {
|
let mut shell_type = if use_streamable_shell_tool {
|
||||||
|
ConfigShellToolType::StreamableShell
|
||||||
|
} else if model_family.uses_local_shell_tool {
|
||||||
ConfigShellToolType::LocalShell
|
ConfigShellToolType::LocalShell
|
||||||
} else {
|
} else {
|
||||||
ConfigShellToolType::DefaultShell
|
ConfigShellToolType::DefaultShell
|
||||||
};
|
};
|
||||||
if matches!(approval_policy, AskForApproval::OnRequest) {
|
if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool {
|
||||||
shell_type = ConfigShellToolType::ShellWithRequest {
|
shell_type = ConfigShellToolType::ShellWithRequest {
|
||||||
sandbox_policy: sandbox_policy.clone(),
|
sandbox_policy: sandbox_policy.clone(),
|
||||||
}
|
}
|
||||||
@@ -492,6 +496,14 @@ pub(crate) fn get_openai_tools(
|
|||||||
ConfigShellToolType::LocalShell => {
|
ConfigShellToolType::LocalShell => {
|
||||||
tools.push(OpenAiTool::LocalShell {});
|
tools.push(OpenAiTool::LocalShell {});
|
||||||
}
|
}
|
||||||
|
ConfigShellToolType::StreamableShell => {
|
||||||
|
tools.push(OpenAiTool::Function(
|
||||||
|
crate::exec_command::create_exec_command_tool_for_responses_api(),
|
||||||
|
));
|
||||||
|
tools.push(OpenAiTool::Function(
|
||||||
|
crate::exec_command::create_write_stdin_tool_for_responses_api(),
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.plan_tool {
|
if config.plan_tool {
|
||||||
@@ -564,6 +576,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||||
|
|
||||||
@@ -579,6 +592,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
true,
|
true,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
let tools = get_openai_tools(&config, Some(HashMap::new()));
|
||||||
|
|
||||||
@@ -594,6 +608,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
&config,
|
&config,
|
||||||
@@ -688,6 +703,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -744,6 +760,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -795,6 +812,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
@@ -849,6 +867,7 @@ mod tests {
|
|||||||
SandboxPolicy::ReadOnly,
|
SandboxPolicy::ReadOnly,
|
||||||
false,
|
false,
|
||||||
false,
|
false,
|
||||||
|
/*use_experimental_streamable_shell_tool*/ false,
|
||||||
);
|
);
|
||||||
|
|
||||||
let tools = get_openai_tools(
|
let tools = get_openai_tools(
|
||||||
|
|||||||
Reference in New Issue
Block a user