From c09ed74a163ecea69c32d61ab2bfa1c8490eb611 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Wed, 10 Sep 2025 17:38:11 -0700 Subject: [PATCH] Unified execution (#3288) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Unified PTY-Based Exec Tool Note: this requires to have this flag in the config: `use_experimental_unified_exec_tool=true` - Adds a PTY-backed interactive exec feature (“unified_exec”) with session reuse via session_id, bounded output (128 KiB), and timeout clamping (≤ 60 s). - Protocol: introduces ResponseItem::UnifiedExec { session_id, arguments, timeout_ms }. - Tools: exposes unified_exec as a function tool (Responses API); excluded from Chat Completions payload while still supported in tool lists. - Path handling: resolves commands via PATH (or explicit paths), with UTF‑8/newline‑aware truncation (truncate_middle). - Tests: cover command parsing, path resolution, session persistence/cleanup, multi‑session isolation, timeouts, and truncation behavior. --- codex-rs/core/Cargo.toml | 4 +- codex-rs/core/src/codex.rs | 106 +++ codex-rs/core/src/config.rs | 11 + .../src/exec_command/exec_command_session.rs | 9 + codex-rs/core/src/exec_command/mod.rs | 1 + .../core/src/exec_command/session_manager.rs | 158 +---- codex-rs/core/src/lib.rs | 2 + codex-rs/core/src/openai_tools.rs | 123 +++- codex-rs/core/src/truncate.rs | 180 +++++ codex-rs/core/src/unified_exec/errors.rs | 22 + codex-rs/core/src/unified_exec/mod.rs | 653 ++++++++++++++++++ codex-rs/core/tests/suite/prompt_caching.rs | 3 +- codex-rs/protocol/src/models.rs | 1 - 13 files changed, 1088 insertions(+), 185 deletions(-) create mode 100644 codex-rs/core/src/truncate.rs create mode 100644 codex-rs/core/src/unified_exec/errors.rs create mode 100644 codex-rs/core/src/unified_exec/mod.rs diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index b706afb2..69ef7bc4 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -54,6 +54,7 @@ tracing = { version = "0.1.41", features = ["log"] } tree-sitter = "0.25.9" tree-sitter-bash = "0.25.0" uuid = { version = "1", features = ["serde", "v4"] } +which = "6" wildmatch = "2.4.0" @@ -69,9 +70,6 @@ openssl-sys = { version = "*", features = ["vendored"] } [target.aarch64-unknown-linux-musl.dependencies] openssl-sys = { version = "*", features = ["vendored"] } -[target.'cfg(target_os = "windows")'.dependencies] -which = "6" - [dev-dependencies] assert_cmd = "2" core_test_support = { path = "tests/common" } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f3e9800c..356eadd8 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -26,6 +26,7 @@ use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; use futures::prelude::*; use mcp_types::CallToolResult; +use serde::Deserialize; use serde::Serialize; use serde_json; use tokio::sync::oneshot; @@ -112,6 +113,7 @@ use crate::safety::assess_command_safety; use crate::safety::assess_safety_for_untrusted_command; use crate::shell; use crate::turn_diff_tracker::TurnDiffTracker; +use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; use crate::user_notification::UserNotification; use crate::util::backoff; @@ -280,6 +282,7 @@ pub(crate) struct Session { /// Manager for external MCP servers/tools. mcp_connection_manager: McpConnectionManager, session_manager: ExecSessionManager, + unified_exec_manager: UnifiedExecSessionManager, /// External notifier command (will be passed as args to exec()). When /// `None` this feature is disabled. @@ -471,6 +474,7 @@ impl Session { include_web_search_request: config.tools_web_search_request, use_streamable_shell_tool: config.use_experimental_streamable_shell_tool, include_view_image_tool: config.include_view_image_tool, + experimental_unified_exec_tool: config.use_experimental_unified_exec_tool, }), user_instructions, base_instructions, @@ -484,6 +488,7 @@ impl Session { tx_event: tx_event.clone(), mcp_connection_manager, session_manager: ExecSessionManager::default(), + unified_exec_manager: UnifiedExecSessionManager::default(), notify, state: Mutex::new(state), rollout: Mutex::new(Some(rollout_recorder)), @@ -1149,6 +1154,7 @@ async fn submission_loop( include_web_search_request: config.tools_web_search_request, use_streamable_shell_tool: config.use_experimental_streamable_shell_tool, include_view_image_tool: config.include_view_image_tool, + experimental_unified_exec_tool: config.use_experimental_unified_exec_tool, }); let new_turn_context = TurnContext { @@ -1251,6 +1257,8 @@ async fn submission_loop( use_streamable_shell_tool: config .use_experimental_streamable_shell_tool, include_view_image_tool: config.include_view_image_tool, + experimental_unified_exec_tool: config + .use_experimental_unified_exec_tool, }), user_instructions: turn_context.user_instructions.clone(), base_instructions: turn_context.base_instructions.clone(), @@ -2082,6 +2090,72 @@ async fn handle_response_item( Ok(output) } +async fn handle_unified_exec_tool_call( + sess: &Session, + call_id: String, + session_id: Option, + arguments: Vec, + timeout_ms: Option, +) -> ResponseInputItem { + let parsed_session_id = if let Some(session_id) = session_id { + match session_id.parse::() { + Ok(parsed) => Some(parsed), + Err(output) => { + return ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload { + content: format!("invalid session_id: {session_id} due to error {output}"), + success: Some(false), + }, + }; + } + } + } else { + None + }; + + let request = crate::unified_exec::UnifiedExecRequest { + session_id: parsed_session_id, + input_chunks: &arguments, + timeout_ms, + }; + + let result = sess.unified_exec_manager.handle_request(request).await; + + let output_payload = match result { + Ok(value) => { + #[derive(Serialize)] + struct SerializedUnifiedExecResult<'a> { + session_id: Option, + output: &'a str, + } + + match serde_json::to_string(&SerializedUnifiedExecResult { + session_id: value.session_id.map(|id| id.to_string()), + output: &value.output, + }) { + Ok(serialized) => FunctionCallOutputPayload { + content: serialized, + success: Some(true), + }, + Err(err) => FunctionCallOutputPayload { + content: format!("failed to serialize unified exec output: {err}"), + success: Some(false), + }, + } + } + Err(err) => FunctionCallOutputPayload { + content: format!("unified exec failed: {err}"), + success: Some(false), + }, + }; + + ResponseInputItem::FunctionCallOutput { + call_id, + output: output_payload, + } +} + async fn handle_function_call( sess: &Session, turn_context: &TurnContext, @@ -2109,6 +2183,38 @@ async fn handle_function_call( ) .await } + "unified_exec" => { + #[derive(Deserialize)] + struct UnifiedExecArgs { + input: Vec, + #[serde(default)] + session_id: Option, + #[serde(default)] + timeout_ms: Option, + } + + let args = match serde_json::from_str::(&arguments) { + Ok(args) => args, + Err(err) => { + return ResponseInputItem::FunctionCallOutput { + call_id, + output: FunctionCallOutputPayload { + content: format!("failed to parse function arguments: {err}"), + success: Some(false), + }, + }; + } + }; + + handle_unified_exec_tool_call( + sess, + call_id, + args.session_id, + args.input, + args.timeout_ms, + ) + .await + } "view_image" => { #[derive(serde::Deserialize)] struct SeeImageArgs { diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 8ae9931f..316f7276 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -172,6 +172,9 @@ pub struct Config { pub use_experimental_streamable_shell_tool: bool, + /// If set to `true`, used only the experimental unified exec tool. + pub use_experimental_unified_exec_tool: bool, + /// Include the `view_image` tool that lets the agent attach a local image path to context. pub include_view_image_tool: bool, @@ -487,6 +490,7 @@ pub struct ConfigToml { pub experimental_instructions_file: Option, pub experimental_use_exec_command_tool: Option, + pub experimental_use_unified_exec_tool: Option, pub projects: Option>, @@ -837,6 +841,9 @@ impl Config { use_experimental_streamable_shell_tool: cfg .experimental_use_exec_command_tool .unwrap_or(false), + use_experimental_unified_exec_tool: cfg + .experimental_use_unified_exec_tool + .unwrap_or(true), include_view_image_tool, active_profile: active_profile_name, disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false), @@ -1212,6 +1219,7 @@ model_verbosity = "high" tools_web_search_request: false, preferred_auth_method: AuthMode::ChatGPT, use_experimental_streamable_shell_tool: false, + use_experimental_unified_exec_tool: true, include_view_image_tool: true, active_profile: Some("o3".to_string()), disable_paste_burst: false, @@ -1269,6 +1277,7 @@ model_verbosity = "high" tools_web_search_request: false, preferred_auth_method: AuthMode::ChatGPT, use_experimental_streamable_shell_tool: false, + use_experimental_unified_exec_tool: true, include_view_image_tool: true, active_profile: Some("gpt3".to_string()), disable_paste_burst: false, @@ -1341,6 +1350,7 @@ model_verbosity = "high" tools_web_search_request: false, preferred_auth_method: AuthMode::ChatGPT, use_experimental_streamable_shell_tool: false, + use_experimental_unified_exec_tool: true, include_view_image_tool: true, active_profile: Some("zdr".to_string()), disable_paste_burst: false, @@ -1399,6 +1409,7 @@ model_verbosity = "high" tools_web_search_request: false, preferred_auth_method: AuthMode::ChatGPT, use_experimental_streamable_shell_tool: false, + use_experimental_unified_exec_tool: true, include_view_image_tool: true, active_profile: Some("gpt5".to_string()), disable_paste_burst: false, diff --git a/codex-rs/core/src/exec_command/exec_command_session.rs b/codex-rs/core/src/exec_command/exec_command_session.rs index 7503150c..b524506e 100644 --- a/codex-rs/core/src/exec_command/exec_command_session.rs +++ b/codex-rs/core/src/exec_command/exec_command_session.rs @@ -24,6 +24,9 @@ pub(crate) struct ExecCommandSession { /// JoinHandle for the child wait task. wait_handle: StdMutex>>, + + /// Tracks whether the underlying process has exited. + exit_status: std::sync::Arc, } impl ExecCommandSession { @@ -34,6 +37,7 @@ impl ExecCommandSession { reader_handle: JoinHandle<()>, writer_handle: JoinHandle<()>, wait_handle: JoinHandle<()>, + exit_status: std::sync::Arc, ) -> Self { Self { writer_tx, @@ -42,6 +46,7 @@ impl ExecCommandSession { reader_handle: StdMutex::new(Some(reader_handle)), writer_handle: StdMutex::new(Some(writer_handle)), wait_handle: StdMutex::new(Some(wait_handle)), + exit_status, } } @@ -52,6 +57,10 @@ impl ExecCommandSession { pub(crate) fn output_receiver(&self) -> broadcast::Receiver> { self.output_tx.subscribe() } + + pub(crate) fn has_exited(&self) -> bool { + self.exit_status.load(std::sync::atomic::Ordering::SeqCst) + } } impl Drop for ExecCommandSession { diff --git a/codex-rs/core/src/exec_command/mod.rs b/codex-rs/core/src/exec_command/mod.rs index 9cdaa4d3..103b76ad 100644 --- a/codex-rs/core/src/exec_command/mod.rs +++ b/codex-rs/core/src/exec_command/mod.rs @@ -6,6 +6,7 @@ mod session_manager; pub use exec_command_params::ExecCommandParams; pub use exec_command_params::WriteStdinParams; +pub(crate) use exec_command_session::ExecCommandSession; pub use responses_api::EXEC_COMMAND_TOOL_NAME; pub use responses_api::WRITE_STDIN_TOOL_NAME; pub use responses_api::create_exec_command_tool_for_responses_api; diff --git a/codex-rs/core/src/exec_command/session_manager.rs b/codex-rs/core/src/exec_command/session_manager.rs index c547409c..9578610c 100644 --- a/codex-rs/core/src/exec_command/session_manager.rs +++ b/codex-rs/core/src/exec_command/session_manager.rs @@ -3,6 +3,7 @@ use std::io::ErrorKind; use std::io::Read; use std::sync::Arc; use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU32; use portable_pty::CommandBuilder; @@ -19,6 +20,7 @@ use crate::exec_command::exec_command_params::ExecCommandParams; use crate::exec_command::exec_command_params::WriteStdinParams; use crate::exec_command::exec_command_session::ExecCommandSession; use crate::exec_command::session_id::SessionId; +use crate::truncate::truncate_middle; use codex_protocol::models::FunctionCallOutputPayload; #[derive(Debug, Default)] @@ -327,11 +329,14 @@ async fn create_exec_command_session( // Keep the child alive until it exits, then signal exit code. let (exit_tx, exit_rx) = oneshot::channel::(); + let exit_status = Arc::new(AtomicBool::new(false)); + let wait_exit_status = exit_status.clone(); let wait_handle = tokio::task::spawn_blocking(move || { let code = match child.wait() { Ok(status) => status.exit_code() as i32, Err(_) => -1, }; + wait_exit_status.store(true, std::sync::atomic::Ordering::SeqCst); let _ = exit_tx.send(code); }); @@ -343,116 +348,11 @@ async fn create_exec_command_session( reader_handle, writer_handle, wait_handle, + exit_status, ); Ok((session, exit_rx)) } -/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes, -/// preserving the beginning and the end. Returns the possibly truncated -/// string and `Some(original_token_count)` (estimated at 4 bytes/token) -/// if truncation occurred; otherwise returns the original string and `None`. -fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option) { - // No truncation needed - if s.len() <= max_bytes { - return (s.to_string(), None); - } - let est_tokens = (s.len() as u64).div_ceil(4); - if max_bytes == 0 { - // Cannot keep any content; still return a full marker (never truncated). - return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); - } - - // Helper to truncate a string to a given byte length on a char boundary. - fn truncate_on_boundary(input: &str, max_len: usize) -> &str { - if input.len() <= max_len { - return input; - } - let mut end = max_len; - while end > 0 && !input.is_char_boundary(end) { - end -= 1; - } - &input[..end] - } - - // Given a left/right budget, prefer newline boundaries; otherwise fall back - // to UTF-8 char boundaries. - fn pick_prefix_end(s: &str, left_budget: usize) -> usize { - if let Some(head) = s.get(..left_budget) - && let Some(i) = head.rfind('\n') - { - return i + 1; // keep the newline so suffix starts on a fresh line - } - truncate_on_boundary(s, left_budget).len() - } - - fn pick_suffix_start(s: &str, right_budget: usize) -> usize { - let start_tail = s.len().saturating_sub(right_budget); - if let Some(tail) = s.get(start_tail..) - && let Some(i) = tail.find('\n') - { - return start_tail + i + 1; // start after newline - } - // Fall back to a char boundary at or after start_tail. - let mut idx = start_tail.min(s.len()); - while idx < s.len() && !s.is_char_boundary(idx) { - idx += 1; - } - idx - } - - // Refine marker length and budgets until stable. Marker is never truncated. - let mut guess_tokens = est_tokens; // worst-case: everything truncated - for _ in 0..4 { - let marker = format!("…{guess_tokens} tokens truncated…"); - let marker_len = marker.len(); - let keep_budget = max_bytes.saturating_sub(marker_len); - if keep_budget == 0 { - // No room for any content within the cap; return a full, untruncated marker - // that reflects the entire truncated content. - return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); - } - - let left_budget = keep_budget / 2; - let right_budget = keep_budget - left_budget; - let prefix_end = pick_prefix_end(s, left_budget); - let mut suffix_start = pick_suffix_start(s, right_budget); - if suffix_start < prefix_end { - suffix_start = prefix_end; - } - let kept_content_bytes = prefix_end + (s.len() - suffix_start); - let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes); - let new_tokens = (truncated_content_bytes as u64).div_ceil(4); - if new_tokens == guess_tokens { - let mut out = String::with_capacity(marker_len + kept_content_bytes + 1); - out.push_str(&s[..prefix_end]); - out.push_str(&marker); - // Place marker on its own line for symmetry when we keep line boundaries. - out.push('\n'); - out.push_str(&s[suffix_start..]); - return (out, Some(est_tokens)); - } - guess_tokens = new_tokens; - } - - // Fallback: use last guess to build output. - let marker = format!("…{guess_tokens} tokens truncated…"); - let marker_len = marker.len(); - let keep_budget = max_bytes.saturating_sub(marker_len); - if keep_budget == 0 { - return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); - } - let left_budget = keep_budget / 2; - let right_budget = keep_budget - left_budget; - let prefix_end = pick_prefix_end(s, left_budget); - let suffix_start = pick_suffix_start(s, right_budget); - let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1); - out.push_str(&s[..prefix_end]); - out.push_str(&marker); - out.push('\n'); - out.push_str(&s[suffix_start..]); - (out, Some(est_tokens)) -} - #[cfg(test)] mod tests { use super::*; @@ -616,50 +516,4 @@ Output: abc"#; assert_eq!(expected, text); } - - #[test] - fn truncate_middle_no_newlines_fallback() { - // A long string with no newlines that exceeds the cap. - let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - let max_bytes = 16; // force truncation - let (out, original) = truncate_middle(s, max_bytes); - // For very small caps, we return the full, untruncated marker, - // even if it exceeds the cap. - assert_eq!(out, "…16 tokens truncated…"); - // Original string length is 62 bytes => ceil(62/4) = 16 tokens. - assert_eq!(original, Some(16)); - } - - #[test] - fn truncate_middle_prefers_newline_boundaries() { - // Build a multi-line string of 20 numbered lines (each "NNN\n"). - let mut s = String::new(); - for i in 1..=20 { - s.push_str(&format!("{i:03}\n")); - } - // Total length: 20 lines * 4 bytes per line = 80 bytes. - assert_eq!(s.len(), 80); - - // Choose a cap that forces truncation while leaving room for - // a few lines on each side after accounting for the marker. - let max_bytes = 64; - // Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20). - assert_eq!( - truncate_middle(&s, max_bytes), - ( - r#"001 -002 -003 -004 -…12 tokens truncated… -017 -018 -019 -020 -"# - .to_string(), - Some(20) - ) - ); - } } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 77b4037e..5223b6d3 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -35,6 +35,8 @@ mod mcp_tool_call; mod message_history; mod model_provider_info; pub mod parse_command; +mod truncate; +mod unified_exec; mod user_instructions; pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID; pub use model_provider_info::ModelProviderInfo; diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index 9521e4ee..ec56fdfd 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -70,6 +70,7 @@ pub(crate) struct ToolsConfig { pub apply_patch_tool_type: Option, pub web_search_request: bool, pub include_view_image_tool: bool, + pub experimental_unified_exec_tool: bool, } pub(crate) struct ToolsConfigParams<'a> { @@ -81,6 +82,7 @@ pub(crate) struct ToolsConfigParams<'a> { pub(crate) include_web_search_request: bool, pub(crate) use_streamable_shell_tool: bool, pub(crate) include_view_image_tool: bool, + pub(crate) experimental_unified_exec_tool: bool, } impl ToolsConfig { @@ -94,6 +96,7 @@ impl ToolsConfig { include_web_search_request, use_streamable_shell_tool, include_view_image_tool, + experimental_unified_exec_tool, } = params; let mut shell_type = if *use_streamable_shell_tool { ConfigShellToolType::StreamableShell @@ -126,6 +129,7 @@ impl ToolsConfig { apply_patch_tool_type, web_search_request: *include_web_search_request, include_view_image_tool: *include_view_image_tool, + experimental_unified_exec_tool: *experimental_unified_exec_tool, } } } @@ -200,6 +204,53 @@ fn create_shell_tool() -> OpenAiTool { }) } +fn create_unified_exec_tool() -> OpenAiTool { + let mut properties = BTreeMap::new(); + properties.insert( + "input".to_string(), + JsonSchema::Array { + items: Box::new(JsonSchema::String { description: None }), + description: Some( + "When no session_id is provided, treat the array as the command and arguments \ + to launch. When session_id is set, concatenate the strings (in order) and write \ + them to the session's stdin." + .to_string(), + ), + }, + ); + properties.insert( + "session_id".to_string(), + JsonSchema::String { + description: Some( + "Identifier for an existing interactive session. If omitted, a new command \ + is spawned." + .to_string(), + ), + }, + ); + properties.insert( + "timeout_ms".to_string(), + JsonSchema::Number { + description: Some( + "Maximum time in milliseconds to wait for output after writing the input." + .to_string(), + ), + }, + ); + + OpenAiTool::Function(ResponsesApiTool { + name: "unified_exec".to_string(), + description: + "Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["input".to_string()]), + additional_properties: Some(false), + }, + }) +} + fn create_shell_tool_for_sandbox(sandbox_policy: &SandboxPolicy) -> OpenAiTool { let mut properties = BTreeMap::new(); properties.insert( @@ -531,23 +582,27 @@ pub(crate) fn get_openai_tools( ) -> Vec { let mut tools: Vec = Vec::new(); - match &config.shell_type { - ConfigShellToolType::DefaultShell => { - tools.push(create_shell_tool()); - } - ConfigShellToolType::ShellWithRequest { sandbox_policy } => { - tools.push(create_shell_tool_for_sandbox(sandbox_policy)); - } - ConfigShellToolType::LocalShell => { - tools.push(OpenAiTool::LocalShell {}); - } - ConfigShellToolType::StreamableShell => { - tools.push(OpenAiTool::Function( - crate::exec_command::create_exec_command_tool_for_responses_api(), - )); - tools.push(OpenAiTool::Function( - crate::exec_command::create_write_stdin_tool_for_responses_api(), - )); + if config.experimental_unified_exec_tool { + tools.push(create_unified_exec_tool()); + } else { + match &config.shell_type { + ConfigShellToolType::DefaultShell => { + tools.push(create_shell_tool()); + } + ConfigShellToolType::ShellWithRequest { sandbox_policy } => { + tools.push(create_shell_tool_for_sandbox(sandbox_policy)); + } + ConfigShellToolType::LocalShell => { + tools.push(OpenAiTool::LocalShell {}); + } + ConfigShellToolType::StreamableShell => { + tools.push(OpenAiTool::Function( + crate::exec_command::create_exec_command_tool_for_responses_api(), + )); + tools.push(OpenAiTool::Function( + crate::exec_command::create_write_stdin_tool_for_responses_api(), + )); + } } } @@ -574,10 +629,8 @@ pub(crate) fn get_openai_tools( if config.include_view_image_tool { tools.push(create_view_image_tool()); } - if let Some(mcp_tools) = mcp_tools { // Ensure deterministic ordering to maximize prompt cache hits. - // HashMap iteration order is non-deterministic, so sort by fully-qualified tool name. let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect(); entries.sort_by(|a, b| a.0.cmp(&b.0)); @@ -639,12 +692,13 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools(&config, Some(HashMap::new())); assert_eq_tool_names( &tools, - &["local_shell", "update_plan", "web_search", "view_image"], + &["unified_exec", "update_plan", "web_search", "view_image"], ); } @@ -660,12 +714,13 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools(&config, Some(HashMap::new())); assert_eq_tool_names( &tools, - &["shell", "update_plan", "web_search", "view_image"], + &["unified_exec", "update_plan", "web_search", "view_image"], ); } @@ -681,6 +736,7 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools( &config, @@ -723,7 +779,7 @@ mod tests { assert_eq_tool_names( &tools, &[ - "shell", + "unified_exec", "web_search", "view_image", "test_server/do_something_cool", @@ -786,6 +842,7 @@ mod tests { include_web_search_request: false, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); // Intentionally construct a map with keys that would sort alphabetically. @@ -838,11 +895,11 @@ mod tests { ]); let tools = get_openai_tools(&config, Some(tools_map)); - // Expect shell first, followed by MCP tools sorted by fully-qualified name. + // Expect unified_exec first, followed by MCP tools sorted by fully-qualified name. assert_eq_tool_names( &tools, &[ - "shell", + "unified_exec", "view_image", "test_server/cool", "test_server/do", @@ -863,6 +920,7 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools( @@ -890,7 +948,7 @@ mod tests { assert_eq_tool_names( &tools, - &["shell", "web_search", "view_image", "dash/search"], + &["unified_exec", "web_search", "view_image", "dash/search"], ); assert_eq!( @@ -925,6 +983,7 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools( @@ -950,7 +1009,7 @@ mod tests { assert_eq_tool_names( &tools, - &["shell", "web_search", "view_image", "dash/paginate"], + &["unified_exec", "web_search", "view_image", "dash/paginate"], ); assert_eq!( tools[3], @@ -982,6 +1041,7 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools( @@ -1005,7 +1065,10 @@ mod tests { )])), ); - assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/tags"]); + assert_eq_tool_names( + &tools, + &["unified_exec", "web_search", "view_image", "dash/tags"], + ); assert_eq!( tools[3], OpenAiTool::Function(ResponsesApiTool { @@ -1039,6 +1102,7 @@ mod tests { include_web_search_request: true, use_streamable_shell_tool: false, include_view_image_tool: true, + experimental_unified_exec_tool: true, }); let tools = get_openai_tools( @@ -1062,7 +1126,10 @@ mod tests { )])), ); - assert_eq_tool_names(&tools, &["shell", "web_search", "view_image", "dash/value"]); + assert_eq_tool_names( + &tools, + &["unified_exec", "web_search", "view_image", "dash/value"], + ); assert_eq!( tools[3], OpenAiTool::Function(ResponsesApiTool { diff --git a/codex-rs/core/src/truncate.rs b/codex-rs/core/src/truncate.rs new file mode 100644 index 00000000..ab015872 --- /dev/null +++ b/codex-rs/core/src/truncate.rs @@ -0,0 +1,180 @@ +//! Utilities for truncating large chunks of output while preserving a prefix +//! and suffix on UTF-8 boundaries. + +/// Truncate the middle of a UTF-8 string to at most `max_bytes` bytes, +/// preserving the beginning and the end. Returns the possibly truncated +/// string and `Some(original_token_count)` (estimated at 4 bytes/token) +/// if truncation occurred; otherwise returns the original string and `None`. +pub(crate) fn truncate_middle(s: &str, max_bytes: usize) -> (String, Option) { + if s.len() <= max_bytes { + return (s.to_string(), None); + } + + let est_tokens = (s.len() as u64).div_ceil(4); + if max_bytes == 0 { + return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); + } + + fn truncate_on_boundary(input: &str, max_len: usize) -> &str { + if input.len() <= max_len { + return input; + } + let mut end = max_len; + while end > 0 && !input.is_char_boundary(end) { + end -= 1; + } + &input[..end] + } + + fn pick_prefix_end(s: &str, left_budget: usize) -> usize { + if let Some(head) = s.get(..left_budget) + && let Some(i) = head.rfind('\n') + { + return i + 1; + } + truncate_on_boundary(s, left_budget).len() + } + + fn pick_suffix_start(s: &str, right_budget: usize) -> usize { + let start_tail = s.len().saturating_sub(right_budget); + if let Some(tail) = s.get(start_tail..) + && let Some(i) = tail.find('\n') + { + return start_tail + i + 1; + } + + let mut idx = start_tail.min(s.len()); + while idx < s.len() && !s.is_char_boundary(idx) { + idx += 1; + } + idx + } + + let mut guess_tokens = est_tokens; + for _ in 0..4 { + let marker = format!("…{guess_tokens} tokens truncated…"); + let marker_len = marker.len(); + let keep_budget = max_bytes.saturating_sub(marker_len); + if keep_budget == 0 { + return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); + } + + let left_budget = keep_budget / 2; + let right_budget = keep_budget - left_budget; + let prefix_end = pick_prefix_end(s, left_budget); + let mut suffix_start = pick_suffix_start(s, right_budget); + if suffix_start < prefix_end { + suffix_start = prefix_end; + } + + let kept_content_bytes = prefix_end + (s.len() - suffix_start); + let truncated_content_bytes = s.len().saturating_sub(kept_content_bytes); + let new_tokens = (truncated_content_bytes as u64).div_ceil(4); + + if new_tokens == guess_tokens { + let mut out = String::with_capacity(marker_len + kept_content_bytes + 1); + out.push_str(&s[..prefix_end]); + out.push_str(&marker); + out.push('\n'); + out.push_str(&s[suffix_start..]); + return (out, Some(est_tokens)); + } + + guess_tokens = new_tokens; + } + + let marker = format!("…{guess_tokens} tokens truncated…"); + let marker_len = marker.len(); + let keep_budget = max_bytes.saturating_sub(marker_len); + if keep_budget == 0 { + return (format!("…{est_tokens} tokens truncated…"), Some(est_tokens)); + } + + let left_budget = keep_budget / 2; + let right_budget = keep_budget - left_budget; + let prefix_end = pick_prefix_end(s, left_budget); + let suffix_start = pick_suffix_start(s, right_budget); + + let mut out = String::with_capacity(marker_len + prefix_end + (s.len() - suffix_start) + 1); + out.push_str(&s[..prefix_end]); + out.push_str(&marker); + out.push('\n'); + out.push_str(&s[suffix_start..]); + (out, Some(est_tokens)) +} + +#[cfg(test)] +mod tests { + use super::truncate_middle; + + #[test] + fn truncate_middle_no_newlines_fallback() { + let s = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ*"; + let max_bytes = 32; + let (out, original) = truncate_middle(s, max_bytes); + assert!(out.starts_with("abc")); + assert!(out.contains("tokens truncated")); + assert!(out.ends_with("XYZ*")); + assert_eq!(original, Some((s.len() as u64).div_ceil(4))); + } + + #[test] + fn truncate_middle_prefers_newline_boundaries() { + let mut s = String::new(); + for i in 1..=20 { + s.push_str(&format!("{i:03}\n")); + } + assert_eq!(s.len(), 80); + + let max_bytes = 64; + let (out, tokens) = truncate_middle(&s, max_bytes); + assert!(out.starts_with("001\n002\n003\n004\n")); + assert!(out.contains("tokens truncated")); + assert!(out.ends_with("017\n018\n019\n020\n")); + assert_eq!(tokens, Some(20)); + } + + #[test] + fn truncate_middle_handles_utf8_content() { + let s = "😀😀😀😀😀😀😀😀😀😀\nsecond line with ascii text\n"; + let max_bytes = 32; + let (out, tokens) = truncate_middle(s, max_bytes); + + assert!(out.contains("tokens truncated")); + assert!(!out.contains('\u{fffd}')); + assert_eq!(tokens, Some((s.len() as u64).div_ceil(4))); + } + + #[test] + fn truncate_middle_prefers_newline_boundaries_2() { + // Build a multi-line string of 20 numbered lines (each "NNN\n"). + let mut s = String::new(); + for i in 1..=20 { + s.push_str(&format!("{i:03}\n")); + } + // Total length: 20 lines * 4 bytes per line = 80 bytes. + assert_eq!(s.len(), 80); + + // Choose a cap that forces truncation while leaving room for + // a few lines on each side after accounting for the marker. + let max_bytes = 64; + // Expect exact output: first 4 lines, marker, last 4 lines, and correct token estimate (80/4 = 20). + assert_eq!( + truncate_middle(&s, max_bytes), + ( + r#"001 +002 +003 +004 +…12 tokens truncated… +017 +018 +019 +020 +"# + .to_string(), + Some(20) + ) + ); + } +} diff --git a/codex-rs/core/src/unified_exec/errors.rs b/codex-rs/core/src/unified_exec/errors.rs new file mode 100644 index 00000000..6bf5bf7e --- /dev/null +++ b/codex-rs/core/src/unified_exec/errors.rs @@ -0,0 +1,22 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub(crate) enum UnifiedExecError { + #[error("Failed to create unified exec session: {pty_error}")] + CreateSession { + #[source] + pty_error: anyhow::Error, + }, + #[error("Unknown session id {session_id}")] + UnknownSessionId { session_id: i32 }, + #[error("failed to write to stdin")] + WriteToStdin, + #[error("missing command line for unified exec request")] + MissingCommandLine, +} + +impl UnifiedExecError { + pub(crate) fn create_session(error: anyhow::Error) -> Self { + Self::CreateSession { pty_error: error } + } +} diff --git a/codex-rs/core/src/unified_exec/mod.rs b/codex-rs/core/src/unified_exec/mod.rs new file mode 100644 index 00000000..25ab3d4d --- /dev/null +++ b/codex-rs/core/src/unified_exec/mod.rs @@ -0,0 +1,653 @@ +use portable_pty::CommandBuilder; +use portable_pty::PtySize; +use portable_pty::native_pty_system; +use std::collections::HashMap; +use std::collections::VecDeque; +use std::io::ErrorKind; +use std::io::Read; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicI32; +use std::sync::atomic::Ordering; +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio::time::Duration; +use tokio::time::Instant; + +use crate::exec_command::ExecCommandSession; +use crate::truncate::truncate_middle; + +mod errors; + +pub(crate) use errors::UnifiedExecError; + +const DEFAULT_TIMEOUT_MS: u64 = 1_000; +const MAX_TIMEOUT_MS: u64 = 60_000; +const UNIFIED_EXEC_OUTPUT_MAX_BYTES: usize = 128 * 1024; // 128 KiB + +#[derive(Debug)] +pub(crate) struct UnifiedExecRequest<'a> { + pub session_id: Option, + pub input_chunks: &'a [String], + pub timeout_ms: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct UnifiedExecResult { + pub session_id: Option, + pub output: String, +} + +#[derive(Debug, Default)] +pub(crate) struct UnifiedExecSessionManager { + next_session_id: AtomicI32, + sessions: Mutex>, +} + +#[derive(Debug)] +struct ManagedUnifiedExecSession { + session: ExecCommandSession, + output_buffer: OutputBuffer, + /// Notifies waiters whenever new output has been appended to + /// `output_buffer`, allowing clients to poll for fresh data. + output_notify: Arc, + output_task: JoinHandle<()>, +} + +#[derive(Debug, Default)] +struct OutputBufferState { + chunks: VecDeque>, + total_bytes: usize, +} + +impl OutputBufferState { + fn push_chunk(&mut self, chunk: Vec) { + self.total_bytes = self.total_bytes.saturating_add(chunk.len()); + self.chunks.push_back(chunk); + + let mut excess = self + .total_bytes + .saturating_sub(UNIFIED_EXEC_OUTPUT_MAX_BYTES); + + while excess > 0 { + match self.chunks.front_mut() { + Some(front) if excess >= front.len() => { + excess -= front.len(); + self.total_bytes = self.total_bytes.saturating_sub(front.len()); + self.chunks.pop_front(); + } + Some(front) => { + front.drain(..excess); + self.total_bytes = self.total_bytes.saturating_sub(excess); + break; + } + None => break, + } + } + } + + fn drain(&mut self) -> Vec> { + let drained: Vec> = self.chunks.drain(..).collect(); + self.total_bytes = 0; + drained + } +} + +type OutputBuffer = Arc>; +type OutputHandles = (OutputBuffer, Arc); + +impl ManagedUnifiedExecSession { + fn new(session: ExecCommandSession) -> Self { + let output_buffer = Arc::new(Mutex::new(OutputBufferState::default())); + let output_notify = Arc::new(Notify::new()); + let mut receiver = session.output_receiver(); + let buffer_clone = Arc::clone(&output_buffer); + let notify_clone = Arc::clone(&output_notify); + let output_task = tokio::spawn(async move { + while let Ok(chunk) = receiver.recv().await { + let mut guard = buffer_clone.lock().await; + guard.push_chunk(chunk); + drop(guard); + notify_clone.notify_waiters(); + } + }); + + Self { + session, + output_buffer, + output_notify, + output_task, + } + } + + fn writer_sender(&self) -> mpsc::Sender> { + self.session.writer_sender() + } + + fn output_handles(&self) -> OutputHandles { + ( + Arc::clone(&self.output_buffer), + Arc::clone(&self.output_notify), + ) + } + + fn has_exited(&self) -> bool { + self.session.has_exited() + } +} + +impl Drop for ManagedUnifiedExecSession { + fn drop(&mut self) { + self.output_task.abort(); + } +} + +impl UnifiedExecSessionManager { + pub async fn handle_request( + &self, + request: UnifiedExecRequest<'_>, + ) -> Result { + let (timeout_ms, timeout_warning) = match request.timeout_ms { + Some(requested) if requested > MAX_TIMEOUT_MS => ( + MAX_TIMEOUT_MS, + Some(format!( + "Warning: requested timeout {requested}ms exceeds maximum of {MAX_TIMEOUT_MS}ms; clamping to {MAX_TIMEOUT_MS}ms.\n" + )), + ), + Some(requested) => (requested, None), + None => (DEFAULT_TIMEOUT_MS, None), + }; + + let mut new_session: Option = None; + let session_id; + let writer_tx; + let output_buffer; + let output_notify; + + if let Some(existing_id) = request.session_id { + let mut sessions = self.sessions.lock().await; + match sessions.get(&existing_id) { + Some(session) => { + if session.has_exited() { + sessions.remove(&existing_id); + return Err(UnifiedExecError::UnknownSessionId { + session_id: existing_id, + }); + } + let (buffer, notify) = session.output_handles(); + session_id = existing_id; + writer_tx = session.writer_sender(); + output_buffer = buffer; + output_notify = notify; + } + None => { + return Err(UnifiedExecError::UnknownSessionId { + session_id: existing_id, + }); + } + } + drop(sessions); + } else { + let command = request.input_chunks.to_vec(); + let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst); + let session = create_unified_exec_session(&command).await?; + let managed_session = ManagedUnifiedExecSession::new(session); + let (buffer, notify) = managed_session.output_handles(); + writer_tx = managed_session.writer_sender(); + output_buffer = buffer; + output_notify = notify; + session_id = new_id; + new_session = Some(managed_session); + }; + + if request.session_id.is_some() { + let joined_input = request.input_chunks.join(" "); + if !joined_input.is_empty() && writer_tx.send(joined_input.into_bytes()).await.is_err() + { + return Err(UnifiedExecError::WriteToStdin); + } + } + + let mut collected: Vec = Vec::with_capacity(4096); + let start = Instant::now(); + let deadline = start + Duration::from_millis(timeout_ms); + + loop { + let drained_chunks; + let mut wait_for_output = None; + { + let mut guard = output_buffer.lock().await; + drained_chunks = guard.drain(); + if drained_chunks.is_empty() { + wait_for_output = Some(output_notify.notified()); + } + } + + if drained_chunks.is_empty() { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining == Duration::ZERO { + break; + } + + let notified = wait_for_output.unwrap_or_else(|| output_notify.notified()); + tokio::pin!(notified); + tokio::select! { + _ = &mut notified => {} + _ = tokio::time::sleep(remaining) => break, + } + continue; + } + + for chunk in drained_chunks { + collected.extend_from_slice(&chunk); + } + + if Instant::now() >= deadline { + break; + } + } + + let (output, _maybe_tokens) = truncate_middle( + &String::from_utf8_lossy(&collected), + UNIFIED_EXEC_OUTPUT_MAX_BYTES, + ); + let output = if let Some(warning) = timeout_warning { + format!("{warning}{output}") + } else { + output + }; + + let should_store_session = if let Some(session) = new_session.as_ref() { + !session.has_exited() + } else if request.session_id.is_some() { + let mut sessions = self.sessions.lock().await; + if let Some(existing) = sessions.get(&session_id) { + if existing.has_exited() { + sessions.remove(&session_id); + false + } else { + true + } + } else { + false + } + } else { + true + }; + + if should_store_session { + if let Some(session) = new_session { + self.sessions.lock().await.insert(session_id, session); + } + Ok(UnifiedExecResult { + session_id: Some(session_id), + output, + }) + } else { + Ok(UnifiedExecResult { + session_id: None, + output, + }) + } + } +} + +async fn create_unified_exec_session( + command: &[String], +) -> Result { + if command.is_empty() { + return Err(UnifiedExecError::MissingCommandLine); + } + + let pty_system = native_pty_system(); + + let pair = pty_system + .openpty(PtySize { + rows: 24, + cols: 80, + pixel_width: 0, + pixel_height: 0, + }) + .map_err(UnifiedExecError::create_session)?; + + // Safe thanks to the check at the top of the function. + let mut command_builder = CommandBuilder::new(command[0].clone()); + for arg in &command[1..] { + command_builder.arg(arg); + } + + let mut child = pair + .slave + .spawn_command(command_builder) + .map_err(UnifiedExecError::create_session)?; + let killer = child.clone_killer(); + + let (writer_tx, mut writer_rx) = mpsc::channel::>(128); + let (output_tx, _) = tokio::sync::broadcast::channel::>(256); + + let mut reader = pair + .master + .try_clone_reader() + .map_err(UnifiedExecError::create_session)?; + let output_tx_clone = output_tx.clone(); + let reader_handle = tokio::task::spawn_blocking(move || { + let mut buf = [0u8; 8192]; + loop { + match reader.read(&mut buf) { + Ok(0) => break, + Ok(n) => { + let _ = output_tx_clone.send(buf[..n].to_vec()); + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, + Err(ref e) if e.kind() == ErrorKind::WouldBlock => { + std::thread::sleep(Duration::from_millis(5)); + continue; + } + Err(_) => break, + } + } + }); + + let writer = pair + .master + .take_writer() + .map_err(UnifiedExecError::create_session)?; + let writer = Arc::new(StdMutex::new(writer)); + let writer_handle = tokio::spawn({ + let writer = writer.clone(); + async move { + while let Some(bytes) = writer_rx.recv().await { + let writer = writer.clone(); + let _ = tokio::task::spawn_blocking(move || { + if let Ok(mut guard) = writer.lock() { + use std::io::Write; + let _ = guard.write_all(&bytes); + let _ = guard.flush(); + } + }) + .await; + } + } + }); + + let exit_status = Arc::new(AtomicBool::new(false)); + let wait_exit_status = Arc::clone(&exit_status); + let wait_handle = tokio::task::spawn_blocking(move || { + let _ = child.wait(); + wait_exit_status.store(true, Ordering::SeqCst); + }); + + Ok(ExecCommandSession::new( + writer_tx, + output_tx, + killer, + reader_handle, + writer_handle, + wait_handle, + exit_status, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn push_chunk_trims_only_excess_bytes() { + let mut buffer = OutputBufferState::default(); + buffer.push_chunk(vec![b'a'; UNIFIED_EXEC_OUTPUT_MAX_BYTES]); + buffer.push_chunk(vec![b'b']); + buffer.push_chunk(vec![b'c']); + + assert_eq!(buffer.total_bytes, UNIFIED_EXEC_OUTPUT_MAX_BYTES); + assert_eq!(buffer.chunks.len(), 3); + assert_eq!( + buffer.chunks.front().unwrap().len(), + UNIFIED_EXEC_OUTPUT_MAX_BYTES - 2 + ); + assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'c']); + assert_eq!(buffer.chunks.pop_back().unwrap(), vec![b'b']); + } + + #[cfg(unix)] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn unified_exec_persists_across_requests_jif() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + + let open_shell = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["bash".to_string(), "-i".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + let session_id = open_shell.session_id.expect("expected session_id"); + + manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &[ + "export".to_string(), + "CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(), + ], + timeout_ms: Some(2_500), + }) + .await?; + + let out_2 = manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + assert!(out_2.output.contains("codex")); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn multi_unified_exec_sessions() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + + let shell_a = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["/bin/bash".to_string(), "-i".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + let session_a = shell_a.session_id.expect("expected session id"); + + manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_a), + input_chunks: &["export CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + + let out_2 = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &[ + "echo".to_string(), + "$CODEX_INTERACTIVE_SHELL_VAR\n".to_string(), + ], + timeout_ms: Some(1_500), + }) + .await?; + assert!(!out_2.output.contains("codex")); + + let out_3 = manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_a), + input_chunks: &["echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + assert!(out_3.output.contains("codex")); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test] + async fn unified_exec_timeouts() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + + let open_shell = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["bash".to_string(), "-i".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + let session_id = open_shell.session_id.expect("expected session id"); + + manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &[ + "export".to_string(), + "CODEX_INTERACTIVE_SHELL_VAR=codex\n".to_string(), + ], + timeout_ms: Some(1_500), + }) + .await?; + + let out_2 = manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &["sleep 5 && echo $CODEX_INTERACTIVE_SHELL_VAR\n".to_string()], + timeout_ms: Some(10), + }) + .await?; + assert!(!out_2.output.contains("codex")); + + tokio::time::sleep(Duration::from_secs(7)).await; + + let empty = Vec::new(); + let out_3 = manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &empty, + timeout_ms: Some(100), + }) + .await?; + + assert!(out_3.output.contains("codex")); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test] + async fn requests_with_large_timeout_are_capped() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + + let result = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["echo".to_string(), "codex".to_string()], + timeout_ms: Some(120_000), + }) + .await?; + + assert!(result.output.starts_with( + "Warning: requested timeout 120000ms exceeds maximum of 60000ms; clamping to 60000ms.\n" + )); + assert!(result.output.contains("codex")); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test] + async fn completed_commands_do_not_persist_sessions() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + let result = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["/bin/echo".to_string(), "codex".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + + assert!(result.session_id.is_none()); + assert!(result.output.contains("codex")); + + assert!(manager.sessions.lock().await.is_empty()); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test] + async fn correct_path_resolution() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + let result = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["echo".to_string(), "codex".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + + assert!(result.session_id.is_none()); + assert!(result.output.contains("codex")); + + assert!(manager.sessions.lock().await.is_empty()); + + Ok(()) + } + + #[cfg(unix)] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn reusing_completed_session_returns_unknown_session() -> Result<(), UnifiedExecError> { + let manager = UnifiedExecSessionManager::default(); + + let open_shell = manager + .handle_request(UnifiedExecRequest { + session_id: None, + input_chunks: &["/bin/bash".to_string(), "-i".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + let session_id = open_shell.session_id.expect("expected session id"); + + manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &["exit\n".to_string()], + timeout_ms: Some(1_500), + }) + .await?; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let err = manager + .handle_request(UnifiedExecRequest { + session_id: Some(session_id), + input_chunks: &[], + timeout_ms: Some(100), + }) + .await + .expect_err("expected unknown session error"); + + match err { + UnifiedExecError::UnknownSessionId { session_id: err_id } => { + assert_eq!(err_id, session_id); + } + other => panic!("expected UnknownSessionId, got {other:?}"), + } + + assert!(!manager.sessions.lock().await.contains_key(&session_id)); + + Ok(()) + } +} diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index 5ac3da60..e77df01f 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -191,7 +191,8 @@ async fn prompt_tools_are_consistent_across_requests() { let expected_instructions: &str = include_str!("../../prompt.md"); // our internal implementation is responsible for keeping tools in sync // with the OpenAI schema, so we just verify the tool presence here - let expected_tools_names: &[&str] = &["shell", "update_plan", "apply_patch", "view_image"]; + let expected_tools_names: &[&str] = + &["unified_exec", "update_plan", "apply_patch", "view_image"]; let body0 = requests[0].body_json::().unwrap(); assert_eq!( body0["instructions"], diff --git a/codex-rs/protocol/src/models.rs b/codex-rs/protocol/src/models.rs index 23a86a59..13a81c9d 100644 --- a/codex-rs/protocol/src/models.rs +++ b/codex-rs/protocol/src/models.rs @@ -115,7 +115,6 @@ pub enum ResponseItem { status: Option, action: WebSearchAction, }, - #[serde(other)] Other, }