From e3b03eaccb104e50b08a3008eb209486babe05ed Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Fri, 22 Aug 2025 18:10:55 -0700 Subject: [PATCH] feat: StreamableShell with exec_command and write_stdin tools (#2574) --- codex-rs/Cargo.lock | 75 ++ codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/codex.rs | 54 ++ codex-rs/core/src/config.rs | 10 + .../src/exec_command/exec_command_params.rs | 57 ++ .../src/exec_command/exec_command_session.rs | 83 +++ codex-rs/core/src/exec_command/mod.rs | 14 + .../core/src/exec_command/responses_api.rs | 98 +++ codex-rs/core/src/exec_command/session_id.rs | 5 + .../core/src/exec_command/session_manager.rs | 677 ++++++++++++++++++ codex-rs/core/src/lib.rs | 1 + codex-rs/core/src/openai_tools.rs | 23 +- 12 files changed, 1096 insertions(+), 2 deletions(-) create mode 100644 codex-rs/core/src/exec_command/exec_command_params.rs create mode 100644 codex-rs/core/src/exec_command/exec_command_session.rs create mode 100644 codex-rs/core/src/exec_command/mod.rs create mode 100644 codex-rs/core/src/exec_command/responses_api.rs create mode 100644 codex-rs/core/src/exec_command/session_id.rs create mode 100644 codex-rs/core/src/exec_command/session_manager.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 94702a83..dbccbd86 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -731,6 +731,7 @@ dependencies = [ "mime_guess", "openssl-sys", "os_info", + "portable-pty", "predicates", "pretty_assertions", "rand 0.9.2", @@ -1479,6 +1480,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + [[package]] name = "dupe" version = "0.9.1" @@ -1724,6 +1731,17 @@ dependencies = [ "simd-adler32", ] +[[package]] +name = "filedescriptor" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e40758ed24c9b2eeb76c35fb0aebc66c626084edd827e07e1552279814c6682d" +dependencies = [ + "libc", + "thiserror 1.0.69", + "winapi", +] + [[package]] name = "fixedbitset" version = "0.4.2" @@ -3439,6 +3457,27 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "portable-pty" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4a596a2b3d2752d94f51fac2d4a96737b8705dddd311a32b9af47211f08671e" +dependencies = [ + "anyhow", + "bitflags 1.3.2", + "downcast-rs", + "filedescriptor", + "lazy_static", + "libc", + "log", + "nix", + "serial2", + "shared_library", + "shell-words", + "winapi", + "winreg", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -4366,6 +4405,17 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "serial2" +version = "0.2.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26e1e5956803a69ddd72ce2de337b577898801528749565def03515f82bad5bb" +dependencies = [ + "cfg-if", + "libc", + "winapi", +] + [[package]] name = "sha1" version = "0.10.6" @@ -4397,6 +4447,22 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shared_library" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a9e7e0f2bfae24d8a5b5a66c5b257a83c7412304311512a0c054cd5e619da11" +dependencies = [ + "lazy_static", + "libc", +] + +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shlex" version = "1.3.0" @@ -6176,6 +6242,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] + [[package]] name = "winsafe" version = "0.0.19" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 56815ba0..2f2fa7cb 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -28,6 +28,7 @@ libc = "0.2.175" mcp-types = { path = "../mcp-types" } mime_guess = "2.0" os_info = "3.12.0" +portable-pty = "0.9.0" rand = "0.9" regex-lite = "0.1.6" reqwest = { version = "0.12", features = ["json", "stream"] } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b08c8e4a..ec053ce4 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -53,6 +53,11 @@ use crate::exec::SandboxType; use crate::exec::StdoutStream; use crate::exec::StreamOutput; use crate::exec::process_exec_tool_call; +use crate::exec_command::EXEC_COMMAND_TOOL_NAME; +use crate::exec_command::ExecCommandParams; +use crate::exec_command::SESSION_MANAGER; +use crate::exec_command::WRITE_STDIN_TOOL_NAME; +use crate::exec_command::WriteStdinParams; use crate::exec_env::create_env; use crate::mcp_connection_manager::McpConnectionManager; use crate::mcp_tool_call::handle_mcp_tool_call; @@ -498,6 +503,7 @@ impl Session { sandbox_policy.clone(), config.include_plan_tool, config.include_apply_patch_tool, + config.use_experimental_streamable_shell_tool, ), user_instructions, base_instructions, @@ -1080,6 +1086,7 @@ async fn submission_loop( new_sandbox_policy.clone(), config.include_plan_tool, config.include_apply_patch_tool, + config.use_experimental_streamable_shell_tool, ); let new_turn_context = TurnContext { @@ -1158,6 +1165,7 @@ async fn submission_loop( sandbox_policy.clone(), config.include_plan_tool, config.include_apply_patch_tool, + config.use_experimental_streamable_shell_tool, ), user_instructions: turn_context.user_instructions.clone(), base_instructions: turn_context.base_instructions.clone(), @@ -2063,6 +2071,52 @@ async fn handle_function_call( .await } "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await, + EXEC_COMMAND_TOOL_NAME => { + // TODO(mbolin): Sandbox check. + let exec_params = match serde_json::from_str::(&arguments) { + Ok(params) => params, + Err(e) => { + return ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: format!("failed to parse function arguments: {e}"), + success: Some(false), + }, + }; + } + }; + let result = SESSION_MANAGER + .handle_exec_command_request(exec_params) + .await; + let function_call_output = crate::exec_command::result_into_payload(result); + ResponseInputItem::FunctionCallOutput { + call_id, + output: function_call_output, + } + } + WRITE_STDIN_TOOL_NAME => { + let write_stdin_params = match serde_json::from_str::(&arguments) { + Ok(params) => params, + Err(e) => { + return ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: format!("failed to parse function arguments: {e}"), + success: Some(false), + }, + }; + } + }; + let result = SESSION_MANAGER + .handle_write_stdin_request(write_stdin_params) + .await; + let function_call_output: FunctionCallOutputPayload = + crate::exec_command::result_into_payload(result); + ResponseInputItem::FunctionCallOutput { + call_id, + output: function_call_output, + } + } _ => { match sess.mcp_connection_manager.parse_tool_name(&name) { Some((server, tool_name)) => { diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 67a54eb1..fbf0387a 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -174,6 +174,8 @@ pub struct Config { /// If set to `true`, the API key will be signed with the `originator` header. pub preferred_auth_method: AuthMode, + + pub use_experimental_streamable_shell_tool: bool, } impl Config { @@ -469,6 +471,8 @@ pub struct ConfigToml { /// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS. pub experimental_instructions_file: Option, + pub experimental_use_exec_command_tool: Option, + /// The value for the `originator` header included with Responses API requests. pub responses_originator_header_internal_override: Option, @@ -758,6 +762,9 @@ impl Config { include_apply_patch_tool: include_apply_patch_tool.unwrap_or(false), responses_originator_header, preferred_auth_method: cfg.preferred_auth_method.unwrap_or(AuthMode::ChatGPT), + use_experimental_streamable_shell_tool: cfg + .experimental_use_exec_command_tool + .unwrap_or(false), }; Ok(config) } @@ -1124,6 +1131,7 @@ disable_response_storage = true include_apply_patch_tool: false, responses_originator_header: "codex_cli_rs".to_string(), preferred_auth_method: AuthMode::ChatGPT, + use_experimental_streamable_shell_tool: false, }, o3_profile_config ); @@ -1178,6 +1186,7 @@ disable_response_storage = true include_apply_patch_tool: false, responses_originator_header: "codex_cli_rs".to_string(), preferred_auth_method: AuthMode::ChatGPT, + use_experimental_streamable_shell_tool: false, }; assert_eq!(expected_gpt3_profile_config, gpt3_profile_config); @@ -1247,6 +1256,7 @@ disable_response_storage = true include_apply_patch_tool: false, responses_originator_header: "codex_cli_rs".to_string(), preferred_auth_method: AuthMode::ChatGPT, + use_experimental_streamable_shell_tool: false, }; assert_eq!(expected_zdr_profile_config, zdr_profile_config); diff --git a/codex-rs/core/src/exec_command/exec_command_params.rs b/codex-rs/core/src/exec_command/exec_command_params.rs new file mode 100644 index 00000000..11a3fd45 --- /dev/null +++ b/codex-rs/core/src/exec_command/exec_command_params.rs @@ -0,0 +1,57 @@ +use serde::Deserialize; +use serde::Serialize; + +use crate::exec_command::session_id::SessionId; + +#[derive(Debug, Clone, Deserialize)] +pub struct ExecCommandParams { + pub(crate) cmd: String, + + #[serde(default = "default_yield_time")] + pub(crate) yield_time_ms: u64, + + #[serde(default = "max_output_tokens")] + pub(crate) max_output_tokens: u64, + + #[serde(default = "default_shell")] + pub(crate) shell: String, + + #[serde(default = "default_login")] + pub(crate) login: bool, +} + +fn default_yield_time() -> u64 { + 10_000 +} + +fn max_output_tokens() -> u64 { + 10_000 +} + +fn default_login() -> bool { + true +} + +fn default_shell() -> String { + "/bin/bash".to_string() +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct WriteStdinParams { + pub(crate) session_id: SessionId, + pub(crate) chars: String, + + #[serde(default = "write_stdin_default_yield_time_ms")] + pub(crate) yield_time_ms: u64, + + #[serde(default = "write_stdin_default_max_output_tokens")] + pub(crate) max_output_tokens: u64, +} + +fn write_stdin_default_yield_time_ms() -> u64 { + 250 +} + +fn write_stdin_default_max_output_tokens() -> u64 { + 10_000 +} diff --git a/codex-rs/core/src/exec_command/exec_command_session.rs b/codex-rs/core/src/exec_command/exec_command_session.rs new file mode 100644 index 00000000..7503150c --- /dev/null +++ b/codex-rs/core/src/exec_command/exec_command_session.rs @@ -0,0 +1,83 @@ +use std::sync::Mutex as StdMutex; + +use tokio::sync::broadcast; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; + +#[derive(Debug)] +pub(crate) struct ExecCommandSession { + /// Queue for writing bytes to the process stdin (PTY master write side). + writer_tx: mpsc::Sender>, + /// Broadcast stream of output chunks read from the PTY. New subscribers + /// receive only chunks emitted after they subscribe. + output_tx: broadcast::Sender>, + + /// Child killer handle for termination on drop (can signal independently + /// of a thread blocked in `.wait()`). + killer: StdMutex>>, + + /// JoinHandle for the blocking PTY reader task. + reader_handle: StdMutex>>, + + /// JoinHandle for the stdin writer task. + writer_handle: StdMutex>>, + + /// JoinHandle for the child wait task. + wait_handle: StdMutex>>, +} + +impl ExecCommandSession { + pub(crate) fn new( + writer_tx: mpsc::Sender>, + output_tx: broadcast::Sender>, + killer: Box, + reader_handle: JoinHandle<()>, + writer_handle: JoinHandle<()>, + wait_handle: JoinHandle<()>, + ) -> Self { + Self { + writer_tx, + output_tx, + killer: StdMutex::new(Some(killer)), + reader_handle: StdMutex::new(Some(reader_handle)), + writer_handle: StdMutex::new(Some(writer_handle)), + wait_handle: StdMutex::new(Some(wait_handle)), + } + } + + pub(crate) fn writer_sender(&self) -> mpsc::Sender> { + self.writer_tx.clone() + } + + pub(crate) fn output_receiver(&self) -> broadcast::Receiver> { + self.output_tx.subscribe() + } +} + +impl Drop for ExecCommandSession { + fn drop(&mut self) { + // Best-effort: terminate child first so blocking tasks can complete. + if let Ok(mut killer_opt) = self.killer.lock() + && let Some(mut killer) = killer_opt.take() + { + let _ = killer.kill(); + } + + // Abort background tasks; they may already have exited after kill. + if let Ok(mut h) = self.reader_handle.lock() + && let Some(handle) = h.take() + { + handle.abort(); + } + if let Ok(mut h) = self.writer_handle.lock() + && let Some(handle) = h.take() + { + handle.abort(); + } + if let Ok(mut h) = self.wait_handle.lock() + && let Some(handle) = h.take() + { + handle.abort(); + } + } +} diff --git a/codex-rs/core/src/exec_command/mod.rs b/codex-rs/core/src/exec_command/mod.rs new file mode 100644 index 00000000..2fd88d4e --- /dev/null +++ b/codex-rs/core/src/exec_command/mod.rs @@ -0,0 +1,14 @@ +mod exec_command_params; +mod exec_command_session; +mod responses_api; +mod session_id; +mod session_manager; + +pub use exec_command_params::ExecCommandParams; +pub use exec_command_params::WriteStdinParams; +pub use responses_api::EXEC_COMMAND_TOOL_NAME; +pub use responses_api::WRITE_STDIN_TOOL_NAME; +pub use responses_api::create_exec_command_tool_for_responses_api; +pub use responses_api::create_write_stdin_tool_for_responses_api; +pub use session_manager::SESSION_MANAGER; +pub use session_manager::result_into_payload; diff --git a/codex-rs/core/src/exec_command/responses_api.rs b/codex-rs/core/src/exec_command/responses_api.rs new file mode 100644 index 00000000..70b90dd4 --- /dev/null +++ b/codex-rs/core/src/exec_command/responses_api.rs @@ -0,0 +1,98 @@ +use std::collections::BTreeMap; + +use crate::openai_tools::JsonSchema; +use crate::openai_tools::ResponsesApiTool; + +pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command"; +pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin"; + +pub fn create_exec_command_tool_for_responses_api() -> ResponsesApiTool { + let mut properties = BTreeMap::::new(); + properties.insert( + "cmd".to_string(), + JsonSchema::String { + description: Some("The shell command to execute.".to_string()), + }, + ); + properties.insert( + "yield_time_ms".to_string(), + JsonSchema::Number { + description: Some("The maximum time in milliseconds to wait for output.".to_string()), + }, + ); + properties.insert( + "max_output_tokens".to_string(), + JsonSchema::Number { + description: Some("The maximum number of tokens to output.".to_string()), + }, + ); + properties.insert( + "shell".to_string(), + JsonSchema::String { + description: Some("The shell to use. Defaults to \"/bin/bash\".".to_string()), + }, + ); + properties.insert( + "login".to_string(), + JsonSchema::Boolean { + description: Some( + "Whether to run the command as a login shell. Defaults to true.".to_string(), + ), + }, + ); + + ResponsesApiTool { + name: EXEC_COMMAND_TOOL_NAME.to_owned(), + description: r#"Execute shell commands on the local machine with streaming output."# + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["cmd".to_string()]), + additional_properties: Some(false), + }, + } +} + +pub fn create_write_stdin_tool_for_responses_api() -> ResponsesApiTool { + let mut properties = BTreeMap::::new(); + properties.insert( + "session_id".to_string(), + JsonSchema::Number { + description: Some("The ID of the exec_command session.".to_string()), + }, + ); + properties.insert( + "chars".to_string(), + JsonSchema::String { + description: Some("The characters to write to stdin.".to_string()), + }, + ); + properties.insert( + "yield_time_ms".to_string(), + JsonSchema::Number { + description: Some( + "The maximum time in milliseconds to wait for output after writing.".to_string(), + ), + }, + ); + properties.insert( + "max_output_tokens".to_string(), + JsonSchema::Number { + description: Some("The maximum number of tokens to output.".to_string()), + }, + ); + + ResponsesApiTool { + name: WRITE_STDIN_TOOL_NAME.to_owned(), + description: r#"Write characters to an exec session's stdin. Returns all stdout+stderr received within yield_time_ms. +Can write control characters (\u0003 for Ctrl-C), or an empty string to just poll stdout+stderr."# + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["session_id".to_string(), "chars".to_string()]), + additional_properties: Some(false), + }, + } +} diff --git a/codex-rs/core/src/exec_command/session_id.rs b/codex-rs/core/src/exec_command/session_id.rs new file mode 100644 index 00000000..c97c5d54 --- /dev/null +++ b/codex-rs/core/src/exec_command/session_id.rs @@ -0,0 +1,5 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub(crate) struct SessionId(pub u32); diff --git a/codex-rs/core/src/exec_command/session_manager.rs b/codex-rs/core/src/exec_command/session_manager.rs new file mode 100644 index 00000000..213b874b --- /dev/null +++ b/codex-rs/core/src/exec_command/session_manager.rs @@ -0,0 +1,677 @@ +use std::collections::HashMap; +use std::io::ErrorKind; +use std::io::Read; +use std::sync::Arc; +use std::sync::LazyLock; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicU32; + +use portable_pty::CommandBuilder; +use portable_pty::PtySize; +use portable_pty::native_pty_system; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::time::Duration; +use tokio::time::Instant; +use tokio::time::timeout; + +use crate::exec_command::exec_command_params::ExecCommandParams; +use crate::exec_command::exec_command_params::WriteStdinParams; +use crate::exec_command::exec_command_session::ExecCommandSession; +use crate::exec_command::session_id::SessionId; +use codex_protocol::models::FunctionCallOutputPayload; + +pub static SESSION_MANAGER: LazyLock = LazyLock::new(SessionManager::default); + +#[derive(Debug, Default)] +pub struct SessionManager { + next_session_id: AtomicU32, + sessions: Mutex>, +} + +#[derive(Debug)] +pub struct ExecCommandOutput { + wall_time: Duration, + exit_status: ExitStatus, + original_token_count: Option, + output: String, +} + +impl ExecCommandOutput { + fn to_text_output(&self) -> String { + let wall_time_secs = self.wall_time.as_secs_f32(); + let termination_status = match self.exit_status { + ExitStatus::Exited(code) => format!("Process exited with code {code}"), + ExitStatus::Ongoing(session_id) => { + format!("Process running with session ID {}", session_id.0) + } + }; + let truncation_status = match self.original_token_count { + Some(tokens) => { + format!("\nWarning: truncated output (original token count: {tokens})") + } + None => "".to_string(), + }; + format!( + r#"Wall time: {wall_time_secs:.3} seconds +{termination_status}{truncation_status} +Output: +{output}"#, + output = self.output + ) + } +} + +#[derive(Debug)] +pub enum ExitStatus { + Exited(i32), + Ongoing(SessionId), +} + +pub fn result_into_payload(result: Result) -> FunctionCallOutputPayload { + match result { + Ok(output) => FunctionCallOutputPayload { + content: output.to_text_output(), + success: Some(true), + }, + Err(err) => FunctionCallOutputPayload { + content: err, + success: Some(false), + }, + } +} + +impl SessionManager { + /// Processes the request and is required to send a response via `outgoing`. + pub async fn handle_exec_command_request( + &self, + params: ExecCommandParams, + ) -> Result { + // Allocate a session id. + let session_id = SessionId( + self.next_session_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst), + ); + + let (session, mut exit_rx) = + create_exec_command_session(params.clone()) + .await + .map_err(|err| { + format!( + "failed to create exec command session for session id {}: {err}", + session_id.0 + ) + })?; + + // Insert into session map. + let mut output_rx = session.output_receiver(); + self.sessions.lock().await.insert(session_id, session); + + // Collect output until either timeout expires or process exits. + // Do not cap during collection; truncate at the end if needed. + // Use a modest initial capacity to avoid large preallocation. + let cap_bytes_u64 = params.max_output_tokens.saturating_mul(4); + let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize; + let mut collected: Vec = Vec::with_capacity(4096); + + let start_time = Instant::now(); + let deadline = start_time + Duration::from_millis(params.yield_time_ms); + let mut exit_code: Option = None; + + loop { + if Instant::now() >= deadline { + break; + } + let remaining = deadline.saturating_duration_since(Instant::now()); + tokio::select! { + biased; + exit = &mut exit_rx => { + exit_code = exit.ok(); + // Small grace period to pull remaining buffered output + let grace_deadline = Instant::now() + Duration::from_millis(25); + while Instant::now() < grace_deadline { + match timeout(Duration::from_millis(1), output_rx.recv()).await { + Ok(Ok(chunk)) => { + collected.extend_from_slice(&chunk); + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { + // Skip missed messages; keep trying within grace period. + continue; + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break, + Err(_) => break, + } + } + break; + } + chunk = timeout(remaining, output_rx.recv()) => { + match chunk { + Ok(Ok(chunk)) => { + collected.extend_from_slice(&chunk); + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { + // Skip missed messages; continue collecting fresh output. + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => { break; } + Err(_) => { break; } + } + } + } + } + + let output = String::from_utf8_lossy(&collected).to_string(); + + let exit_status = if let Some(code) = exit_code { + ExitStatus::Exited(code) + } else { + ExitStatus::Ongoing(session_id) + }; + + // If output exceeds cap, truncate the middle and record original token estimate. + let (output, original_token_count) = truncate_middle(&output, cap_bytes); + Ok(ExecCommandOutput { + wall_time: Instant::now().duration_since(start_time), + exit_status, + original_token_count, + output, + }) + } + + /// Write characters to a session's stdin and collect combined output for up to `yield_time_ms`. + pub async fn handle_write_stdin_request( + &self, + params: WriteStdinParams, + ) -> Result { + let WriteStdinParams { + session_id, + chars, + yield_time_ms, + max_output_tokens, + } = params; + + // Grab handles without holding the sessions lock across await points. + let (writer_tx, mut output_rx) = { + let sessions = self.sessions.lock().await; + match sessions.get(&session_id) { + Some(session) => (session.writer_sender(), session.output_receiver()), + None => { + return Err(format!("unknown session id {}", session_id.0)); + } + } + }; + + // Write stdin if provided. + if !chars.is_empty() && writer_tx.send(chars.into_bytes()).await.is_err() { + return Err("failed to write to stdin".to_string()); + } + + // Collect output up to yield_time_ms, truncating to max_output_tokens bytes. + let mut collected: Vec = Vec::with_capacity(4096); + let start_time = Instant::now(); + let deadline = start_time + Duration::from_millis(yield_time_ms); + loop { + let now = Instant::now(); + if now >= deadline { + break; + } + let remaining = deadline - now; + match timeout(remaining, output_rx.recv()).await { + Ok(Ok(chunk)) => { + // Collect all output within the time budget; truncate at the end. + collected.extend_from_slice(&chunk); + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => { + // Skip missed messages; continue collecting fresh output. + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Closed)) => break, + Err(_) => break, // timeout + } + } + + // Return structured output, truncating middle if over cap. + let output = String::from_utf8_lossy(&collected).to_string(); + let cap_bytes_u64 = max_output_tokens.saturating_mul(4); + let cap_bytes: usize = cap_bytes_u64.min(usize::MAX as u64) as usize; + let (output, original_token_count) = truncate_middle(&output, cap_bytes); + Ok(ExecCommandOutput { + wall_time: Instant::now().duration_since(start_time), + exit_status: ExitStatus::Ongoing(session_id), + original_token_count, + output, + }) + } +} + +/// Spawn PTY and child process per spawn_exec_command_session logic. +async fn create_exec_command_session( + params: ExecCommandParams, +) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver)> { + let ExecCommandParams { + cmd, + yield_time_ms: _, + max_output_tokens: _, + shell, + login, + } = params; + + // Use the native pty implementation for the system + let pty_system = native_pty_system(); + + // Create a new pty + let pair = pty_system.openpty(PtySize { + rows: 24, + cols: 80, + pixel_width: 0, + pixel_height: 0, + })?; + + // Spawn a shell into the pty + let mut command_builder = CommandBuilder::new(shell); + let shell_mode_opt = if login { "-lc" } else { "-c" }; + command_builder.arg(shell_mode_opt); + command_builder.arg(cmd); + + let mut child = pair.slave.spawn_command(command_builder)?; + // Obtain a killer that can signal the process independently of `.wait()`. + let killer = child.clone_killer(); + + // Channel to forward write requests to the PTY writer. + let (writer_tx, mut writer_rx) = mpsc::channel::>(128); + // Broadcast for streaming PTY output to readers: subscribers receive from subscription time. + let (output_tx, _) = tokio::sync::broadcast::channel::>(256); + + // Reader task: drain PTY and forward chunks to output channel. + let mut reader = pair.master.try_clone_reader()?; + let output_tx_clone = output_tx.clone(); + let reader_handle = tokio::task::spawn_blocking(move || { + let mut buf = [0u8; 8192]; + loop { + match reader.read(&mut buf) { + Ok(0) => break, // EOF + Ok(n) => { + // Forward to broadcast; best-effort if there are subscribers. + let _ = output_tx_clone.send(buf[..n].to_vec()); + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => { + // Retry on EINTR + continue; + } + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + // We're in a blocking thread; back off briefly and retry. + std::thread::sleep(Duration::from_millis(5)); + continue; + } + Err(_) => break, + } + } + }); + + // Writer task: apply stdin writes to the PTY writer. + let writer = pair.master.take_writer()?; + let writer = Arc::new(StdMutex::new(writer)); + let writer_handle = tokio::spawn({ + let writer = writer.clone(); + async move { + while let Some(bytes) = writer_rx.recv().await { + let writer = writer.clone(); + // Perform blocking write on a blocking thread. + let _ = tokio::task::spawn_blocking(move || { + if let Ok(mut guard) = writer.lock() { + use std::io::Write; + let _ = guard.write_all(&bytes); + let _ = guard.flush(); + } + }) + .await; + } + } + }); + + // Keep the child alive until it exits, then signal exit code. + let (exit_tx, exit_rx) = oneshot::channel::(); + let wait_handle = tokio::task::spawn_blocking(move || { + let code = match child.wait() { + Ok(status) => status.exit_code() as i32, + Err(_) => -1, + }; + let _ = exit_tx.send(code); + }); + + // Create and store the session with channels. + let session = ExecCommandSession::new( + writer_tx, + output_tx, + killer, + reader_handle, + writer_handle, + wait_handle, + ); + Ok((session, exit_rx)) +} + +/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes, +/// preserving the beginning and the end. Returns the possibly truncated +/// string and `Some(original_token_count)` (estimated at 4 bytes/token) +/// if truncation occurred; otherwise returns the original string and `None`. +fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option) { + // No truncation needed + if s.len() <= max_bytes { + return (s.to_string(), None); + } + let est_tokens = (s.len() as u64).div_ceil(4); + if max_bytes == 0 { + // Cannot keep any content; still return a full marker (never truncated). + return ( + format!("…{} tokens truncated…", est_tokens), + Some(est_tokens), + ); + } + + // Helper to truncate a string to a given byte length on a char boundary. + fn truncate_on_boundary(input: &str, max_len: usize) -> &str { + if input.len() <= max_len { + return input; + } + let mut end = max_len; + while end > 0 && !input.is_char_boundary(end) { + end -= 1; + } + &input[..end] + } + + // Given a left/right budget, prefer newline boundaries; otherwise fall back + // to UTF-8 char boundaries. + fn pick_prefix_end(s: &str, left_budget: usize) -> usize { + if let Some(head) = s.get(..left_budget) + && let Some(i) = head.rfind('\n') + { + return i + 1; // keep the newline so suffix starts on a fresh line + } + truncate_on_boundary(s, left_budget).len() + } + + fn pick_suffix_start(s: &str, right_budget: usize) -> usize { + let start_tail = s.len().saturating_sub(right_budget); + if let Some(tail) = s.get(start_tail..) + && let Some(i) = tail.find('\n') + { + return start_tail + i + 1; // start after newline + } + // Fall back to a char boundary at or after start_tail. + let mut idx = start_tail.min(s.len()); + while idx < s.len() && !s.is_char_boundary(idx) { + idx += 1; + } + idx + } + + // Refine marker length and budgets until stable. Marker is never truncated. + let mut guess_tokens = est_tokens; // worst-case: everything truncated + for _ in 0..4 { + let marker = format!("…{} tokens truncated…", guess_tokens); + let marker_len = marker.len(); + let keep_budget = max_bytes.saturating_sub(marker_len); + if keep_budget == 0 { + // No room for any content within the cap; return a full, untruncated marker + // that reflects the entire truncated content. + return ( + format!("…{} tokens truncated…", est_tokens), + Some(est_tokens), + ); + } + + let left_budget = keep_budget / 2; + let right_budget = keep_budget - left_budget; + let prefix_end = pick_prefix_end(s, left_budget); + let mut suffix_start = pick_suffix_start(s, right_budget); + if suffix_start < prefix_end { + suffix_start = prefix_end; + } + let kept_content_bytes = prefix_end + (s.len() - suffix_start); + let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes); + let new_tokens = (truncated_content_bytes as u64).div_ceil(4); + if new_tokens == guess_tokens { + let mut out = String::with_capacity(marker_len + kept_content_bytes + 1); + out.push_str(&s[..prefix_end]); + out.push_str(&marker); + // Place marker on its own line for symmetry when we keep line boundaries. + out.push('\n'); + out.push_str(&s[suffix_start..]); + return (out, Some(est_tokens)); + } + guess_tokens = new_tokens; + } + + // Fallback: use last guess to build output. + let marker = format!("…{} tokens truncated…", guess_tokens); + let marker_len = marker.len(); + let keep_budget = max_bytes.saturating_sub(marker_len); + if keep_budget == 0 { + return ( + format!("…{} tokens truncated…", est_tokens), + Some(est_tokens), + ); + } + let left_budget = keep_budget / 2; + let right_budget = keep_budget - left_budget; + let prefix_end = pick_prefix_end(s, left_budget); + let suffix_start = pick_suffix_start(s, right_budget); + let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1); + out.push_str(&s[..prefix_end]); + out.push_str(&marker); + out.push('\n'); + out.push_str(&s[suffix_start..]); + (out, Some(est_tokens)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::exec_command::session_id::SessionId; + + /// Test that verifies that [`SessionManager::handle_exec_command_request()`] + /// and [`SessionManager::handle_write_stdin_request()`] work as expected + /// in the presence of a process that never terminates (but produces + /// output continuously). + #[cfg(unix)] + #[allow(clippy::print_stderr)] + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn session_manager_streams_and_truncates_from_now() { + use crate::exec_command::exec_command_params::ExecCommandParams; + use crate::exec_command::exec_command_params::WriteStdinParams; + use tokio::time::sleep; + + let session_manager = SessionManager::default(); + // Long-running loop that prints an increasing counter every ~100ms. + // Use Python for a portable, reliable sleep across shells/PTYs. + let cmd = r#"python3 - <<'PY' +import sys, time +count = 0 +while True: + print(count) + sys.stdout.flush() + count += 100 + time.sleep(0.1) +PY"# + .to_string(); + + // Start the session and collect ~3s of output. + let params = ExecCommandParams { + cmd, + yield_time_ms: 3_000, + max_output_tokens: 1_000, // large enough to avoid truncation here + shell: "/bin/bash".to_string(), + login: false, + }; + let initial_output = match session_manager + .handle_exec_command_request(params.clone()) + .await + { + Ok(v) => v, + Err(e) => { + // PTY may be restricted in some sandboxes; skip in that case. + if e.contains("openpty") || e.contains("Operation not permitted") { + eprintln!("skipping test due to restricted PTY: {e}"); + return; + } + panic!("exec request failed unexpectedly: {e}"); + } + }; + eprintln!("initial output: {initial_output:?}"); + + // Should be ongoing (we launched a never-ending loop). + let session_id = match initial_output.exit_status { + ExitStatus::Ongoing(id) => id, + _ => panic!("expected ongoing session"), + }; + + // Parse the numeric lines and get the max observed value in the first window. + let first_nums = extract_monotonic_numbers(&initial_output.output); + assert!( + !first_nums.is_empty(), + "expected some output from first window" + ); + let first_max = *first_nums.iter().max().unwrap(); + + // Wait ~4s so counters progress while we're not reading. + sleep(Duration::from_millis(4_000)).await; + + // Now read ~3s of output "from now" only. + // Use a small token cap so truncation occurs and we test middle truncation. + let write_params = WriteStdinParams { + session_id, + chars: String::new(), + yield_time_ms: 3_000, + max_output_tokens: 16, // 16 tokens ~= 64 bytes -> likely truncation + }; + let second = session_manager + .handle_write_stdin_request(write_params) + .await + .expect("write stdin should succeed"); + + // Verify truncation metadata and size bound (cap is tokens*4 bytes). + assert!(second.original_token_count.is_some()); + let cap_bytes = (16u64 * 4) as usize; + assert!(second.output.len() <= cap_bytes); + // New middle marker should be present. + assert!( + second.output.contains("tokens truncated") && second.output.contains('…'), + "expected truncation marker in output, got: {}", + second.output + ); + + // Minimal freshness check: the earliest number we see in the second window + // should be significantly larger than the last from the first window. + let second_nums = extract_monotonic_numbers(&second.output); + assert!( + !second_nums.is_empty(), + "expected some numeric output from second window" + ); + let second_min = *second_nums.iter().min().unwrap(); + + // We slept 4 seconds (~40 ticks at 100ms/tick, each +100), so expect + // an increase of roughly 4000 or more. Allow a generous margin. + assert!( + second_min >= first_max + 2000, + "second_min={second_min} first_max={first_max}", + ); + } + + #[cfg(unix)] + fn extract_monotonic_numbers(s: &str) -> Vec { + s.lines() + .filter_map(|line| { + if !line.is_empty() + && line.chars().all(|c| c.is_ascii_digit()) + && let Ok(n) = line.parse::() + { + // Our generator increments by 100; ignore spurious fragments. + if n % 100 == 0 { + return Some(n); + } + } + None + }) + .collect() + } + + #[test] + fn to_text_output_exited_no_truncation() { + let out = ExecCommandOutput { + wall_time: Duration::from_millis(1234), + exit_status: ExitStatus::Exited(0), + original_token_count: None, + output: "hello".to_string(), + }; + let text = out.to_text_output(); + let expected = r#"Wall time: 1.234 seconds +Process exited with code 0 +Output: +hello"#; + assert_eq!(expected, text); + } + + #[test] + fn to_text_output_ongoing_with_truncation() { + let out = ExecCommandOutput { + wall_time: Duration::from_millis(500), + exit_status: ExitStatus::Ongoing(SessionId(42)), + original_token_count: Some(1000), + output: "abc".to_string(), + }; + let text = out.to_text_output(); + let expected = r#"Wall time: 0.500 seconds +Process running with session ID 42 +Warning: truncated output (original token count: 1000) +Output: +abc"#; + assert_eq!(expected, text); + } + + #[test] + fn truncate_middle_no_newlines_fallback() { + // A long string with no newlines that exceeds the cap. + let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let max_bytes = 16; // force truncation + let (out, original) = truncate_middle(s, max_bytes); + // For very small caps, we return the full, untruncated marker, + // even if it exceeds the cap. + assert_eq!(out, "…16 tokens truncated…"); + // Original string length is 62 bytes => ceil(62/4) = 16 tokens. + assert_eq!(original, Some(16)); + } + + #[test] + fn truncate_middle_prefers_newline_boundaries() { + // Build a multi-line string of 20 numbered lines (each "NNN\n"). + let mut s = String::new(); + for i in 1..=20 { + s.push_str(&format!("{i:03}\n")); + } + // Total length: 20 lines * 4 bytes per line = 80 bytes. + assert_eq!(s.len(), 80); + + // Choose a cap that forces truncation while leaving room for + // a few lines on each side after accounting for the marker. + let max_bytes = 64; + // Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20). + assert_eq!( + truncate_middle(&s, max_bytes), + ( + r#"001 +002 +003 +004 +…12 tokens truncated… +017 +018 +019 +020 +"# + .to_string(), + Some(20) + ) + ); + } +} diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 6d4699bc..ae183320 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -20,6 +20,7 @@ mod conversation_history; mod environment_context; pub mod error; pub mod exec; +mod exec_command; pub mod exec_env; mod flags; pub mod git_info; diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index bb5e6dac..272c901d 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -56,6 +56,7 @@ pub enum ConfigShellToolType { DefaultShell, ShellWithRequest { sandbox_policy: SandboxPolicy }, LocalShell, + StreamableShell, } #[derive(Debug, Clone)] @@ -72,13 +73,16 @@ impl ToolsConfig { sandbox_policy: SandboxPolicy, include_plan_tool: bool, include_apply_patch_tool: bool, + use_streamable_shell_tool: bool, ) -> Self { - let mut shell_type = if model_family.uses_local_shell_tool { + let mut shell_type = if use_streamable_shell_tool { + ConfigShellToolType::StreamableShell + } else if model_family.uses_local_shell_tool { ConfigShellToolType::LocalShell } else { ConfigShellToolType::DefaultShell }; - if matches!(approval_policy, AskForApproval::OnRequest) { + if matches!(approval_policy, AskForApproval::OnRequest) && !use_streamable_shell_tool { shell_type = ConfigShellToolType::ShellWithRequest { sandbox_policy: sandbox_policy.clone(), } @@ -492,6 +496,14 @@ pub(crate) fn get_openai_tools( ConfigShellToolType::LocalShell => { tools.push(OpenAiTool::LocalShell {}); } + ConfigShellToolType::StreamableShell => { + tools.push(OpenAiTool::Function( + crate::exec_command::create_exec_command_tool_for_responses_api(), + )); + tools.push(OpenAiTool::Function( + crate::exec_command::create_write_stdin_tool_for_responses_api(), + )); + } } if config.plan_tool { @@ -564,6 +576,7 @@ mod tests { SandboxPolicy::ReadOnly, true, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools(&config, Some(HashMap::new())); @@ -579,6 +592,7 @@ mod tests { SandboxPolicy::ReadOnly, true, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools(&config, Some(HashMap::new())); @@ -594,6 +608,7 @@ mod tests { SandboxPolicy::ReadOnly, false, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools( &config, @@ -688,6 +703,7 @@ mod tests { SandboxPolicy::ReadOnly, false, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools( @@ -744,6 +760,7 @@ mod tests { SandboxPolicy::ReadOnly, false, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools( @@ -795,6 +812,7 @@ mod tests { SandboxPolicy::ReadOnly, false, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools( @@ -849,6 +867,7 @@ mod tests { SandboxPolicy::ReadOnly, false, false, + /*use_experimental_streamable_shell_tool*/ false, ); let tools = get_openai_tools(