diff --git a/codex-rs/core/src/tools/events.rs b/codex-rs/core/src/tools/events.rs index af8afe3e..74cf1ade 100644 --- a/codex-rs/core/src/tools/events.rs +++ b/codex-rs/core/src/tools/events.rs @@ -11,6 +11,7 @@ use crate::protocol::PatchApplyEndEvent; use crate::protocol::TurnDiffEvent; use crate::tools::context::SharedTurnDiffTracker; use std::collections::HashMap; +use std::path::Path; use std::path::PathBuf; use std::time::Duration; @@ -51,6 +52,20 @@ pub(crate) enum ToolEventFailure { Output(ExecToolCallOutput), Message(String), } + +pub(crate) async fn emit_exec_command_begin(ctx: ToolEventCtx<'_>, command: &[String], cwd: &Path) { + ctx.session + .send_event( + ctx.turn, + EventMsg::ExecCommandBegin(ExecCommandBeginEvent { + call_id: ctx.call_id.to_string(), + command: command.to_vec(), + cwd: cwd.to_path_buf(), + parsed_cmd: parse_command(command), + }), + ) + .await; +} // Concrete, allocation-free emitter: avoid trait objects and boxed futures. pub(crate) enum ToolEmitter { Shell { @@ -78,17 +93,7 @@ impl ToolEmitter { pub async fn emit(&self, ctx: ToolEventCtx<'_>, stage: ToolEventStage) { match (self, stage) { (Self::Shell { command, cwd }, ToolEventStage::Begin) => { - ctx.session - .send_event( - ctx.turn, - EventMsg::ExecCommandBegin(ExecCommandBeginEvent { - call_id: ctx.call_id.to_string(), - command: command.clone(), - cwd: cwd.clone(), - parsed_cmd: parse_command(command), - }), - ) - .await; + emit_exec_command_begin(ctx, command, cwd.as_path()).await; } (Self::Shell { .. }, ToolEventStage::Success(output)) => { emit_exec_end( diff --git a/codex-rs/core/src/unified_exec/session_manager.rs b/codex-rs/core/src/unified_exec/session_manager.rs index 83b076d9..dc6e6004 100644 --- a/codex-rs/core/src/unified_exec/session_manager.rs +++ b/codex-rs/core/src/unified_exec/session_manager.rs @@ -7,6 +7,9 @@ use tokio::time::Instant; use crate::exec_env::create_env; use crate::sandboxing::ExecEnv; +use crate::tools::events::ToolEmitter; +use crate::tools::events::ToolEventCtx; +use crate::tools::events::ToolEventStage; use crate::tools::orchestrator::ToolOrchestrator; use crate::tools::runtimes::unified_exec::UnifiedExecRequest as UnifiedExecToolRequest; use crate::tools::runtimes::unified_exec::UnifiedExecRuntime; @@ -246,6 +249,13 @@ impl UnifiedExecSessionManager { None => (DEFAULT_TIMEOUT_MS, None), }; + if !request.input_chunks.is_empty() { + let event_ctx = ToolEventCtx::new(context.session, context.turn, context.call_id, None); + let emitter = + ToolEmitter::shell(request.input_chunks.to_vec(), context.turn.cwd.clone()); + emitter.emit(event_ctx, ToolEventStage::Begin).await; + } + let mut acquisition = self.acquire_session(&request, &context).await?; if acquisition.reuse_requested { diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs index 78a1abf1..a53cab8f 100644 --- a/codex-rs/core/tests/suite/unified_exec.rs +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use anyhow::Result; use codex_core::features::Feature; +use codex_core::parse_command::parse_command; use codex_core::protocol::AskForApproval; use codex_core::protocol::EventMsg; use codex_core::protocol::Op; @@ -22,7 +23,10 @@ use core_test_support::skip_if_sandbox; use core_test_support::test_codex::TestCodex; use core_test_support::test_codex::test_codex; use core_test_support::wait_for_event; +use core_test_support::wait_for_event_match; +use core_test_support::wait_for_event_with_timeout; use serde_json::Value; +use serde_json::json; fn extract_output_text(item: &Value) -> Option<&str> { item.get("output").and_then(|value| match value { @@ -58,6 +62,180 @@ fn collect_tool_outputs(bodies: &[Value]) -> Result> { Ok(outputs) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_emits_exec_command_begin_event() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "uexec-begin-event"; + let command = vec!["/bin/echo".to_string(), "hello unified exec".to_string()]; + let args = json!({ + "input": command.clone(), + "timeout_ms": 250, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "unified_exec", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "finished"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "emit begin event".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin_event = wait_for_event_match(&codex, |msg| match msg { + EventMsg::ExecCommandBegin(event) if event.call_id == call_id => Some(event.clone()), + _ => None, + }) + .await; + + assert_eq!(begin_event.command, command); + assert_eq!(begin_event.cwd, cwd.path()); + assert_eq!(begin_event.parsed_cmd, parse_command(&command)); + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_skips_begin_event_for_empty_input() -> Result<()> { + use tokio::time::Duration; + + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + config.features.enable(Feature::UnifiedExec); + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let open_call_id = "uexec-open-session"; + let open_command = vec![ + "/bin/sh".to_string(), + "-c".to_string(), + "echo ready".to_string(), + ]; + let open_args = json!({ + "input": open_command.clone(), + "timeout_ms": 200, + }); + + let poll_call_id = "uexec-poll-empty"; + let poll_args = json!({ + "input": Vec::::new(), + "session_id": "0", + "timeout_ms": 150, + }); + + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call( + open_call_id, + "unified_exec", + &serde_json::to_string(&open_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_function_call( + poll_call_id, + "unified_exec", + &serde_json::to_string(&poll_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_response_created("resp-3"), + ev_assistant_message("msg-1", "complete"), + ev_completed("resp-3"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "check poll event behavior".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let mut begin_events = Vec::new(); + loop { + let event_msg = wait_for_event_with_timeout(&codex, |_| true, Duration::from_secs(2)).await; + match event_msg { + EventMsg::ExecCommandBegin(event) => begin_events.push(event), + EventMsg::TaskComplete(_) => break, + _ => {} + } + } + + assert_eq!( + begin_events.len(), + 1, + "expected only the initial command to emit begin event" + ); + assert_eq!(begin_events[0].call_id, open_call_id); + assert_eq!(begin_events[0].command, open_command); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn unified_exec_reuses_session_via_stdin() -> Result<()> { skip_if_no_network!(Ok(()));