From 33d3ecbccca4b92cfb2a77002387de30302f337f Mon Sep 17 00:00:00 2001 From: jif-oai Date: Fri, 3 Oct 2025 13:21:06 +0100 Subject: [PATCH] chore: refactor tool handling (#4510) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Tool System Refactor - Centralizes tool definitions and execution in `core/src/tools/*`: specs (`spec.rs`), handlers (`handlers/*`), router (`router.rs`), registry/dispatch (`registry.rs`), and shared context (`context.rs`). One registry now builds the model-visible tool list and binds handlers. - Router converts model responses to tool calls; Registry dispatches with consistent telemetry via `codex-rs/otel` and unified error handling. Function, Local Shell, MCP, and experimental `unified_exec` all flow through this path; legacy shell aliases still work. - Rationale: reduce per‑tool boilerplate, keep spec/handler in sync, and make adding tools predictable and testable. Example: `read_file` - Spec: `core/src/tools/spec.rs` (see `create_read_file_tool`, registered by `build_specs`). - Handler: `core/src/tools/handlers/read_file.rs` (absolute `file_path`, 1‑indexed `offset`, `limit`, `L#: ` prefixes, safe truncation). - E2E test: `core/tests/suite/read_file.rs` validates the tool returns the requested lines. ## Next steps: - Decompose `handle_container_exec_with_params` - Add parallel tool calls --- codex-rs/Cargo.lock | 5 + codex-rs/Cargo.toml | 2 + codex-rs/core/Cargo.toml | 1 + codex-rs/core/src/client_common.rs | 56 +- codex-rs/core/src/codex.rs | 929 +++--------- codex-rs/core/src/error.rs | 3 + .../core/src/exec_command/responses_api.rs | 2 +- codex-rs/core/src/executor/mod.rs | 2 +- codex-rs/core/src/executor/runner.rs | 3 +- codex-rs/core/src/function_tool.rs | 4 + codex-rs/core/src/lib.rs | 3 +- codex-rs/core/src/model_family.rs | 2 +- codex-rs/core/src/openai_tools.rs | 1190 +--------------- codex-rs/core/src/tools/context.rs | 244 ++++ .../handlers/apply_patch.rs} | 106 +- .../core/src/tools/handlers/exec_stream.rs | 71 + codex-rs/core/src/tools/handlers/mcp.rs | 70 + codex-rs/core/src/tools/handlers/mod.rs | 19 + .../{plan_tool.rs => tools/handlers/plan.rs} | 68 +- codex-rs/core/src/tools/handlers/read_file.rs | 255 ++++ codex-rs/core/src/tools/handlers/shell.rs | 103 ++ .../handlers}/tool_apply_patch.lark | 0 .../core/src/tools/handlers/unified_exec.rs | 112 ++ .../core/src/tools/handlers/view_image.rs | 96 ++ codex-rs/core/src/tools/mod.rs | 280 ++++ codex-rs/core/src/tools/registry.rs | 197 +++ codex-rs/core/src/tools/router.rs | 177 +++ codex-rs/core/src/tools/spec.rs | 1269 +++++++++++++++++ codex-rs/core/tests/common/lib.rs | 18 +- codex-rs/core/tests/suite/mod.rs | 6 + codex-rs/core/tests/suite/model_tools.rs | 124 ++ codex-rs/core/tests/suite/prompt_caching.rs | 8 +- codex-rs/core/tests/suite/read_file.rs | 124 ++ codex-rs/core/tests/suite/tool_harness.rs | 568 ++++++++ codex-rs/core/tests/suite/tools.rs | 450 ++++++ codex-rs/core/tests/suite/unified_exec.rs | 280 ++++ codex-rs/core/tests/suite/view_image.rs | 351 +++++ .../src/event_processor_with_human_output.rs | 4 +- .../src/event_processor_with_jsonl_output.rs | 4 +- codex-rs/exec/src/lib.rs | 2 +- .../tests/event_processor_with_json_output.rs | 11 +- codex-rs/otel/src/otel_event_manager.rs | 14 +- codex-rs/protocol/src/models.rs | 1 + codex-rs/tui/src/chatwidget.rs | 3 +- codex-rs/tui/src/chatwidget/tests.rs | 6 +- codex-rs/tui/src/history_cell.rs | 6 +- codex-rs/utils/string/Cargo.toml | 7 + codex-rs/utils/string/src/lib.rs | 38 + 48 files changed, 5288 insertions(+), 2006 deletions(-) create mode 100644 codex-rs/core/src/tools/context.rs rename codex-rs/core/src/{tool_apply_patch.rs => tools/handlers/apply_patch.rs} (60%) create mode 100644 codex-rs/core/src/tools/handlers/exec_stream.rs create mode 100644 codex-rs/core/src/tools/handlers/mcp.rs create mode 100644 codex-rs/core/src/tools/handlers/mod.rs rename codex-rs/core/src/{plan_tool.rs => tools/handlers/plan.rs} (63%) create mode 100644 codex-rs/core/src/tools/handlers/read_file.rs create mode 100644 codex-rs/core/src/tools/handlers/shell.rs rename codex-rs/core/src/{ => tools/handlers}/tool_apply_patch.lark (100%) create mode 100644 codex-rs/core/src/tools/handlers/unified_exec.rs create mode 100644 codex-rs/core/src/tools/handlers/view_image.rs create mode 100644 codex-rs/core/src/tools/mod.rs create mode 100644 codex-rs/core/src/tools/registry.rs create mode 100644 codex-rs/core/src/tools/router.rs create mode 100644 codex-rs/core/src/tools/spec.rs create mode 100644 codex-rs/core/tests/suite/model_tools.rs create mode 100644 codex-rs/core/tests/suite/read_file.rs create mode 100644 codex-rs/core/tests/suite/tool_harness.rs create mode 100644 codex-rs/core/tests/suite/tools.rs create mode 100644 codex-rs/core/tests/suite/unified_exec.rs create mode 100644 codex-rs/core/tests/suite/view_image.rs create mode 100644 codex-rs/utils/string/Cargo.toml create mode 100644 codex-rs/utils/string/src/lib.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 7a3ca22c..aa9736d2 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -861,6 +861,7 @@ dependencies = [ "codex-otel", "codex-protocol", "codex-rmcp-client", + "codex-utils-string", "core_test_support", "dirs", "dunce", @@ -1254,6 +1255,10 @@ dependencies = [ "tokio", ] +[[package]] +name = "codex-utils-string" +version = "0.0.0" + [[package]] name = "color-eyre" version = "0.6.5" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 06c83c81..f8e88797 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -32,6 +32,7 @@ members = [ "git-apply", "utils/json-to-toml", "utils/readiness", + "utils/string", ] resolver = "2" @@ -71,6 +72,7 @@ codex-rmcp-client = { path = "rmcp-client" } codex-tui = { path = "tui" } codex-utils-json-to-toml = { path = "utils/json-to-toml" } codex-utils-readiness = { path = "utils/readiness" } +codex-utils-string = { path = "utils/string" } core_test_support = { path = "core/tests/common" } mcp-types = { path = "mcp-types" } mcp_test_support = { path = "mcp-server/tests/common" } diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index 8c56e7d1..5054ce98 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -26,6 +26,7 @@ codex-rmcp-client = { workspace = true } codex-protocol = { workspace = true } codex-app-server-protocol = { workspace = true } codex-otel = { workspace = true, features = ["otel"] } +codex-utils-string = { workspace = true } dirs = { workspace = true } dunce = { workspace = true } env-flags = { workspace = true } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index b695581d..624a4ccb 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,6 +1,6 @@ +use crate::client_common::tools::ToolSpec; use crate::error::Result; use crate::model_family::ModelFamily; -use crate::openai_tools::OpenAiTool; use crate::protocol::RateLimitSnapshot; use crate::protocol::TokenUsage; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; @@ -29,7 +29,7 @@ pub struct Prompt { /// Tools available to the model, including additional tools sourced from /// external MCP servers. - pub(crate) tools: Vec, + pub(crate) tools: Vec, /// Optional override for the built-in BASE_INSTRUCTIONS. pub base_instructions_override: Option, @@ -49,8 +49,8 @@ impl Prompt { // AND // - there is no apply_patch tool present let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool { - OpenAiTool::Function(f) => f.name == "apply_patch", - OpenAiTool::Freeform(f) => f.name == "apply_patch", + ToolSpec::Function(f) => f.name == "apply_patch", + ToolSpec::Freeform(f) => f.name == "apply_patch", _ => false, }); if self.base_instructions_override.is_none() @@ -160,6 +160,54 @@ pub(crate) struct ResponsesApiRequest<'a> { pub(crate) text: Option, } +pub(crate) mod tools { + use crate::openai_tools::JsonSchema; + use serde::Deserialize; + use serde::Serialize; + + /// When serialized as JSON, this produces a valid "Tool" in the OpenAI + /// Responses API. + #[derive(Debug, Clone, Serialize, PartialEq)] + #[serde(tag = "type")] + pub(crate) enum ToolSpec { + #[serde(rename = "function")] + Function(ResponsesApiTool), + #[serde(rename = "local_shell")] + LocalShell {}, + // TODO: Understand why we get an error on web_search although the API docs say it's supported. + // https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C + #[serde(rename = "web_search")] + WebSearch {}, + #[serde(rename = "custom")] + Freeform(FreeformTool), + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct FreeformTool { + pub(crate) name: String, + pub(crate) description: String, + pub(crate) format: FreeformToolFormat, + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub struct FreeformToolFormat { + pub(crate) r#type: String, + pub(crate) syntax: String, + pub(crate) definition: String, + } + + #[derive(Debug, Clone, Serialize, PartialEq)] + pub struct ResponsesApiTool { + pub(crate) name: String, + pub(crate) description: String, + /// TODO: Validation. When strict is set to true, the JSON schema, + /// `required` and `additional_properties` must be present. All fields in + /// `properties` must be present in `required`. + pub(crate) strict: bool, + pub(crate) parameters: JsonSchema, + } +} + pub(crate) fn create_reasoning_param_for_request( model_family: &ModelFamily, effort: Option, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2bba2b42..2f13c5ba 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::collections::HashMap; use std::fmt::Debug; use std::path::PathBuf; use std::sync::Arc; @@ -15,8 +14,6 @@ use crate::user_notification::UserNotifier; use async_channel::Receiver; use async_channel::Sender; use codex_apply_patch::ApplyPatchAction; -use codex_apply_patch::MaybeApplyPatchVerified; -use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_protocol::ConversationId; use codex_protocol::protocol::ConversationPathResponseEvent; use codex_protocol::protocol::ExitedReviewModeEvent; @@ -28,8 +25,6 @@ use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnContextItem; use futures::prelude::*; use mcp_types::CallToolResult; -use serde::Deserialize; -use serde::Serialize; use serde_json; use serde_json::Value; use tokio::sync::Mutex; @@ -41,9 +36,6 @@ use tracing::trace; use tracing::warn; use crate::ModelProviderInfo; -use crate::apply_patch; -use crate::apply_patch::ApplyPatchExec; -use crate::apply_patch::InternalApplyPatchInvocation; use crate::apply_patch::convert_apply_patch_to_protocol; use crate::client::ModelClient; use crate::client_common::Prompt; @@ -54,32 +46,21 @@ use crate::conversation_history::ConversationHistory; use crate::environment_context::EnvironmentContext; use crate::error::CodexErr; use crate::error::Result as CodexResult; -use crate::error::SandboxErr; -use crate::exec::ExecParams; use crate::exec::ExecToolCallOutput; -use crate::exec::StdoutStream; #[cfg(test)] use crate::exec::StreamOutput; -use crate::exec_command::EXEC_COMMAND_TOOL_NAME; use crate::exec_command::ExecCommandParams; use crate::exec_command::ExecSessionManager; -use crate::exec_command::WRITE_STDIN_TOOL_NAME; use crate::exec_command::WriteStdinParams; -use crate::exec_env::create_env; -use crate::executor::ExecutionMode; use crate::executor::Executor; use crate::executor::ExecutorConfig; use crate::executor::normalize_exec_result; use crate::mcp_connection_manager::McpConnectionManager; -use crate::mcp_tool_call::handle_mcp_tool_call; use crate::model_family::find_family_for_model; use crate::openai_model_info::get_model_info; -use crate::openai_tools::ApplyPatchToolArgs; use crate::openai_tools::ToolsConfig; use crate::openai_tools::ToolsConfigParams; -use crate::openai_tools::get_openai_tools; use crate::parse_command::parse_command; -use crate::plan_tool::handle_update_plan; use crate::project_doc::get_user_instructions; use crate::protocol::AgentMessageDeltaEvent; use crate::protocol::AgentReasoningDeltaEvent; @@ -94,7 +75,6 @@ use crate::protocol::EventMsg; use crate::protocol::ExecApprovalRequestEvent; use crate::protocol::ExecCommandBeginEvent; use crate::protocol::ExecCommandEndEvent; -use crate::protocol::FileChange; use crate::protocol::InputItem; use crate::protocol::ListCustomPromptsResponseEvent; use crate::protocol::Op; @@ -110,7 +90,6 @@ use crate::protocol::Submission; use crate::protocol::TokenCountEvent; use crate::protocol::TokenUsage; use crate::protocol::TurnDiffEvent; -use crate::protocol::ViewImageToolCallEvent; use crate::protocol::WebSearchBeginEvent; use crate::rollout::RolloutRecorder; use crate::rollout::RolloutRecorderParams; @@ -120,6 +99,8 @@ use crate::state::SessionServices; use crate::tasks::CompactTask; use crate::tasks::RegularTask; use crate::tasks::ReviewTask; +use crate::tools::ToolRouter; +use crate::tools::format_exec_output_str; use crate::turn_diff_tracker::TurnDiffTracker; use crate::unified_exec::UnifiedExecSessionManager; use crate::user_instructions::UserInstructions; @@ -131,10 +112,8 @@ use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig; use codex_protocol::custom_prompts::CustomPrompt; use codex_protocol::models::ContentItem; use codex_protocol::models::FunctionCallOutputPayload; -use codex_protocol::models::LocalShellAction; use codex_protocol::models::ResponseInputItem; use codex_protocol::models::ResponseItem; -use codex_protocol::models::ShellToolCallParams; use codex_protocol::protocol::InitialHistory; pub mod compact; @@ -160,13 +139,6 @@ pub struct CodexSpawnOk { pub(crate) const INITIAL_SUBMIT_ID: &str = ""; pub(crate) const SUBMISSION_CHANNEL_CAPACITY: usize = 64; -// Model-formatting limits: clients get full streams; oonly content sent to the model is truncated. -pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB -pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines -pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2; -pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128 -pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2; - impl Codex { /// Spawn a new [`Codex`] and initialize the session. pub async fn spawn( @@ -266,7 +238,7 @@ pub(crate) struct Session { tx_event: Sender, state: Mutex, pub(crate) active_turn: Mutex>, - services: SessionServices, + pub(crate) services: SessionServices, next_internal_sub_id: AtomicU64, } @@ -289,7 +261,7 @@ pub(crate) struct TurnContext { } impl TurnContext { - fn resolve_path(&self, path: Option) -> PathBuf { + pub(crate) fn resolve_path(&self, path: Option) -> PathBuf { path.as_ref() .map(PathBuf::from) .map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p)) @@ -534,6 +506,10 @@ impl Session { Ok((sess, turn_context)) } + pub(crate) fn get_tx_event(&self) -> Sender { + self.tx_event.clone() + } + fn next_internal_sub_id(&self) -> String { let id = self .next_internal_sub_id @@ -937,7 +913,7 @@ impl Session { /// command even on error. /// /// Returns the output of the exec tool call. - async fn run_exec_with_events( + pub(crate) async fn run_exec_with_events( &self, turn_diff_tracker: &mut TurnDiffTracker, prepared: PreparedExec, @@ -1043,6 +1019,49 @@ impl Session { .await } + pub(crate) fn parse_mcp_tool_name(&self, tool_name: &str) -> Option<(String, String)> { + self.services + .mcp_connection_manager + .parse_tool_name(tool_name) + } + + pub(crate) async fn handle_exec_command_tool( + &self, + params: ExecCommandParams, + ) -> Result { + let result = self + .services + .session_manager + .handle_exec_command_request(params) + .await; + match result { + Ok(output) => Ok(output.to_text_output()), + Err(err) => Err(FunctionCallError::RespondToModel(err)), + } + } + + pub(crate) async fn handle_write_stdin_tool( + &self, + params: WriteStdinParams, + ) -> Result { + self.services + .session_manager + .handle_write_stdin_request(params) + .await + .map(|output| output.to_text_output()) + .map_err(FunctionCallError::RespondToModel) + } + + pub(crate) async fn run_unified_exec_request( + &self, + request: crate::unified_exec::UnifiedExecRequest<'_>, + ) -> Result { + self.services + .unified_exec_manager + .handle_request(request) + .await + } + pub async fn interrupt_task(self: &Arc) { info!("interrupt received: abort current task, if any"); self.abort_all_tasks(TurnAbortReason::Interrupted).await; @@ -1080,23 +1099,6 @@ impl Drop for Session { } } -#[derive(Clone, Debug)] -pub(crate) struct ExecCommandContext { - pub(crate) sub_id: String, - pub(crate) call_id: String, - pub(crate) command_for_display: Vec, - pub(crate) cwd: PathBuf, - pub(crate) apply_patch: Option, - pub(crate) tool_name: String, - pub(crate) otel_event_manager: OtelEventManager, -} - -#[derive(Clone, Debug)] -pub(crate) struct ApplyPatchCommandContext { - pub(crate) user_explicitly_approved_this_action: bool, - pub(crate) changes: HashMap, -} - async fn submission_loop( sess: Arc, turn_context: TurnContext, @@ -1910,24 +1912,32 @@ async fn run_turn( sub_id: String, input: Vec, ) -> CodexResult { - let tools = get_openai_tools( - &turn_context.tools_config, - Some(sess.services.mcp_connection_manager.list_all_tools()), - ); + let mcp_tools = sess.services.mcp_connection_manager.list_all_tools(); + let router = ToolRouter::from_config(&turn_context.tools_config, Some(mcp_tools)); let prompt = Prompt { input, - tools, + tools: router.specs().to_vec(), base_instructions_override: turn_context.base_instructions.clone(), output_schema: turn_context.final_output_json_schema.clone(), }; let mut retries = 0; loop { - match try_run_turn(sess, turn_context, turn_diff_tracker, &sub_id, &prompt).await { + match try_run_turn( + &router, + sess, + turn_context, + turn_diff_tracker, + &sub_id, + &prompt, + ) + .await + { Ok(output) => return Ok(output), Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted), Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), + Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(CodexErr::UsageLimitReached(e)) => { let rate_limits = e.rate_limits.clone(); if let Some(rate_limits) = rate_limits { @@ -1986,6 +1996,7 @@ struct TurnRunResult { } async fn try_run_turn( + router: &crate::tools::ToolRouter, sess: &Session, turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, @@ -2087,6 +2098,7 @@ async fn try_run_turn( ResponseEvent::Created => {} ResponseEvent::OutputItemDone(item) => { let response = handle_response_item( + router, sess, turn_context, turn_diff_tracker, @@ -2177,6 +2189,7 @@ async fn try_run_turn( } async fn handle_response_item( + router: &crate::tools::ToolRouter, sess: &Session, turn_context: &TurnContext, turn_diff_tracker: &mut TurnDiffTracker, @@ -2184,715 +2197,79 @@ async fn handle_response_item( item: ResponseItem, ) -> CodexResult> { debug!(?item, "Output item"); - let output = match item { - ResponseItem::FunctionCall { - name, - arguments, - call_id, - .. - } => { - info!("FunctionCall: {name}({arguments})"); - if let Some((server, tool_name)) = - sess.services.mcp_connection_manager.parse_tool_name(&name) - { - let resp = handle_mcp_tool_call( - sess, - sub_id, - call_id.clone(), - server, - tool_name, - arguments, - ) - .await; - Some(resp) - } else { - let result = turn_context - .client - .get_otel_event_manager() - .log_tool_result(name.as_str(), call_id.as_str(), arguments.as_str(), || { - handle_function_call( - sess, - turn_context, - turn_diff_tracker, - sub_id.to_string(), - name.to_owned(), - arguments.to_owned(), - call_id.clone(), - ) - }) - .await; - let output = match result { - Ok(content) => FunctionCallOutputPayload { - content, - success: Some(true), - }, - Err(FunctionCallError::RespondToModel(msg)) => FunctionCallOutputPayload { - content: msg, - success: Some(false), - }, - }; - Some(ResponseInputItem::FunctionCallOutput { call_id, output }) + match ToolRouter::build_tool_call(sess, item.clone()) { + Ok(Some(call)) => { + let payload_preview = call.payload.log_payload().into_owned(); + tracing::info!("ToolCall: {} {}", call.tool_name, payload_preview); + match router + .dispatch_tool_call(sess, turn_context, turn_diff_tracker, sub_id, call) + .await + { + Ok(response) => Ok(Some(response)), + Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), + Err(other) => unreachable!("non-fatal tool error returned: {other:?}"), } } - ResponseItem::LocalShellCall { - id, - call_id, - status: _, - action, - } => { - let name = "local_shell"; - let LocalShellAction::Exec(action) = action; - tracing::info!("LocalShellCall: {action:?}"); - let params = ShellToolCallParams { - command: action.command, - workdir: action.working_directory, - timeout_ms: action.timeout_ms, - with_escalated_permissions: None, - justification: None, - }; - let effective_call_id = match (call_id, id) { - (Some(call_id), _) => call_id, - (None, Some(id)) => id, - (None, None) => { - let error_message = "LocalShellCall without call_id or id"; - - turn_context - .client - .get_otel_event_manager() - .log_tool_failed(name, error_message); - - error!(error_message); - return Ok(Some(ResponseInputItem::FunctionCallOutput { - call_id: "".to_string(), - output: FunctionCallOutputPayload { - content: error_message.to_string(), - success: None, - }, - })); + Ok(None) => { + match &item { + ResponseItem::Message { .. } + | ResponseItem::Reasoning { .. } + | ResponseItem::WebSearchCall { .. } => { + let msgs = match &item { + ResponseItem::Message { .. } if turn_context.is_review_mode => { + trace!("suppressing assistant Message in review mode"); + Vec::new() + } + _ => map_response_item_to_event_messages( + &item, + sess.show_raw_agent_reasoning(), + ), + }; + for msg in msgs { + let event = Event { + id: sub_id.to_string(), + msg, + }; + sess.send_event(event).await; + } } - }; - - let exec_params = to_exec_params(params, turn_context); - { - let result = turn_context - .client - .get_otel_event_manager() - .log_tool_result( - name, - effective_call_id.as_str(), - exec_params.command.join(" ").as_str(), - || { - handle_container_exec_with_params( - name, - exec_params, - sess, - turn_context, - turn_diff_tracker, - sub_id.to_string(), - effective_call_id.clone(), - ) - }, - ) - .await; - - let output = match result { - Ok(content) => FunctionCallOutputPayload { - content, - success: Some(true), - }, - Err(FunctionCallError::RespondToModel(msg)) => FunctionCallOutputPayload { - content: msg, - success: Some(false), - }, - }; - Some(ResponseInputItem::FunctionCallOutput { - call_id: effective_call_id, - output, - }) + ResponseItem::FunctionCallOutput { .. } + | ResponseItem::CustomToolCallOutput { .. } => { + debug!("unexpected tool output from stream"); + } + _ => {} } + + Ok(None) } - ResponseItem::CustomToolCall { - id: _, - call_id, - name, - input, - status: _, - } => { - let result = turn_context + Err(FunctionCallError::MissingLocalShellCallId) => { + let msg = "LocalShellCall without call_id or id"; + turn_context .client .get_otel_event_manager() - .log_tool_result(name.as_str(), call_id.as_str(), input.as_str(), || { - handle_custom_tool_call( - sess, - turn_context, - turn_diff_tracker, - sub_id.to_string(), - name.to_owned(), - input.to_owned(), - call_id.clone(), - ) - }) - .await; + .log_tool_failed("local_shell", msg); + error!(msg); - let output = match result { - Ok(content) => content, - Err(FunctionCallError::RespondToModel(msg)) => msg, - }; - Some(ResponseInputItem::CustomToolCallOutput { call_id, output }) + Ok(Some(ResponseInputItem::FunctionCallOutput { + call_id: String::new(), + output: FunctionCallOutputPayload { + content: msg.to_string(), + success: None, + }, + })) } - ResponseItem::FunctionCallOutput { .. } => { - debug!("unexpected FunctionCallOutput from stream"); - None + Err(FunctionCallError::RespondToModel(msg)) => { + Ok(Some(ResponseInputItem::FunctionCallOutput { + call_id: String::new(), + output: FunctionCallOutputPayload { + content: msg, + success: None, + }, + })) } - ResponseItem::CustomToolCallOutput { .. } => { - debug!("unexpected CustomToolCallOutput from stream"); - None - } - ResponseItem::Message { .. } - | ResponseItem::Reasoning { .. } - | ResponseItem::WebSearchCall { .. } => { - // In review child threads, suppress assistant message events but - // keep reasoning/web search. - let msgs = match &item { - ResponseItem::Message { .. } if turn_context.is_review_mode => { - trace!("suppressing assistant Message in review mode"); - Vec::new() - } - _ => map_response_item_to_event_messages(&item, sess.show_raw_agent_reasoning()), - }; - for msg in msgs { - let event = Event { - id: sub_id.to_string(), - msg, - }; - sess.send_event(event).await; - } - None - } - ResponseItem::Other => None, - }; - Ok(output) -} - -async fn handle_unified_exec_tool_call( - sess: &Session, - session_id: Option, - arguments: Vec, - timeout_ms: Option, -) -> Result { - let parsed_session_id = if let Some(session_id) = session_id { - match session_id.parse::() { - Ok(parsed) => Some(parsed), - Err(output) => { - return Err(FunctionCallError::RespondToModel(format!( - "invalid session_id: {session_id} due to error {output:?}" - ))); - } - } - } else { - None - }; - - let request = crate::unified_exec::UnifiedExecRequest { - session_id: parsed_session_id, - input_chunks: &arguments, - timeout_ms, - }; - - let value = sess - .services - .unified_exec_manager - .handle_request(request) - .await - .map_err(|err| { - FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}")) - })?; - - #[derive(Serialize)] - struct SerializedUnifiedExecResult { - session_id: Option, - output: String, + Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)), } - - serde_json::to_string(&SerializedUnifiedExecResult { - session_id: value.session_id.map(|id| id.to_string()), - output: value.output, - }) - .map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to serialize unified exec output: {err:?}" - )) - }) -} - -async fn handle_function_call( - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, - sub_id: String, - name: String, - arguments: String, - call_id: String, -) -> Result { - match name.as_str() { - "container.exec" | "shell" => { - let params = parse_container_exec_arguments(arguments, turn_context, &call_id)?; - handle_container_exec_with_params( - name.as_str(), - params, - sess, - turn_context, - turn_diff_tracker, - sub_id, - call_id, - ) - .await - } - "unified_exec" => { - #[derive(Deserialize)] - struct UnifiedExecArgs { - input: Vec, - #[serde(default)] - session_id: Option, - #[serde(default)] - timeout_ms: Option, - } - - let args: UnifiedExecArgs = serde_json::from_str(&arguments).map_err(|err| { - FunctionCallError::RespondToModel(format!( - "failed to parse function arguments: {err:?}" - )) - })?; - - handle_unified_exec_tool_call(sess, args.session_id, args.input, args.timeout_ms).await - } - "view_image" => { - #[derive(serde::Deserialize)] - struct SeeImageArgs { - path: String, - } - let args: SeeImageArgs = serde_json::from_str(&arguments).map_err(|e| { - FunctionCallError::RespondToModel(format!( - "failed to parse function arguments: {e:?}" - )) - })?; - let abs = turn_context.resolve_path(Some(args.path)); - sess.inject_input(vec![InputItem::LocalImage { path: abs.clone() }]) - .await - .map_err(|_| { - FunctionCallError::RespondToModel( - "unable to attach image (no active task)".to_string(), - ) - })?; - sess.send_event(Event { - id: sub_id.clone(), - msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent { - call_id: call_id.clone(), - path: abs, - }), - }) - .await; - - Ok("attached local image path".to_string()) - } - "apply_patch" => { - let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| { - FunctionCallError::RespondToModel(format!( - "failed to parse function arguments: {e:?}" - )) - })?; - let exec_params = ExecParams { - command: vec!["apply_patch".to_string(), args.input.clone()], - cwd: turn_context.cwd.clone(), - timeout_ms: None, - env: HashMap::new(), - with_escalated_permissions: None, - justification: None, - }; - handle_container_exec_with_params( - name.as_str(), - exec_params, - sess, - turn_context, - turn_diff_tracker, - sub_id, - call_id, - ) - .await - } - "update_plan" => handle_update_plan(sess, arguments, sub_id, call_id).await, - EXEC_COMMAND_TOOL_NAME => { - // TODO(mbolin): Sandbox check. - let exec_params: ExecCommandParams = serde_json::from_str(&arguments).map_err(|e| { - FunctionCallError::RespondToModel(format!( - "failed to parse function arguments: {e:?}" - )) - })?; - let result = sess - .services - .session_manager - .handle_exec_command_request(exec_params) - .await; - match result { - Ok(output) => Ok(output.to_text_output()), - Err(err) => Err(FunctionCallError::RespondToModel(err)), - } - } - WRITE_STDIN_TOOL_NAME => { - let write_stdin_params = - serde_json::from_str::(&arguments).map_err(|e| { - FunctionCallError::RespondToModel(format!( - "failed to parse function arguments: {e:?}" - )) - })?; - - let result = sess - .services - .session_manager - .handle_write_stdin_request(write_stdin_params) - .await - .map_err(FunctionCallError::RespondToModel)?; - - Ok(result.to_text_output()) - } - _ => Err(FunctionCallError::RespondToModel(format!( - "unsupported call: {name}" - ))), - } -} - -async fn handle_custom_tool_call( - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, - sub_id: String, - name: String, - input: String, - call_id: String, -) -> Result { - info!("CustomToolCall: {name} {input}"); - match name.as_str() { - "apply_patch" => { - let exec_params = ExecParams { - command: vec!["apply_patch".to_string(), input.clone()], - cwd: turn_context.cwd.clone(), - timeout_ms: None, - env: HashMap::new(), - with_escalated_permissions: None, - justification: None, - }; - - handle_container_exec_with_params( - name.as_str(), - exec_params, - sess, - turn_context, - turn_diff_tracker, - sub_id, - call_id, - ) - .await - } - _ => { - debug!("unexpected CustomToolCall from stream"); - Err(FunctionCallError::RespondToModel(format!( - "unsupported custom tool call: {name}" - ))) - } - } -} - -fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams { - ExecParams { - command: params.command, - cwd: turn_context.resolve_path(params.workdir.clone()), - timeout_ms: params.timeout_ms, - env: create_env(&turn_context.shell_environment_policy), - with_escalated_permissions: params.with_escalated_permissions, - justification: params.justification, - } -} - -fn parse_container_exec_arguments( - arguments: String, - turn_context: &TurnContext, - _call_id: &str, -) -> Result { - serde_json::from_str::(&arguments) - .map(|p| to_exec_params(p, turn_context)) - .map_err(|e| { - FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}")) - }) -} - -async fn handle_container_exec_with_params( - tool_name: &str, - params: ExecParams, - sess: &Session, - turn_context: &TurnContext, - turn_diff_tracker: &mut TurnDiffTracker, - sub_id: String, - call_id: String, -) -> Result { - let otel_event_manager = turn_context.client.get_otel_event_manager(); - - if params.with_escalated_permissions.unwrap_or(false) - && !matches!(turn_context.approval_policy, AskForApproval::OnRequest) - { - return Err(FunctionCallError::RespondToModel(format!( - "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", - policy = turn_context.approval_policy - ))); - } - - // check if this was a patch, and apply it if so - let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { - MaybeApplyPatchVerified::Body(changes) => { - match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await { - InternalApplyPatchInvocation::Output(item) => return item, - InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { - Some(apply_patch_exec) - } - } - } - MaybeApplyPatchVerified::CorrectnessError(parse_error) => { - // It looks like an invocation of `apply_patch`, but we - // could not resolve it into a patch that would apply - // cleanly. Return to model for resample. - return Err(FunctionCallError::RespondToModel(format!( - "error: {parse_error:#?}" - ))); - } - MaybeApplyPatchVerified::ShellParseError(error) => { - trace!("Failed to parse shell command, {error:?}"); - None - } - MaybeApplyPatchVerified::NotApplyPatch => None, - }; - - let command_for_display = if let Some(exec) = apply_patch_exec.as_ref() { - vec!["apply_patch".to_string(), exec.action.patch.clone()] - } else { - params.command.clone() - }; - - let exec_command_context = ExecCommandContext { - sub_id: sub_id.clone(), - call_id: call_id.clone(), - command_for_display: command_for_display.clone(), - cwd: params.cwd.clone(), - apply_patch: apply_patch_exec.as_ref().map( - |ApplyPatchExec { - action, - user_explicitly_approved_this_action, - }| ApplyPatchCommandContext { - user_explicitly_approved_this_action: *user_explicitly_approved_this_action, - changes: convert_apply_patch_to_protocol(action), - }, - ), - tool_name: tool_name.to_string(), - otel_event_manager, - }; - - let mode = match apply_patch_exec { - Some(exec) => ExecutionMode::ApplyPatch(exec), - None => ExecutionMode::Shell, - }; - - sess.services.executor.update_environment( - turn_context.sandbox_policy.clone(), - turn_context.cwd.clone(), - ); - - let prepared_exec = PreparedExec::new( - exec_command_context, - params, - command_for_display, - mode, - Some(StdoutStream { - sub_id: sub_id.clone(), - call_id: call_id.clone(), - tx_event: sess.tx_event.clone(), - }), - turn_context.shell_environment_policy.use_profile, - ); - - let output_result = sess - .run_exec_with_events( - turn_diff_tracker, - prepared_exec, - turn_context.approval_policy, - ) - .await; - - match output_result { - Ok(output) => { - let ExecToolCallOutput { exit_code, .. } = &output; - let content = format_exec_output(&output); - if *exit_code == 0 { - Ok(content) - } else { - Err(FunctionCallError::RespondToModel(content)) - } - } - Err(ExecError::Function(err)) => Err(err), - Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err( - FunctionCallError::RespondToModel(format_exec_output(&output)), - ), - Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!( - "execution error: {err:?}" - ))), - } -} - -fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String { - let ExecToolCallOutput { - aggregated_output, .. - } = exec_output; - - // Head+tail truncation for the model: show the beginning and end with an elision. - // Clients still receive full streams; only this formatted summary is capped. - - let mut s = &aggregated_output.text; - let prefixed_str: String; - - if exec_output.timed_out { - prefixed_str = format!( - "command timed out after {} milliseconds\n", - exec_output.duration.as_millis() - ) + s; - s = &prefixed_str; - } - - let total_lines = s.lines().count(); - if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES { - return s.to_string(); - } - - let lines: Vec<&str> = s.lines().collect(); - let head_take = MODEL_FORMAT_HEAD_LINES.min(lines.len()); - let tail_take = MODEL_FORMAT_TAIL_LINES.min(lines.len().saturating_sub(head_take)); - let omitted = lines.len().saturating_sub(head_take + tail_take); - - // Join head and tail blocks (lines() strips newlines; reinsert them) - let head_block = lines - .iter() - .take(head_take) - .cloned() - .collect::>() - .join("\n"); - let tail_block = if tail_take > 0 { - lines[lines.len() - tail_take..].join("\n") - } else { - String::new() - }; - let marker = format!("\n[... omitted {omitted} of {total_lines} lines ...]\n\n"); - - // Byte budgets for head/tail around the marker - let mut head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES); - let tail_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(head_budget + marker.len()); - if tail_budget == 0 && marker.len() >= MODEL_FORMAT_MAX_BYTES { - // Degenerate case: marker alone exceeds budget; return a clipped marker - return take_bytes_at_char_boundary(&marker, MODEL_FORMAT_MAX_BYTES).to_string(); - } - if tail_budget == 0 { - // Make room for the marker by shrinking head - head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len()); - } - - // Enforce line-count cap by trimming head/tail lines - let head_lines_text = head_block; - let tail_lines_text = tail_block; - // Build final string respecting byte budgets - let head_part = take_bytes_at_char_boundary(&head_lines_text, head_budget); - let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len())); - - result.push_str(head_part); - result.push_str(&marker); - - let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len()); - let tail_budget_final = remaining; - let tail_part = take_last_bytes_at_char_boundary(&tail_lines_text, tail_budget_final); - result.push_str(tail_part); - - result -} - -// Truncate a &str to a byte budget at a char boundary (prefix) -#[inline] -fn take_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { - if s.len() <= maxb { - return s; - } - let mut last_ok = 0; - for (i, ch) in s.char_indices() { - let nb = i + ch.len_utf8(); - if nb > maxb { - break; - } - last_ok = nb; - } - &s[..last_ok] -} - -// Take a suffix of a &str within a byte budget at a char boundary -#[inline] -fn take_last_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { - if s.len() <= maxb { - return s; - } - let mut start = s.len(); - let mut used = 0usize; - for (i, ch) in s.char_indices().rev() { - let nb = ch.len_utf8(); - if used + nb > maxb { - break; - } - start = i; - used += nb; - if start == 0 { - break; - } - } - &s[start..] -} - -/// Exec output is a pre-serialized JSON payload -fn format_exec_output(exec_output: &ExecToolCallOutput) -> String { - let ExecToolCallOutput { - exit_code, - duration, - .. - } = exec_output; - - #[derive(Serialize)] - struct ExecMetadata { - exit_code: i32, - duration_seconds: f32, - } - - #[derive(Serialize)] - struct ExecOutput<'a> { - output: &'a str, - metadata: ExecMetadata, - } - - // round to 1 decimal place - let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0; - - let formatted_output = format_exec_output_str(exec_output); - - let payload = ExecOutput { - output: &formatted_output, - metadata: ExecMetadata { - exit_code: *exit_code, - duration_seconds, - }, - }; - - #[expect(clippy::expect_used)] - serde_json::to_string(&payload).expect("serialize ExecOutput") } pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option { @@ -3006,6 +2383,8 @@ pub(crate) async fn exit_review_mode( use crate::executor::errors::ExecError; use crate::executor::linkers::PreparedExec; +use crate::tools::context::ApplyPatchCommandContext; +use crate::tools::context::ExecCommandContext; #[cfg(test)] pub(crate) use tests::make_session_and_context; @@ -3021,6 +2400,13 @@ mod tests { use crate::state::TaskKind; use crate::tasks::SessionTask; use crate::tasks::SessionTaskContext; + use crate::tools::MODEL_FORMAT_HEAD_LINES; + use crate::tools::MODEL_FORMAT_MAX_BYTES; + use crate::tools::MODEL_FORMAT_MAX_LINES; + use crate::tools::MODEL_FORMAT_TAIL_LINES; + use crate::tools::ToolRouter; + use crate::tools::handle_container_exec_with_params; + use crate::turn_diff_tracker::TurnDiffTracker; use codex_app_server_protocol::AuthMode; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; @@ -3512,6 +2898,43 @@ mod tests { ); } + #[tokio::test] + async fn fatal_tool_error_stops_turn_and_reports_error() { + let (session, turn_context, _rx) = make_session_and_context_with_rx(); + let session_ref = session.as_ref(); + let turn_context_ref = turn_context.as_ref(); + let router = ToolRouter::from_config( + &turn_context_ref.tools_config, + Some(session_ref.services.mcp_connection_manager.list_all_tools()), + ); + let mut tracker = TurnDiffTracker::new(); + let item = ResponseItem::CustomToolCall { + id: None, + status: None, + call_id: "call-1".to_string(), + name: "shell".to_string(), + input: "{}".to_string(), + }; + + let err = handle_response_item( + &router, + session_ref, + turn_context_ref, + &mut tracker, + "sub-id", + item, + ) + .await + .expect_err("expected fatal error"); + + match err { + CodexErr::Fatal(message) => { + assert_eq!(message, "tool shell invoked with incompatible payload"); + } + other => panic!("expected CodexErr::Fatal, got {other:?}"), + } + } + fn sample_rollout( session: &Session, turn_context: &TurnContext, diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index 7482c512..aa093379 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -108,6 +108,9 @@ pub enum CodexErr { #[error("unsupported operation: {0}")] UnsupportedOperation(String), + #[error("Fatal error: {0}")] + Fatal(String), + // ----------------------------------------------------------------- // Automatic conversions for common external error types // ----------------------------------------------------------------- diff --git a/codex-rs/core/src/exec_command/responses_api.rs b/codex-rs/core/src/exec_command/responses_api.rs index 10629f43..24f6d35c 100644 --- a/codex-rs/core/src/exec_command/responses_api.rs +++ b/codex-rs/core/src/exec_command/responses_api.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; +use crate::client_common::tools::ResponsesApiTool; use crate::openai_tools::JsonSchema; -use crate::openai_tools::ResponsesApiTool; pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command"; pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin"; diff --git a/codex-rs/core/src/executor/mod.rs b/codex-rs/core/src/executor/mod.rs index a5a305c6..97d7b292 100644 --- a/codex-rs/core/src/executor/mod.rs +++ b/codex-rs/core/src/executor/mod.rs @@ -10,11 +10,11 @@ pub(crate) use runner::ExecutorConfig; pub(crate) use runner::normalize_exec_result; pub(crate) mod linkers { - use crate::codex::ExecCommandContext; use crate::exec::ExecParams; use crate::exec::StdoutStream; use crate::executor::backends::ExecutionMode; use crate::executor::runner::ExecutionRequest; + use crate::tools::context::ExecCommandContext; pub struct PreparedExec { pub(crate) context: ExecCommandContext, diff --git a/codex-rs/core/src/executor/runner.rs b/codex-rs/core/src/executor/runner.rs index 68cdfb61..f475aad6 100644 --- a/codex-rs/core/src/executor/runner.rs +++ b/codex-rs/core/src/executor/runner.rs @@ -6,7 +6,6 @@ use std::time::Duration; use super::backends::ExecutionMode; use super::backends::backend_for_mode; use super::cache::ApprovalCache; -use crate::codex::ExecCommandContext; use crate::codex::Session; use crate::error::CodexErr; use crate::error::SandboxErr; @@ -24,6 +23,7 @@ use crate::protocol::AskForApproval; use crate::protocol::ReviewDecision; use crate::protocol::SandboxPolicy; use crate::shell; +use crate::tools::context::ExecCommandContext; use codex_otel::otel_event_manager::ToolDecisionSource; #[derive(Clone, Debug)] @@ -303,6 +303,7 @@ pub(crate) fn normalize_exec_result( let message = match err { ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(), ExecError::Codex(e) => get_error_message_ui(e), + err => err.to_string(), }; let synthetic = ExecToolCallOutput { exit_code: -1, diff --git a/codex-rs/core/src/function_tool.rs b/codex-rs/core/src/function_tool.rs index 756cef3e..240e0436 100644 --- a/codex-rs/core/src/function_tool.rs +++ b/codex-rs/core/src/function_tool.rs @@ -4,4 +4,8 @@ use thiserror::Error; pub enum FunctionCallError { #[error("{0}")] RespondToModel(String), + #[error("LocalShellCall without call_id or id")] + MissingLocalShellCallId, + #[error("Fatal error: {0}")] + Fatal(String), } diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 94350a44..0d42f6f1 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -57,7 +57,6 @@ pub mod default_client; pub mod model_family; mod openai_model_info; mod openai_tools; -pub mod plan_tool; pub mod project_doc; mod rollout; pub(crate) mod safety; @@ -65,7 +64,7 @@ pub mod seatbelt; pub mod shell; pub mod spawn; pub mod terminal; -mod tool_apply_patch; +mod tools; pub mod turn_diff_tracker; pub use rollout::ARCHIVED_SESSIONS_SUBDIR; pub use rollout::INTERACTIVE_SESSION_SOURCES; diff --git a/codex-rs/core/src/model_family.rs b/codex-rs/core/src/model_family.rs index 54c18dae..8796d7e7 100644 --- a/codex-rs/core/src/model_family.rs +++ b/codex-rs/core/src/model_family.rs @@ -1,5 +1,5 @@ use crate::config_types::ReasoningSummaryFormat; -use crate::tool_apply_patch::ApplyPatchToolType; +use crate::tools::handlers::apply_patch::ApplyPatchToolType; /// The `instructions` field in the payload sent to a model should always start /// with this content. diff --git a/codex-rs/core/src/openai_tools.rs b/codex-rs/core/src/openai_tools.rs index 9e9b9388..0e10f909 100644 --- a/codex-rs/core/src/openai_tools.rs +++ b/codex-rs/core/src/openai_tools.rs @@ -1,1189 +1 @@ -use serde::Deserialize; -use serde::Serialize; -use serde_json::Value as JsonValue; -use serde_json::json; -use std::collections::BTreeMap; -use std::collections::HashMap; - -use crate::model_family::ModelFamily; -use crate::plan_tool::PLAN_TOOL; -use crate::tool_apply_patch::ApplyPatchToolType; -use crate::tool_apply_patch::create_apply_patch_freeform_tool; -use crate::tool_apply_patch::create_apply_patch_json_tool; - -#[derive(Debug, Clone, Serialize, PartialEq)] -pub struct ResponsesApiTool { - pub(crate) name: String, - pub(crate) description: String, - /// TODO: Validation. When strict is set to true, the JSON schema, - /// `required` and `additional_properties` must be present. All fields in - /// `properties` must be present in `required`. - pub(crate) strict: bool, - pub(crate) parameters: JsonSchema, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct FreeformTool { - pub(crate) name: String, - pub(crate) description: String, - pub(crate) format: FreeformToolFormat, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct FreeformToolFormat { - pub(crate) r#type: String, - pub(crate) syntax: String, - pub(crate) definition: String, -} - -/// When serialized as JSON, this produces a valid "Tool" in the OpenAI -/// Responses API. -#[derive(Debug, Clone, Serialize, PartialEq)] -#[serde(tag = "type")] -pub(crate) enum OpenAiTool { - #[serde(rename = "function")] - Function(ResponsesApiTool), - #[serde(rename = "local_shell")] - LocalShell {}, - // TODO: Understand why we get an error on web_search although the API docs say it's supported. - // https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C - #[serde(rename = "web_search")] - WebSearch {}, - #[serde(rename = "custom")] - Freeform(FreeformTool), -} - -#[derive(Debug, Clone)] -pub enum ConfigShellToolType { - Default, - Local, - Streamable, -} - -#[derive(Debug, Clone)] -pub(crate) struct ToolsConfig { - pub shell_type: ConfigShellToolType, - pub plan_tool: bool, - pub apply_patch_tool_type: Option, - pub web_search_request: bool, - pub include_view_image_tool: bool, - pub experimental_unified_exec_tool: bool, -} - -pub(crate) struct ToolsConfigParams<'a> { - pub(crate) model_family: &'a ModelFamily, - pub(crate) include_plan_tool: bool, - pub(crate) include_apply_patch_tool: bool, - pub(crate) include_web_search_request: bool, - pub(crate) use_streamable_shell_tool: bool, - pub(crate) include_view_image_tool: bool, - pub(crate) experimental_unified_exec_tool: bool, -} - -impl ToolsConfig { - pub fn new(params: &ToolsConfigParams) -> Self { - let ToolsConfigParams { - model_family, - include_plan_tool, - include_apply_patch_tool, - include_web_search_request, - use_streamable_shell_tool, - include_view_image_tool, - experimental_unified_exec_tool, - } = params; - let shell_type = if *use_streamable_shell_tool { - ConfigShellToolType::Streamable - } else if model_family.uses_local_shell_tool { - ConfigShellToolType::Local - } else { - ConfigShellToolType::Default - }; - - let apply_patch_tool_type = match model_family.apply_patch_tool_type { - Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform), - Some(ApplyPatchToolType::Function) => Some(ApplyPatchToolType::Function), - None => { - if *include_apply_patch_tool { - Some(ApplyPatchToolType::Freeform) - } else { - None - } - } - }; - - Self { - shell_type, - plan_tool: *include_plan_tool, - apply_patch_tool_type, - web_search_request: *include_web_search_request, - include_view_image_tool: *include_view_image_tool, - experimental_unified_exec_tool: *experimental_unified_exec_tool, - } - } -} - -/// Whether additional properties are allowed, and if so, any required schema -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(untagged)] -pub(crate) enum AdditionalProperties { - Boolean(bool), - Schema(Box), -} - -impl From for AdditionalProperties { - fn from(b: bool) -> Self { - Self::Boolean(b) - } -} - -impl From for AdditionalProperties { - fn from(s: JsonSchema) -> Self { - Self::Schema(Box::new(s)) - } -} - -/// Generic JSON‑Schema subset needed for our tool definitions -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[serde(tag = "type", rename_all = "lowercase")] -pub(crate) enum JsonSchema { - Boolean { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - String { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - /// MCP schema allows "number" | "integer" for Number - #[serde(alias = "integer")] - Number { - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - Array { - items: Box, - - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - }, - Object { - properties: BTreeMap, - #[serde(skip_serializing_if = "Option::is_none")] - required: Option>, - #[serde( - rename = "additionalProperties", - skip_serializing_if = "Option::is_none" - )] - additional_properties: Option, - }, -} - -fn create_unified_exec_tool() -> OpenAiTool { - let mut properties = BTreeMap::new(); - properties.insert( - "input".to_string(), - JsonSchema::Array { - items: Box::new(JsonSchema::String { description: None }), - description: Some( - "When no session_id is provided, treat the array as the command and arguments \ - to launch. When session_id is set, concatenate the strings (in order) and write \ - them to the session's stdin." - .to_string(), - ), - }, - ); - properties.insert( - "session_id".to_string(), - JsonSchema::String { - description: Some( - "Identifier for an existing interactive session. If omitted, a new command \ - is spawned." - .to_string(), - ), - }, - ); - properties.insert( - "timeout_ms".to_string(), - JsonSchema::Number { - description: Some( - "Maximum time in milliseconds to wait for output after writing the input." - .to_string(), - ), - }, - ); - - OpenAiTool::Function(ResponsesApiTool { - name: "unified_exec".to_string(), - description: - "Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(), - strict: false, - parameters: JsonSchema::Object { - properties, - required: Some(vec!["input".to_string()]), - additional_properties: Some(false.into()), - }, - }) -} - -fn create_shell_tool() -> OpenAiTool { - let mut properties = BTreeMap::new(); - properties.insert( - "command".to_string(), - JsonSchema::Array { - items: Box::new(JsonSchema::String { description: None }), - description: Some("The command to execute".to_string()), - }, - ); - properties.insert( - "workdir".to_string(), - JsonSchema::String { - description: Some("The working directory to execute the command in".to_string()), - }, - ); - properties.insert( - "timeout_ms".to_string(), - JsonSchema::Number { - description: Some("The timeout for the command in milliseconds".to_string()), - }, - ); - - properties.insert( - "with_escalated_permissions".to_string(), - JsonSchema::Boolean { - description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()), - }, - ); - properties.insert( - "justification".to_string(), - JsonSchema::String { - description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()), - }, - ); - - OpenAiTool::Function(ResponsesApiTool { - name: "shell".to_string(), - description: "Runs a shell command and returns its output.".to_string(), - strict: false, - parameters: JsonSchema::Object { - properties, - required: Some(vec!["command".to_string()]), - additional_properties: Some(false.into()), - }, - }) -} - -fn create_view_image_tool() -> OpenAiTool { - // Support only local filesystem path. - let mut properties = BTreeMap::new(); - properties.insert( - "path".to_string(), - JsonSchema::String { - description: Some("Local filesystem path to an image file".to_string()), - }, - ); - - OpenAiTool::Function(ResponsesApiTool { - name: "view_image".to_string(), - description: - "Attach a local image (by filesystem path) to the conversation context for this turn." - .to_string(), - strict: false, - parameters: JsonSchema::Object { - properties, - required: Some(vec!["path".to_string()]), - additional_properties: Some(false.into()), - }, - }) -} -/// TODO(dylan): deprecate once we get rid of json tool -#[derive(Serialize, Deserialize)] -pub(crate) struct ApplyPatchToolArgs { - pub(crate) input: String, -} - -/// Returns JSON values that are compatible with Function Calling in the -/// Responses API: -/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses -pub fn create_tools_json_for_responses_api( - tools: &[OpenAiTool], -) -> crate::error::Result> { - let mut tools_json = Vec::new(); - - for tool in tools { - let json = serde_json::to_value(tool)?; - tools_json.push(json); - } - - Ok(tools_json) -} -/// Returns JSON values that are compatible with Function Calling in the -/// Chat Completions API: -/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat -pub(crate) fn create_tools_json_for_chat_completions_api( - tools: &[OpenAiTool], -) -> crate::error::Result> { - // We start with the JSON for the Responses API and than rewrite it to match - // the chat completions tool call format. - let responses_api_tools_json = create_tools_json_for_responses_api(tools)?; - let tools_json = responses_api_tools_json - .into_iter() - .filter_map(|mut tool| { - if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) { - return None; - } - - if let Some(map) = tool.as_object_mut() { - // Remove "type" field as it is not needed in chat completions. - map.remove("type"); - Some(json!({ - "type": "function", - "function": map, - })) - } else { - None - } - }) - .collect::>(); - Ok(tools_json) -} - -pub(crate) fn mcp_tool_to_openai_tool( - fully_qualified_name: String, - tool: mcp_types::Tool, -) -> Result { - let mcp_types::Tool { - description, - mut input_schema, - .. - } = tool; - - // OpenAI models mandate the "properties" field in the schema. The Agents - // SDK fixed this by inserting an empty object for "properties" if it is not - // already present https://github.com/openai/openai-agents-python/issues/449 - // so here we do the same. - if input_schema.properties.is_none() { - input_schema.properties = Some(serde_json::Value::Object(serde_json::Map::new())); - } - - // Serialize to a raw JSON value so we can sanitize schemas coming from MCP - // servers. Some servers omit the top-level or nested `type` in JSON - // Schemas (e.g. using enum/anyOf), or use unsupported variants like - // `integer`. Our internal JsonSchema is a small subset and requires - // `type`, so we coerce/sanitize here for compatibility. - let mut serialized_input_schema = serde_json::to_value(input_schema)?; - sanitize_json_schema(&mut serialized_input_schema); - let input_schema = serde_json::from_value::(serialized_input_schema)?; - - Ok(ResponsesApiTool { - name: fully_qualified_name, - description: description.unwrap_or_default(), - strict: false, - parameters: input_schema, - }) -} - -/// Sanitize a JSON Schema (as serde_json::Value) so it can fit our limited -/// JsonSchema enum. This function: -/// - Ensures every schema object has a "type". If missing, infers it from -/// common keywords (properties => object, items => array, enum/const/format => string) -/// and otherwise defaults to "string". -/// - Fills required child fields (e.g. array items, object properties) with -/// permissive defaults when absent. -fn sanitize_json_schema(value: &mut JsonValue) { - match value { - JsonValue::Bool(_) => { - // JSON Schema boolean form: true/false. Coerce to an accept-all string. - *value = json!({ "type": "string" }); - } - JsonValue::Array(arr) => { - for v in arr.iter_mut() { - sanitize_json_schema(v); - } - } - JsonValue::Object(map) => { - // First, recursively sanitize known nested schema holders - if let Some(props) = map.get_mut("properties") - && let Some(props_map) = props.as_object_mut() - { - for (_k, v) in props_map.iter_mut() { - sanitize_json_schema(v); - } - } - if let Some(items) = map.get_mut("items") { - sanitize_json_schema(items); - } - // Some schemas use oneOf/anyOf/allOf - sanitize their entries - for combiner in ["oneOf", "anyOf", "allOf", "prefixItems"] { - if let Some(v) = map.get_mut(combiner) { - sanitize_json_schema(v); - } - } - - // Normalize/ensure type - let mut ty = map.get("type").and_then(|v| v.as_str()).map(str::to_string); - - // If type is an array (union), pick first supported; else leave to inference - if ty.is_none() - && let Some(JsonValue::Array(types)) = map.get("type") - { - for t in types { - if let Some(tt) = t.as_str() - && matches!( - tt, - "object" | "array" | "string" | "number" | "integer" | "boolean" - ) - { - ty = Some(tt.to_string()); - break; - } - } - } - - // Infer type if still missing - if ty.is_none() { - if map.contains_key("properties") - || map.contains_key("required") - || map.contains_key("additionalProperties") - { - ty = Some("object".to_string()); - } else if map.contains_key("items") || map.contains_key("prefixItems") { - ty = Some("array".to_string()); - } else if map.contains_key("enum") - || map.contains_key("const") - || map.contains_key("format") - { - ty = Some("string".to_string()); - } else if map.contains_key("minimum") - || map.contains_key("maximum") - || map.contains_key("exclusiveMinimum") - || map.contains_key("exclusiveMaximum") - || map.contains_key("multipleOf") - { - ty = Some("number".to_string()); - } - } - // If we still couldn't infer, default to string - let ty = ty.unwrap_or_else(|| "string".to_string()); - map.insert("type".to_string(), JsonValue::String(ty.to_string())); - - // Ensure object schemas have properties map - if ty == "object" { - if !map.contains_key("properties") { - map.insert( - "properties".to_string(), - JsonValue::Object(serde_json::Map::new()), - ); - } - // If additionalProperties is an object schema, sanitize it too. - // Leave booleans as-is, since JSON Schema allows boolean here. - if let Some(ap) = map.get_mut("additionalProperties") { - let is_bool = matches!(ap, JsonValue::Bool(_)); - if !is_bool { - sanitize_json_schema(ap); - } - } - } - - // Ensure array schemas have items - if ty == "array" && !map.contains_key("items") { - map.insert("items".to_string(), json!({ "type": "string" })); - } - } - _ => {} - } -} - -/// Returns a list of OpenAiTools based on the provided config and MCP tools. -/// Note that the keys of mcp_tools should be fully qualified names. See -/// [`McpConnectionManager`] for more details. -pub(crate) fn get_openai_tools( - config: &ToolsConfig, - mcp_tools: Option>, -) -> Vec { - let mut tools: Vec = Vec::new(); - - if config.experimental_unified_exec_tool { - tools.push(create_unified_exec_tool()); - } else { - match &config.shell_type { - ConfigShellToolType::Default => { - tools.push(create_shell_tool()); - } - ConfigShellToolType::Local => { - tools.push(OpenAiTool::LocalShell {}); - } - ConfigShellToolType::Streamable => { - tools.push(OpenAiTool::Function( - crate::exec_command::create_exec_command_tool_for_responses_api(), - )); - tools.push(OpenAiTool::Function( - crate::exec_command::create_write_stdin_tool_for_responses_api(), - )); - } - } - } - - if config.plan_tool { - tools.push(PLAN_TOOL.clone()); - } - - if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type { - match apply_patch_tool_type { - ApplyPatchToolType::Freeform => { - tools.push(create_apply_patch_freeform_tool()); - } - ApplyPatchToolType::Function => { - tools.push(create_apply_patch_json_tool()); - } - } - } - - if config.web_search_request { - tools.push(OpenAiTool::WebSearch {}); - } - - // Include the view_image tool so the agent can attach images to context. - if config.include_view_image_tool { - tools.push(create_view_image_tool()); - } - if let Some(mcp_tools) = mcp_tools { - // Ensure deterministic ordering to maximize prompt cache hits. - let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect(); - entries.sort_by(|a, b| a.0.cmp(&b.0)); - - for (name, tool) in entries.into_iter() { - match mcp_tool_to_openai_tool(name.clone(), tool.clone()) { - Ok(converted_tool) => tools.push(OpenAiTool::Function(converted_tool)), - Err(e) => { - tracing::error!("Failed to convert {name:?} MCP tool to OpenAI tool: {e:?}"); - } - } - } - } - - tools -} - -#[cfg(test)] -mod tests { - use crate::model_family::find_family_for_model; - use mcp_types::ToolInputSchema; - use pretty_assertions::assert_eq; - - use super::*; - - fn assert_eq_tool_names(tools: &[OpenAiTool], expected_names: &[&str]) { - let tool_names = tools - .iter() - .map(|tool| match tool { - OpenAiTool::Function(ResponsesApiTool { name, .. }) => name, - OpenAiTool::LocalShell {} => "local_shell", - OpenAiTool::WebSearch {} => "web_search", - OpenAiTool::Freeform(FreeformTool { name, .. }) => name, - }) - .collect::>(); - - assert_eq!( - tool_names.len(), - expected_names.len(), - "tool_name mismatch, {tool_names:?}, {expected_names:?}", - ); - for (name, expected_name) in tool_names.iter().zip(expected_names.iter()) { - assert_eq!( - name, expected_name, - "tool_name mismatch, {name:?}, {expected_name:?}" - ); - } - } - - #[test] - fn test_get_openai_tools() { - let model_family = find_family_for_model("codex-mini-latest") - .expect("codex-mini-latest should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: true, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - let tools = get_openai_tools(&config, Some(HashMap::new())); - - assert_eq_tool_names( - &tools, - &["unified_exec", "update_plan", "web_search", "view_image"], - ); - } - - #[test] - fn test_get_openai_tools_default_shell() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: true, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - let tools = get_openai_tools(&config, Some(HashMap::new())); - - assert_eq_tool_names( - &tools, - &["unified_exec", "update_plan", "web_search", "view_image"], - ); - } - - #[test] - fn test_get_openai_tools_mcp_tools() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "test_server/do_something_cool".to_string(), - mcp_types::Tool { - name: "do_something_cool".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "string_argument": { - "type": "string", - }, - "number_argument": { - "type": "number", - }, - "object_argument": { - "type": "object", - "properties": { - "string_property": { "type": "string" }, - "number_property": { "type": "number" }, - }, - "required": [ - "string_property", - "number_property", - ], - "additionalProperties": Some(false), - }, - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("Do something cool".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &[ - "unified_exec", - "web_search", - "view_image", - "test_server/do_something_cool", - ], - ); - - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "test_server/do_something_cool".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_argument".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_argument".to_string(), - JsonSchema::Number { description: None } - ), - ( - "object_argument".to_string(), - JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_property".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_property".to_string(), - JsonSchema::Number { description: None } - ), - ]), - required: Some(vec![ - "string_property".to_string(), - "number_property".to_string(), - ]), - additional_properties: Some(false.into()), - }, - ), - ]), - required: None, - additional_properties: None, - }, - description: "Do something cool".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "test_server/do_something_cool".to_string(), - mcp_types::Tool { - name: "do_something_cool".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "string_argument": { - "type": "string", - }, - "number_argument": { - "type": "number", - }, - "object_argument": { - "type": "object", - "properties": { - "string_property": { "type": "string" }, - "number_property": { "type": "number" }, - }, - "required": [ - "string_property", - "number_property", - ], - "additionalProperties": { - "type": "object", - "properties": { - "addtl_prop": { "type": "string" }, - }, - "required": [ - "addtl_prop", - ], - "additionalProperties": false, - }, - }, - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("Do something cool".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &[ - "unified_exec", - "web_search", - "view_image", - "test_server/do_something_cool", - ], - ); - - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "test_server/do_something_cool".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_argument".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_argument".to_string(), - JsonSchema::Number { description: None } - ), - ( - "object_argument".to_string(), - JsonSchema::Object { - properties: BTreeMap::from([ - ( - "string_property".to_string(), - JsonSchema::String { description: None } - ), - ( - "number_property".to_string(), - JsonSchema::Number { description: None } - ), - ]), - required: Some(vec![ - "string_property".to_string(), - "number_property".to_string(), - ]), - additional_properties: Some( - JsonSchema::Object { - properties: BTreeMap::from([( - "addtl_prop".to_string(), - JsonSchema::String { description: None } - ),]), - required: Some(vec!["addtl_prop".to_string(),]), - additional_properties: Some(false.into()), - } - .into() - ), - }, - ), - ]), - required: None, - additional_properties: None, - }, - description: "Do something cool".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_get_openai_tools_mcp_tools_sorted_by_name() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: false, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - - // Intentionally construct a map with keys that would sort alphabetically. - let tools_map: HashMap = HashMap::from([ - ( - "test_server/do".to_string(), - mcp_types::Tool { - name: "a".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({})), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("a".to_string()), - }, - ), - ( - "test_server/something".to_string(), - mcp_types::Tool { - name: "b".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({})), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("b".to_string()), - }, - ), - ( - "test_server/cool".to_string(), - mcp_types::Tool { - name: "c".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({})), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("c".to_string()), - }, - ), - ]); - - let tools = get_openai_tools(&config, Some(tools_map)); - // Expect unified_exec first, followed by MCP tools sorted by fully-qualified name. - assert_eq_tool_names( - &tools, - &[ - "unified_exec", - "view_image", - "test_server/cool", - "test_server/do", - "test_server/something", - ], - ); - } - - #[test] - fn test_mcp_tool_property_missing_type_defaults_to_string() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "dash/search".to_string(), - mcp_types::Tool { - name: "search".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "query": { - "description": "search query" - } - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("Search docs".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &["unified_exec", "web_search", "view_image", "dash/search"], - ); - - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "dash/search".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "query".to_string(), - JsonSchema::String { - description: Some("search query".to_string()) - } - )]), - required: None, - additional_properties: None, - }, - description: "Search docs".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_mcp_tool_integer_normalized_to_number() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "dash/paginate".to_string(), - mcp_types::Tool { - name: "paginate".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "page": { "type": "integer" } - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("Pagination".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &["unified_exec", "web_search", "view_image", "dash/paginate"], - ); - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "dash/paginate".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "page".to_string(), - JsonSchema::Number { description: None } - )]), - required: None, - additional_properties: None, - }, - description: "Pagination".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_mcp_tool_array_without_items_gets_default_string_items() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "dash/tags".to_string(), - mcp_types::Tool { - name: "tags".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "tags": { "type": "array" } - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("Tags".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &["unified_exec", "web_search", "view_image", "dash/tags"], - ); - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "dash/tags".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "tags".to_string(), - JsonSchema::Array { - items: Box::new(JsonSchema::String { description: None }), - description: None - } - )]), - required: None, - additional_properties: None, - }, - description: "Tags".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_mcp_tool_anyof_defaults_to_string() { - let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); - let config = ToolsConfig::new(&ToolsConfigParams { - model_family: &model_family, - include_plan_tool: false, - include_apply_patch_tool: false, - include_web_search_request: true, - use_streamable_shell_tool: false, - include_view_image_tool: true, - experimental_unified_exec_tool: true, - }); - - let tools = get_openai_tools( - &config, - Some(HashMap::from([( - "dash/value".to_string(), - mcp_types::Tool { - name: "value".to_string(), - input_schema: ToolInputSchema { - properties: Some(serde_json::json!({ - "value": { "anyOf": [ { "type": "string" }, { "type": "number" } ] } - })), - required: None, - r#type: "object".to_string(), - }, - output_schema: None, - title: None, - annotations: None, - description: Some("AnyOf Value".to_string()), - }, - )])), - ); - - assert_eq_tool_names( - &tools, - &["unified_exec", "web_search", "view_image", "dash/value"], - ); - assert_eq!( - tools[3], - OpenAiTool::Function(ResponsesApiTool { - name: "dash/value".to_string(), - parameters: JsonSchema::Object { - properties: BTreeMap::from([( - "value".to_string(), - JsonSchema::String { description: None } - )]), - required: None, - additional_properties: None, - }, - description: "AnyOf Value".to_string(), - strict: false, - }) - ); - } - - #[test] - fn test_shell_tool() { - let tool = super::create_shell_tool(); - let OpenAiTool::Function(ResponsesApiTool { - description, name, .. - }) = &tool - else { - panic!("expected function tool"); - }; - assert_eq!(name, "shell"); - - let expected = "Runs a shell command and returns its output."; - assert_eq!(description, expected); - } -} +pub use crate::tools::spec::*; diff --git a/codex-rs/core/src/tools/context.rs b/codex-rs/core/src/tools/context.rs new file mode 100644 index 00000000..b6b458f1 --- /dev/null +++ b/codex-rs/core/src/tools/context.rs @@ -0,0 +1,244 @@ +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::tools::TELEMETRY_PREVIEW_MAX_BYTES; +use crate::tools::TELEMETRY_PREVIEW_MAX_LINES; +use crate::tools::TELEMETRY_PREVIEW_TRUNCATION_NOTICE; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_otel::otel_event_manager::OtelEventManager; +use codex_protocol::models::FunctionCallOutputPayload; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ShellToolCallParams; +use codex_protocol::protocol::FileChange; +use codex_utils_string::take_bytes_at_char_boundary; +use mcp_types::CallToolResult; +use std::borrow::Cow; +use std::collections::HashMap; +use std::path::PathBuf; + +pub struct ToolInvocation<'a> { + pub session: &'a Session, + pub turn: &'a TurnContext, + pub tracker: &'a mut TurnDiffTracker, + pub sub_id: &'a str, + pub call_id: String, + pub tool_name: String, + pub payload: ToolPayload, +} + +#[derive(Clone)] +pub enum ToolPayload { + Function { + arguments: String, + }, + Custom { + input: String, + }, + LocalShell { + params: ShellToolCallParams, + }, + UnifiedExec { + arguments: String, + }, + Mcp { + server: String, + tool: String, + raw_arguments: String, + }, +} + +impl ToolPayload { + pub fn log_payload(&self) -> Cow<'_, str> { + match self { + ToolPayload::Function { arguments } => Cow::Borrowed(arguments), + ToolPayload::Custom { input } => Cow::Borrowed(input), + ToolPayload::LocalShell { params } => Cow::Owned(params.command.join(" ")), + ToolPayload::UnifiedExec { arguments } => Cow::Borrowed(arguments), + ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments), + } + } +} + +#[derive(Clone)] +pub enum ToolOutput { + Function { + content: String, + success: Option, + }, + Mcp { + result: Result, + }, +} + +impl ToolOutput { + pub fn log_preview(&self) -> String { + match self { + ToolOutput::Function { content, .. } => telemetry_preview(content), + ToolOutput::Mcp { result } => format!("{result:?}"), + } + } + + pub fn success_for_logging(&self) -> bool { + match self { + ToolOutput::Function { success, .. } => success.unwrap_or(true), + ToolOutput::Mcp { result } => result.is_ok(), + } + } + + pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem { + match self { + ToolOutput::Function { content, success } => { + if matches!(payload, ToolPayload::Custom { .. }) { + ResponseInputItem::CustomToolCallOutput { + call_id: call_id.to_string(), + output: content, + } + } else { + ResponseInputItem::FunctionCallOutput { + call_id: call_id.to_string(), + output: FunctionCallOutputPayload { content, success }, + } + } + } + ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput { + call_id: call_id.to_string(), + result, + }, + } + } +} + +fn telemetry_preview(content: &str) -> String { + let truncated_slice = take_bytes_at_char_boundary(content, TELEMETRY_PREVIEW_MAX_BYTES); + let truncated_by_bytes = truncated_slice.len() < content.len(); + + let mut preview = String::new(); + let mut lines_iter = truncated_slice.lines(); + for idx in 0..TELEMETRY_PREVIEW_MAX_LINES { + match lines_iter.next() { + Some(line) => { + if idx > 0 { + preview.push('\n'); + } + preview.push_str(line); + } + None => break, + } + } + let truncated_by_lines = lines_iter.next().is_some(); + + if !truncated_by_bytes && !truncated_by_lines { + return content.to_string(); + } + + if preview.len() < truncated_slice.len() + && truncated_slice + .as_bytes() + .get(preview.len()) + .is_some_and(|byte| *byte == b'\n') + { + preview.push('\n'); + } + + if !preview.is_empty() && !preview.ends_with('\n') { + preview.push('\n'); + } + preview.push_str(TELEMETRY_PREVIEW_TRUNCATION_NOTICE); + + preview +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn custom_tool_calls_should_roundtrip_as_custom_outputs() { + let payload = ToolPayload::Custom { + input: "patch".to_string(), + }; + let response = ToolOutput::Function { + content: "patched".to_string(), + success: Some(true), + } + .into_response("call-42", &payload); + + match response { + ResponseInputItem::CustomToolCallOutput { call_id, output } => { + assert_eq!(call_id, "call-42"); + assert_eq!(output, "patched"); + } + other => panic!("expected CustomToolCallOutput, got {other:?}"), + } + } + + #[test] + fn function_payloads_remain_function_outputs() { + let payload = ToolPayload::Function { + arguments: "{}".to_string(), + }; + let response = ToolOutput::Function { + content: "ok".to_string(), + success: Some(true), + } + .into_response("fn-1", &payload); + + match response { + ResponseInputItem::FunctionCallOutput { call_id, output } => { + assert_eq!(call_id, "fn-1"); + assert_eq!(output.content, "ok"); + assert_eq!(output.success, Some(true)); + } + other => panic!("expected FunctionCallOutput, got {other:?}"), + } + } + + #[test] + fn telemetry_preview_returns_original_within_limits() { + let content = "short output"; + assert_eq!(telemetry_preview(content), content); + } + + #[test] + fn telemetry_preview_truncates_by_bytes() { + let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8); + let preview = telemetry_preview(&content); + + assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); + assert!( + preview.len() + <= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1 + ); + } + + #[test] + fn telemetry_preview_truncates_by_lines() { + let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5)) + .map(|idx| format!("line {idx}")) + .collect::>() + .join("\n"); + + let preview = telemetry_preview(&content); + let lines: Vec<&str> = preview.lines().collect(); + + assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1); + assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE)); + } +} + +#[derive(Clone, Debug)] +pub(crate) struct ExecCommandContext { + pub(crate) sub_id: String, + pub(crate) call_id: String, + pub(crate) command_for_display: Vec, + pub(crate) cwd: PathBuf, + pub(crate) apply_patch: Option, + pub(crate) tool_name: String, + pub(crate) otel_event_manager: OtelEventManager, +} + +#[derive(Clone, Debug)] +pub(crate) struct ApplyPatchCommandContext { + pub(crate) user_explicitly_approved_this_action: bool, + pub(crate) changes: HashMap, +} diff --git a/codex-rs/core/src/tool_apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs similarity index 60% rename from codex-rs/core/src/tool_apply_patch.rs rename to codex-rs/core/src/tools/handlers/apply_patch.rs index 5f34b0d2..ced0898a 100644 --- a/codex-rs/core/src/tool_apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -1,15 +1,99 @@ +use std::collections::BTreeMap; +use std::collections::HashMap; + +use crate::client_common::tools::FreeformTool; +use crate::client_common::tools::FreeformToolFormat; +use crate::client_common::tools::ResponsesApiTool; +use crate::client_common::tools::ToolSpec; +use crate::exec::ExecParams; +use crate::function_tool::FunctionCallError; +use crate::openai_tools::JsonSchema; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handle_container_exec_with_params; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::tools::spec::ApplyPatchToolArgs; +use async_trait::async_trait; use serde::Deserialize; use serde::Serialize; -use std::collections::BTreeMap; -use crate::openai_tools::FreeformTool; -use crate::openai_tools::FreeformToolFormat; -use crate::openai_tools::JsonSchema; -use crate::openai_tools::OpenAiTool; -use crate::openai_tools::ResponsesApiTool; +pub struct ApplyPatchHandler; const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark"); +#[async_trait] +impl ToolHandler for ApplyPatchHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!( + payload, + ToolPayload::Function { .. } | ToolPayload::Custom { .. } + ) + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + turn, + tracker, + sub_id, + call_id, + tool_name, + payload, + } = invocation; + + let patch_input = match payload { + ToolPayload::Function { arguments } => { + let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + args.input + } + ToolPayload::Custom { input } => input, + _ => { + return Err(FunctionCallError::RespondToModel( + "apply_patch handler received unsupported payload".to_string(), + )); + } + }; + + let exec_params = ExecParams { + command: vec!["apply_patch".to_string(), patch_input.clone()], + cwd: turn.cwd.clone(), + timeout_ms: None, + env: HashMap::new(), + with_escalated_permissions: None, + justification: None, + }; + + let content = handle_container_exec_with_params( + tool_name.as_str(), + exec_params, + session, + turn, + tracker, + sub_id.to_string(), + call_id.clone(), + ) + .await?; + + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] #[serde(rename_all = "snake_case")] pub enum ApplyPatchToolType { @@ -19,8 +103,8 @@ pub enum ApplyPatchToolType { /// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models /// https://platform.openai.com/docs/guides/function-calling#custom-tools -pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool { - OpenAiTool::Freeform(FreeformTool { +pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec { + ToolSpec::Freeform(FreeformTool { name: "apply_patch".to_string(), description: "Use the `apply_patch` tool to edit files".to_string(), format: FreeformToolFormat { @@ -32,7 +116,7 @@ pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool { } /// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models -pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool { +pub(crate) fn create_apply_patch_json_tool() -> ToolSpec { let mut properties = BTreeMap::new(); properties.insert( "input".to_string(), @@ -41,7 +125,7 @@ pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool { }, ); - OpenAiTool::Function(ResponsesApiTool { + ToolSpec::Function(ResponsesApiTool { name: "apply_patch".to_string(), description: r#"Use the `apply_patch` tool to edit files. Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope: @@ -111,7 +195,7 @@ It is important to remember: - You must prefix new lines with `+` even when creating a new file - File references can only be relative, NEVER ABSOLUTE. "# - .to_string(), + .to_string(), strict: false, parameters: JsonSchema::Object { properties, diff --git a/codex-rs/core/src/tools/handlers/exec_stream.rs b/codex-rs/core/src/tools/handlers/exec_stream.rs new file mode 100644 index 00000000..db9d4b0b --- /dev/null +++ b/codex-rs/core/src/tools/handlers/exec_stream.rs @@ -0,0 +1,71 @@ +use async_trait::async_trait; + +use crate::exec_command::EXEC_COMMAND_TOOL_NAME; +use crate::exec_command::ExecCommandParams; +use crate::exec_command::WRITE_STDIN_TOOL_NAME; +use crate::exec_command::WriteStdinParams; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct ExecStreamHandler; + +#[async_trait] +impl ToolHandler for ExecStreamHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + tool_name, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "exec_stream handler received unsupported payload".to_string(), + )); + } + }; + + let content = match tool_name.as_str() { + EXEC_COMMAND_TOOL_NAME => { + let params: ExecCommandParams = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + session.handle_exec_command_tool(params).await? + } + WRITE_STDIN_TOOL_NAME => { + let params: WriteStdinParams = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + session.handle_write_stdin_tool(params).await? + } + _ => { + return Err(FunctionCallError::RespondToModel(format!( + "exec_stream handler does not support tool {tool_name}" + ))); + } + }; + + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } +} diff --git a/codex-rs/core/src/tools/handlers/mcp.rs b/codex-rs/core/src/tools/handlers/mcp.rs new file mode 100644 index 00000000..17eae7ea --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mcp.rs @@ -0,0 +1,70 @@ +use async_trait::async_trait; + +use crate::function_tool::FunctionCallError; +use crate::mcp_tool_call::handle_mcp_tool_call; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct McpHandler; + +#[async_trait] +impl ToolHandler for McpHandler { + fn kind(&self) -> ToolKind { + ToolKind::Mcp + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + sub_id, + call_id, + payload, + .. + } = invocation; + + let payload = match payload { + ToolPayload::Mcp { + server, + tool, + raw_arguments, + } => (server, tool, raw_arguments), + _ => { + return Err(FunctionCallError::RespondToModel( + "mcp handler received unsupported payload".to_string(), + )); + } + }; + + let (server, tool, raw_arguments) = payload; + let arguments_str = raw_arguments; + + let response = handle_mcp_tool_call( + session, + sub_id, + call_id.clone(), + server, + tool, + arguments_str, + ) + .await; + + match response { + codex_protocol::models::ResponseInputItem::McpToolCallOutput { result, .. } => { + Ok(ToolOutput::Mcp { result }) + } + codex_protocol::models::ResponseInputItem::FunctionCallOutput { output, .. } => { + let codex_protocol::models::FunctionCallOutputPayload { content, success } = output; + Ok(ToolOutput::Function { content, success }) + } + _ => Err(FunctionCallError::RespondToModel( + "mcp handler received unexpected response variant".to_string(), + )), + } + } +} diff --git a/codex-rs/core/src/tools/handlers/mod.rs b/codex-rs/core/src/tools/handlers/mod.rs new file mode 100644 index 00000000..af410b99 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/mod.rs @@ -0,0 +1,19 @@ +pub mod apply_patch; +mod exec_stream; +mod mcp; +mod plan; +mod read_file; +mod shell; +mod unified_exec; +mod view_image; + +pub use plan::PLAN_TOOL; + +pub use apply_patch::ApplyPatchHandler; +pub use exec_stream::ExecStreamHandler; +pub use mcp::McpHandler; +pub use plan::PlanHandler; +pub use read_file::ReadFileHandler; +pub use shell::ShellHandler; +pub use unified_exec::UnifiedExecHandler; +pub use view_image::ViewImageHandler; diff --git a/codex-rs/core/src/plan_tool.rs b/codex-rs/core/src/tools/handlers/plan.rs similarity index 63% rename from codex-rs/core/src/plan_tool.rs rename to codex-rs/core/src/tools/handlers/plan.rs index e0fdb565..f5208030 100644 --- a/codex-rs/core/src/plan_tool.rs +++ b/codex-rs/core/src/tools/handlers/plan.rs @@ -1,23 +1,23 @@ -use std::collections::BTreeMap; -use std::sync::LazyLock; - +use crate::client_common::tools::ResponsesApiTool; +use crate::client_common::tools::ToolSpec; use crate::codex::Session; use crate::function_tool::FunctionCallError; use crate::openai_tools::JsonSchema; -use crate::openai_tools::OpenAiTool; -use crate::openai_tools::ResponsesApiTool; -use crate::protocol::Event; -use crate::protocol::EventMsg; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use async_trait::async_trait; +use codex_protocol::plan_tool::UpdatePlanArgs; +use codex_protocol::protocol::Event; +use codex_protocol::protocol::EventMsg; +use std::collections::BTreeMap; +use std::sync::LazyLock; -// Use the canonical plan tool types from the protocol crate to ensure -// type-identity matches events transported via `codex_protocol`. -pub use codex_protocol::plan_tool::PlanItemArg; -pub use codex_protocol::plan_tool::StepStatus; -pub use codex_protocol::plan_tool::UpdatePlanArgs; +pub struct PlanHandler; -// Types for the TODO tool arguments matching codex-vscode/todo-mcp/src/main.rs - -pub(crate) static PLAN_TOOL: LazyLock = LazyLock::new(|| { +pub static PLAN_TOOL: LazyLock = LazyLock::new(|| { let mut plan_item_props = BTreeMap::new(); plan_item_props.insert("step".to_string(), JsonSchema::String { description: None }); plan_item_props.insert( @@ -43,7 +43,7 @@ pub(crate) static PLAN_TOOL: LazyLock = LazyLock::new(|| { ); properties.insert("plan".to_string(), plan_items_schema); - OpenAiTool::Function(ResponsesApiTool { + ToolSpec::Function(ResponsesApiTool { name: "update_plan".to_string(), description: r#"Updates the task plan. Provide an optional explanation and a list of plan items, each with a step and status. @@ -59,6 +59,42 @@ At most one step can be in_progress at a time. }) }); +#[async_trait] +impl ToolHandler for PlanHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + sub_id, + call_id, + payload, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "update_plan handler received unsupported payload".to_string(), + )); + } + }; + + let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?; + + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } +} + /// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render. /// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other /// than forcing it to come up and document a plan (TBD how that affects performance). diff --git a/codex-rs/core/src/tools/handlers/read_file.rs b/codex-rs/core/src/tools/handlers/read_file.rs new file mode 100644 index 00000000..4988593b --- /dev/null +++ b/codex-rs/core/src/tools/handlers/read_file.rs @@ -0,0 +1,255 @@ +use std::path::Path; +use std::path::PathBuf; + +use async_trait::async_trait; +use codex_utils_string::take_bytes_at_char_boundary; +use serde::Deserialize; +use tokio::fs::File; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct ReadFileHandler; + +const MAX_LINE_LENGTH: usize = 500; + +fn default_offset() -> usize { + 1 +} + +fn default_limit() -> usize { + 2000 +} + +#[derive(Deserialize)] +struct ReadFileArgs { + file_path: String, + #[serde(default = "default_offset")] + offset: usize, + #[serde(default = "default_limit")] + limit: usize, +} + +#[async_trait] +impl ToolHandler for ReadFileHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { payload, .. } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "read_file handler received unsupported payload".to_string(), + )); + } + }; + + let args: ReadFileArgs = serde_json::from_str(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {err:?}" + )) + })?; + + let ReadFileArgs { + file_path, + offset, + limit, + } = args; + + if offset == 0 { + return Err(FunctionCallError::RespondToModel( + "offset must be a 1-indexed line number".to_string(), + )); + } + + if limit == 0 { + return Err(FunctionCallError::RespondToModel( + "limit must be greater than zero".to_string(), + )); + } + + let path = PathBuf::from(&file_path); + if !path.is_absolute() { + return Err(FunctionCallError::RespondToModel( + "file_path must be an absolute path".to_string(), + )); + } + + let collected = read_file_slice(&path, offset, limit).await?; + Ok(ToolOutput::Function { + content: collected.join("\n"), + success: Some(true), + }) + } +} + +async fn read_file_slice( + path: &Path, + offset: usize, + limit: usize, +) -> Result, FunctionCallError> { + let file = File::open(path) + .await + .map_err(|err| FunctionCallError::RespondToModel(format!("failed to read file: {err}")))?; + + let mut reader = BufReader::new(file); + let mut collected = Vec::new(); + let mut seen = 0usize; + let mut buffer = Vec::new(); + + loop { + buffer.clear(); + let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to read file: {err}")) + })?; + + if bytes_read == 0 { + break; + } + + if buffer.last() == Some(&b'\n') { + buffer.pop(); + if buffer.last() == Some(&b'\r') { + buffer.pop(); + } + } + + seen += 1; + + if seen < offset { + continue; + } + + if collected.len() == limit { + break; + } + + let formatted = format_line(&buffer); + collected.push(format!("L{seen}: {formatted}")); + + if collected.len() == limit { + break; + } + } + + if seen < offset { + return Err(FunctionCallError::RespondToModel( + "offset exceeds file length".to_string(), + )); + } + + Ok(collected) +} + +fn format_line(bytes: &[u8]) -> String { + let decoded = String::from_utf8_lossy(bytes); + if decoded.len() > MAX_LINE_LENGTH { + take_bytes_at_char_boundary(&decoded, MAX_LINE_LENGTH).to_string() + } else { + decoded.into_owned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[tokio::test] + async fn reads_requested_range() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + writeln!(temp, "alpha").unwrap(); + writeln!(temp, "beta").unwrap(); + writeln!(temp, "gamma").unwrap(); + + let lines = read_file_slice(temp.path(), 2, 2) + .await + .expect("read slice"); + assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]); + } + + #[tokio::test] + async fn errors_when_offset_exceeds_length() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + writeln!(temp, "only").unwrap(); + + let err = read_file_slice(temp.path(), 3, 1) + .await + .expect_err("offset exceeds length"); + assert_eq!( + err, + FunctionCallError::RespondToModel("offset exceeds file length".to_string()) + ); + } + + #[tokio::test] + async fn reads_non_utf8_lines() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + temp.as_file_mut().write_all(b"\xff\xfe\nplain\n").unwrap(); + + let lines = read_file_slice(temp.path(), 1, 2) + .await + .expect("read slice"); + let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}'); + assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]); + } + + #[tokio::test] + async fn trims_crlf_endings() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + write!(temp, "one\r\ntwo\r\n").unwrap(); + + let lines = read_file_slice(temp.path(), 1, 2) + .await + .expect("read slice"); + assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]); + } + + #[tokio::test] + async fn respects_limit_even_with_more_lines() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + writeln!(temp, "first").unwrap(); + writeln!(temp, "second").unwrap(); + writeln!(temp, "third").unwrap(); + + let lines = read_file_slice(temp.path(), 1, 2) + .await + .expect("read slice"); + assert_eq!( + lines, + vec!["L1: first".to_string(), "L2: second".to_string()] + ); + } + + #[tokio::test] + async fn truncates_lines_longer_than_max_length() { + let mut temp = NamedTempFile::new().expect("create temp file"); + use std::io::Write as _; + let long_line = "x".repeat(MAX_LINE_LENGTH + 50); + writeln!(temp, "{long_line}").unwrap(); + + let lines = read_file_slice(temp.path(), 1, 1) + .await + .expect("read slice"); + let expected = "x".repeat(MAX_LINE_LENGTH); + assert_eq!(lines, vec![format!("L1: {expected}")]); + } +} diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs new file mode 100644 index 00000000..fbcb493e --- /dev/null +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -0,0 +1,103 @@ +use async_trait::async_trait; +use codex_protocol::models::ShellToolCallParams; + +use crate::codex::TurnContext; +use crate::exec::ExecParams; +use crate::exec_env::create_env; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::handle_container_exec_with_params; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct ShellHandler; + +impl ShellHandler { + fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams { + ExecParams { + command: params.command, + cwd: turn_context.resolve_path(params.workdir.clone()), + timeout_ms: params.timeout_ms, + env: create_env(&turn_context.shell_environment_policy), + with_escalated_permissions: params.with_escalated_permissions, + justification: params.justification, + } + } +} + +#[async_trait] +impl ToolHandler for ShellHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!( + payload, + ToolPayload::Function { .. } | ToolPayload::LocalShell { .. } + ) + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + turn, + tracker, + sub_id, + call_id, + tool_name, + payload, + } = invocation; + + match payload { + ToolPayload::Function { arguments } => { + let params: ShellToolCallParams = + serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {e:?}" + )) + })?; + let exec_params = Self::to_exec_params(params, turn); + let content = handle_container_exec_with_params( + tool_name.as_str(), + exec_params, + session, + turn, + tracker, + sub_id.to_string(), + call_id.clone(), + ) + .await?; + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } + ToolPayload::LocalShell { params } => { + let exec_params = Self::to_exec_params(params, turn); + let content = handle_container_exec_with_params( + tool_name.as_str(), + exec_params, + session, + turn, + tracker, + sub_id.to_string(), + call_id.clone(), + ) + .await?; + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } + _ => Err(FunctionCallError::RespondToModel(format!( + "unsupported payload for shell handler: {tool_name}" + ))), + } + } +} diff --git a/codex-rs/core/src/tool_apply_patch.lark b/codex-rs/core/src/tools/handlers/tool_apply_patch.lark similarity index 100% rename from codex-rs/core/src/tool_apply_patch.lark rename to codex-rs/core/src/tools/handlers/tool_apply_patch.lark diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs new file mode 100644 index 00000000..7175afb9 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -0,0 +1,112 @@ +use async_trait::async_trait; +use serde::Deserialize; + +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; +use crate::unified_exec::UnifiedExecRequest; + +pub struct UnifiedExecHandler; + +#[derive(Deserialize)] +struct UnifiedExecArgs { + input: Vec, + #[serde(default)] + session_id: Option, + #[serde(default)] + timeout_ms: Option, +} + +#[async_trait] +impl ToolHandler for UnifiedExecHandler { + fn kind(&self) -> ToolKind { + ToolKind::UnifiedExec + } + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!( + payload, + ToolPayload::UnifiedExec { .. } | ToolPayload::Function { .. } + ) + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, payload, .. + } = invocation; + + let args = match payload { + ToolPayload::UnifiedExec { arguments } | ToolPayload::Function { arguments } => { + serde_json::from_str::(&arguments).map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to parse function arguments: {err:?}" + )) + })? + } + _ => { + return Err(FunctionCallError::RespondToModel( + "unified_exec handler received unsupported payload".to_string(), + )); + } + }; + + let UnifiedExecArgs { + input, + session_id, + timeout_ms, + } = args; + + let parsed_session_id = if let Some(session_id) = session_id { + match session_id.parse::() { + Ok(parsed) => Some(parsed), + Err(output) => { + return Err(FunctionCallError::RespondToModel(format!( + "invalid session_id: {session_id} due to error {output:?}" + ))); + } + } + } else { + None + }; + + let request = UnifiedExecRequest { + session_id: parsed_session_id, + input_chunks: &input, + timeout_ms, + }; + + let value = session + .run_unified_exec_request(request) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}")) + })?; + + #[derive(serde::Serialize)] + struct SerializedUnifiedExecResult { + session_id: Option, + output: String, + } + + let content = serde_json::to_string(&SerializedUnifiedExecResult { + session_id: value.session_id.map(|id| id.to_string()), + output: value.output, + }) + .map_err(|err| { + FunctionCallError::RespondToModel(format!( + "failed to serialize unified exec output: {err:?}" + )) + })?; + + Ok(ToolOutput::Function { + content, + success: Some(true), + }) + } +} diff --git a/codex-rs/core/src/tools/handlers/view_image.rs b/codex-rs/core/src/tools/handlers/view_image.rs new file mode 100644 index 00000000..4ebfd8f3 --- /dev/null +++ b/codex-rs/core/src/tools/handlers/view_image.rs @@ -0,0 +1,96 @@ +use async_trait::async_trait; +use serde::Deserialize; +use tokio::fs; + +use crate::function_tool::FunctionCallError; +use crate::protocol::Event; +use crate::protocol::EventMsg; +use crate::protocol::InputItem; +use crate::protocol::ViewImageToolCallEvent; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolHandler; +use crate::tools::registry::ToolKind; + +pub struct ViewImageHandler; + +#[derive(Deserialize)] +struct ViewImageArgs { + path: String, +} + +#[async_trait] +impl ToolHandler for ViewImageHandler { + fn kind(&self) -> ToolKind { + ToolKind::Function + } + + async fn handle( + &self, + invocation: ToolInvocation<'_>, + ) -> Result { + let ToolInvocation { + session, + turn, + payload, + sub_id, + call_id, + .. + } = invocation; + + let arguments = match payload { + ToolPayload::Function { arguments } => arguments, + _ => { + return Err(FunctionCallError::RespondToModel( + "view_image handler received unsupported payload".to_string(), + )); + } + }; + + let args: ViewImageArgs = serde_json::from_str(&arguments).map_err(|e| { + FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}")) + })?; + + let abs_path = turn.resolve_path(Some(args.path)); + + let metadata = fs::metadata(&abs_path).await.map_err(|error| { + FunctionCallError::RespondToModel(format!( + "unable to locate image at `{}`: {error}", + abs_path.display() + )) + })?; + + if !metadata.is_file() { + return Err(FunctionCallError::RespondToModel(format!( + "image path `{}` is not a file", + abs_path.display() + ))); + } + let event_path = abs_path.clone(); + + session + .inject_input(vec![InputItem::LocalImage { path: abs_path }]) + .await + .map_err(|_| { + FunctionCallError::RespondToModel( + "unable to attach image (no active task)".to_string(), + ) + })?; + + session + .send_event(Event { + id: sub_id.to_string(), + msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent { + call_id, + path: event_path, + }), + }) + .await; + + Ok(ToolOutput::Function { + content: "attached local image path".to_string(), + success: Some(true), + }) + } +} diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs new file mode 100644 index 00000000..5a120d09 --- /dev/null +++ b/codex-rs/core/src/tools/mod.rs @@ -0,0 +1,280 @@ +pub mod context; +pub(crate) mod handlers; +pub mod registry; +pub mod router; +pub mod spec; + +use crate::apply_patch; +use crate::apply_patch::ApplyPatchExec; +use crate::apply_patch::InternalApplyPatchInvocation; +use crate::apply_patch::convert_apply_patch_to_protocol; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::error::CodexErr; +use crate::error::SandboxErr; +use crate::exec::ExecParams; +use crate::exec::ExecToolCallOutput; +use crate::exec::StdoutStream; +use crate::executor::ExecutionMode; +use crate::executor::errors::ExecError; +use crate::executor::linkers::PreparedExec; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ApplyPatchCommandContext; +use crate::tools::context::ExecCommandContext; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_apply_patch::MaybeApplyPatchVerified; +use codex_apply_patch::maybe_parse_apply_patch_verified; +use codex_protocol::protocol::AskForApproval; +use codex_utils_string::take_bytes_at_char_boundary; +use codex_utils_string::take_last_bytes_at_char_boundary; +pub use router::ToolRouter; +use serde::Serialize; +use tracing::trace; + +// Model-formatting limits: clients get full streams; only content sent to the model is truncated. +pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB +pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines +pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2; +pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128 +pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2; + +// Telemetry preview limits: keep log events smaller than model budgets. +pub(crate) const TELEMETRY_PREVIEW_MAX_BYTES: usize = 2 * 1024; // 2 KiB +pub(crate) const TELEMETRY_PREVIEW_MAX_LINES: usize = 64; // lines +pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str = + "[... telemetry preview truncated ...]"; + +// TODO(jif) break this down +pub(crate) async fn handle_container_exec_with_params( + tool_name: &str, + params: ExecParams, + sess: &Session, + turn_context: &TurnContext, + turn_diff_tracker: &mut TurnDiffTracker, + sub_id: String, + call_id: String, +) -> Result { + let otel_event_manager = turn_context.client.get_otel_event_manager(); + + if params.with_escalated_permissions.unwrap_or(false) + && !matches!(turn_context.approval_policy, AskForApproval::OnRequest) + { + return Err(FunctionCallError::RespondToModel(format!( + "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}", + policy = turn_context.approval_policy + ))); + } + + // check if this was a patch, and apply it if so + let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) { + MaybeApplyPatchVerified::Body(changes) => { + match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await { + InternalApplyPatchInvocation::Output(item) => return item, + InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => { + Some(apply_patch_exec) + } + } + } + MaybeApplyPatchVerified::CorrectnessError(parse_error) => { + // It looks like an invocation of `apply_patch`, but we + // could not resolve it into a patch that would apply + // cleanly. Return to model for resample. + return Err(FunctionCallError::RespondToModel(format!( + "apply_patch verification failed: {parse_error}" + ))); + } + MaybeApplyPatchVerified::ShellParseError(error) => { + trace!("Failed to parse shell command, {error:?}"); + None + } + MaybeApplyPatchVerified::NotApplyPatch => None, + }; + + let command_for_display = if let Some(exec) = apply_patch_exec.as_ref() { + vec!["apply_patch".to_string(), exec.action.patch.clone()] + } else { + params.command.clone() + }; + + let exec_command_context = ExecCommandContext { + sub_id: sub_id.clone(), + call_id: call_id.clone(), + command_for_display: command_for_display.clone(), + cwd: params.cwd.clone(), + apply_patch: apply_patch_exec.as_ref().map( + |ApplyPatchExec { + action, + user_explicitly_approved_this_action, + }| ApplyPatchCommandContext { + user_explicitly_approved_this_action: *user_explicitly_approved_this_action, + changes: convert_apply_patch_to_protocol(action), + }, + ), + tool_name: tool_name.to_string(), + otel_event_manager, + }; + + let mode = match apply_patch_exec { + Some(exec) => ExecutionMode::ApplyPatch(exec), + None => ExecutionMode::Shell, + }; + + sess.services.executor.update_environment( + turn_context.sandbox_policy.clone(), + turn_context.cwd.clone(), + ); + + let prepared_exec = PreparedExec::new( + exec_command_context, + params, + command_for_display, + mode, + Some(StdoutStream { + sub_id: sub_id.clone(), + call_id: call_id.clone(), + tx_event: sess.get_tx_event(), + }), + turn_context.shell_environment_policy.use_profile, + ); + + let output_result = sess + .run_exec_with_events( + turn_diff_tracker, + prepared_exec, + turn_context.approval_policy, + ) + .await; + + match output_result { + Ok(output) => { + let ExecToolCallOutput { exit_code, .. } = &output; + let content = format_exec_output_apply_patch(&output); + if *exit_code == 0 { + Ok(content) + } else { + Err(FunctionCallError::RespondToModel(content)) + } + } + Err(ExecError::Function(err)) => Err(err), + Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err( + FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)), + ), + Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!( + "execution error: {err:?}" + ))), + } +} + +pub fn format_exec_output_apply_patch(exec_output: &ExecToolCallOutput) -> String { + let ExecToolCallOutput { + exit_code, + duration, + .. + } = exec_output; + + #[derive(Serialize)] + struct ExecMetadata { + exit_code: i32, + duration_seconds: f32, + } + + #[derive(Serialize)] + struct ExecOutput<'a> { + output: &'a str, + metadata: ExecMetadata, + } + + // round to 1 decimal place + let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0; + + let formatted_output = format_exec_output_str(exec_output); + + let payload = ExecOutput { + output: &formatted_output, + metadata: ExecMetadata { + exit_code: *exit_code, + duration_seconds, + }, + }; + + #[expect(clippy::expect_used)] + serde_json::to_string(&payload).expect("serialize ExecOutput") +} + +pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String { + let ExecToolCallOutput { + aggregated_output, .. + } = exec_output; + + // Head+tail truncation for the model: show the beginning and end with an elision. + // Clients still receive full streams; only this formatted summary is capped. + + let mut s = &aggregated_output.text; + let prefixed_str: String; + + if exec_output.timed_out { + prefixed_str = format!( + "command timed out after {} milliseconds\n", + exec_output.duration.as_millis() + ) + s; + s = &prefixed_str; + } + + let total_lines = s.lines().count(); + if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES { + return s.to_string(); + } + + let segments: Vec<&str> = s.split_inclusive('\n').collect(); + let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len()); + let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take)); + let omitted = segments.len().saturating_sub(head_take + tail_take); + + let head_slice_end: usize = segments + .iter() + .take(head_take) + .map(|segment| segment.len()) + .sum(); + let tail_slice_start: usize = if tail_take == 0 { + s.len() + } else { + s.len() + - segments + .iter() + .rev() + .take(tail_take) + .map(|segment| segment.len()) + .sum::() + }; + let marker = format!("\n[... omitted {omitted} of {total_lines} lines ...]\n\n"); + + // Byte budgets for head/tail around the marker + let mut head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES); + let tail_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(head_budget + marker.len()); + if tail_budget == 0 && marker.len() >= MODEL_FORMAT_MAX_BYTES { + // Degenerate case: marker alone exceeds budget; return a clipped marker + return take_bytes_at_char_boundary(&marker, MODEL_FORMAT_MAX_BYTES).to_string(); + } + if tail_budget == 0 { + // Make room for the marker by shrinking head + head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len()); + } + + let head_slice = &s[..head_slice_end]; + let head_part = take_bytes_at_char_boundary(head_slice, head_budget); + let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len())); + + result.push_str(head_part); + result.push_str(&marker); + + let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len()); + if remaining == 0 { + return result; + } + + let tail_slice = &s[tail_slice_start..]; + let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining); + result.push_str(tail_part); + + result +} diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs new file mode 100644 index 00000000..7c7b1d25 --- /dev/null +++ b/codex-rs/core/src/tools/registry.rs @@ -0,0 +1,197 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use codex_protocol::models::ResponseInputItem; +use tracing::warn; + +use crate::client_common::tools::ToolSpec; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolOutput; +use crate::tools::context::ToolPayload; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum ToolKind { + Function, + UnifiedExec, + Mcp, +} + +#[async_trait] +pub trait ToolHandler: Send + Sync { + fn kind(&self) -> ToolKind; + + fn matches_kind(&self, payload: &ToolPayload) -> bool { + matches!( + (self.kind(), payload), + (ToolKind::Function, ToolPayload::Function { .. }) + | (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. }) + | (ToolKind::Mcp, ToolPayload::Mcp { .. }) + ) + } + + async fn handle(&self, invocation: ToolInvocation<'_>) + -> Result; +} + +pub struct ToolRegistry { + handlers: HashMap>, +} + +impl ToolRegistry { + pub fn new(handlers: HashMap>) -> Self { + Self { handlers } + } + + pub fn handler(&self, name: &str) -> Option> { + self.handlers.get(name).map(Arc::clone) + } + + // TODO(jif) for dynamic tools. + // pub fn register(&mut self, name: impl Into, handler: Arc) { + // let name = name.into(); + // if self.handlers.insert(name.clone(), handler).is_some() { + // warn!("overwriting handler for tool {name}"); + // } + // } + + pub async fn dispatch<'a>( + &self, + invocation: ToolInvocation<'a>, + ) -> Result { + let tool_name = invocation.tool_name.clone(); + let call_id_owned = invocation.call_id.clone(); + let otel = invocation.turn.client.get_otel_event_manager(); + let payload_for_response = invocation.payload.clone(); + let log_payload = payload_for_response.log_payload(); + + let handler = match self.handler(tool_name.as_ref()) { + Some(handler) => handler, + None => { + let message = + unsupported_tool_call_message(&invocation.payload, tool_name.as_ref()); + otel.tool_result( + tool_name.as_ref(), + &call_id_owned, + log_payload.as_ref(), + Duration::ZERO, + false, + &message, + ); + return Err(FunctionCallError::RespondToModel(message)); + } + }; + + if !handler.matches_kind(&invocation.payload) { + let message = format!("tool {tool_name} invoked with incompatible payload"); + otel.tool_result( + tool_name.as_ref(), + &call_id_owned, + log_payload.as_ref(), + Duration::ZERO, + false, + &message, + ); + return Err(FunctionCallError::Fatal(message)); + } + + let output_cell = tokio::sync::Mutex::new(None); + + let result = otel + .log_tool_result( + tool_name.as_ref(), + &call_id_owned, + log_payload.as_ref(), + || { + let handler = handler.clone(); + let output_cell = &output_cell; + let invocation = invocation; + async move { + match handler.handle(invocation).await { + Ok(output) => { + let preview = output.log_preview(); + let success = output.success_for_logging(); + let mut guard = output_cell.lock().await; + *guard = Some(output); + Ok((preview, success)) + } + Err(err) => Err(err), + } + } + }, + ) + .await; + + match result { + Ok(_) => { + let mut guard = output_cell.lock().await; + let output = guard.take().ok_or_else(|| { + FunctionCallError::Fatal("tool produced no output".to_string()) + })?; + Ok(output.into_response(&call_id_owned, &payload_for_response)) + } + Err(err) => Err(err), + } + } +} + +pub struct ToolRegistryBuilder { + handlers: HashMap>, + specs: Vec, +} + +impl ToolRegistryBuilder { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + specs: Vec::new(), + } + } + + pub fn push_spec(&mut self, spec: ToolSpec) { + self.specs.push(spec); + } + + pub fn register_handler(&mut self, name: impl Into, handler: Arc) { + let name = name.into(); + if self + .handlers + .insert(name.clone(), handler.clone()) + .is_some() + { + warn!("overwriting handler for tool {name}"); + } + } + + // TODO(jif) for dynamic tools. + // pub fn register_many(&mut self, names: I, handler: Arc) + // where + // I: IntoIterator, + // I::Item: Into, + // { + // for name in names { + // let name = name.into(); + // if self + // .handlers + // .insert(name.clone(), handler.clone()) + // .is_some() + // { + // warn!("overwriting handler for tool {name}"); + // } + // } + // } + + pub fn build(self) -> (Vec, ToolRegistry) { + let registry = ToolRegistry::new(self.handlers); + (self.specs, registry) + } +} + +fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String { + match payload { + ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"), + _ => format!("unsupported call: {tool_name}"), + } +} diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs new file mode 100644 index 00000000..6ec62e20 --- /dev/null +++ b/codex-rs/core/src/tools/router.rs @@ -0,0 +1,177 @@ +use std::collections::HashMap; + +use crate::client_common::tools::ToolSpec; +use crate::codex::Session; +use crate::codex::TurnContext; +use crate::function_tool::FunctionCallError; +use crate::tools::context::ToolInvocation; +use crate::tools::context::ToolPayload; +use crate::tools::registry::ToolRegistry; +use crate::tools::spec::ToolsConfig; +use crate::tools::spec::build_specs; +use crate::turn_diff_tracker::TurnDiffTracker; +use codex_protocol::models::LocalShellAction; +use codex_protocol::models::ResponseInputItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::models::ShellToolCallParams; + +#[derive(Clone)] +pub struct ToolCall { + pub tool_name: String, + pub call_id: String, + pub payload: ToolPayload, +} + +pub struct ToolRouter { + registry: ToolRegistry, + specs: Vec, +} + +impl ToolRouter { + pub fn from_config( + config: &ToolsConfig, + mcp_tools: Option>, + ) -> Self { + let builder = build_specs(config, mcp_tools); + let (specs, registry) = builder.build(); + Self { registry, specs } + } + + pub fn specs(&self) -> &[ToolSpec] { + &self.specs + } + + pub fn build_tool_call( + session: &Session, + item: ResponseItem, + ) -> Result, FunctionCallError> { + match item { + ResponseItem::FunctionCall { + name, + arguments, + call_id, + .. + } => { + if let Some((server, tool)) = session.parse_mcp_tool_name(&name) { + Ok(Some(ToolCall { + tool_name: name, + call_id, + payload: ToolPayload::Mcp { + server, + tool, + raw_arguments: arguments, + }, + })) + } else { + let payload = if name == "unified_exec" { + ToolPayload::UnifiedExec { arguments } + } else { + ToolPayload::Function { arguments } + }; + Ok(Some(ToolCall { + tool_name: name, + call_id, + payload, + })) + } + } + ResponseItem::CustomToolCall { + name, + input, + call_id, + .. + } => Ok(Some(ToolCall { + tool_name: name, + call_id, + payload: ToolPayload::Custom { input }, + })), + ResponseItem::LocalShellCall { + id, + call_id, + action, + .. + } => { + let call_id = call_id + .or(id) + .ok_or(FunctionCallError::MissingLocalShellCallId)?; + + match action { + LocalShellAction::Exec(exec) => { + let params = ShellToolCallParams { + command: exec.command, + workdir: exec.working_directory, + timeout_ms: exec.timeout_ms, + with_escalated_permissions: None, + justification: None, + }; + Ok(Some(ToolCall { + tool_name: "local_shell".to_string(), + call_id, + payload: ToolPayload::LocalShell { params }, + })) + } + } + } + _ => Ok(None), + } + } + + pub async fn dispatch_tool_call( + &self, + session: &Session, + turn: &TurnContext, + tracker: &mut TurnDiffTracker, + sub_id: &str, + call: ToolCall, + ) -> Result { + let ToolCall { + tool_name, + call_id, + payload, + } = call; + let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. }); + let failure_call_id = call_id.clone(); + + let invocation = ToolInvocation { + session, + turn, + tracker, + sub_id, + call_id, + tool_name, + payload, + }; + + match self.registry.dispatch(invocation).await { + Ok(response) => Ok(response), + Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)), + Err(err) => Ok(Self::failure_response( + failure_call_id, + payload_outputs_custom, + err, + )), + } + } + + fn failure_response( + call_id: String, + payload_outputs_custom: bool, + err: FunctionCallError, + ) -> ResponseInputItem { + let message = err.to_string(); + if payload_outputs_custom { + ResponseInputItem::CustomToolCallOutput { + call_id, + output: message, + } + } else { + ResponseInputItem::FunctionCallOutput { + call_id, + output: codex_protocol::models::FunctionCallOutputPayload { + content: message, + success: Some(false), + }, + } + } + } +} diff --git a/codex-rs/core/src/tools/spec.rs b/codex-rs/core/src/tools/spec.rs new file mode 100644 index 00000000..18d8555a --- /dev/null +++ b/codex-rs/core/src/tools/spec.rs @@ -0,0 +1,1269 @@ +use crate::client_common::tools::ResponsesApiTool; +use crate::client_common::tools::ToolSpec; +use crate::model_family::ModelFamily; +use crate::tools::handlers::PLAN_TOOL; +use crate::tools::handlers::apply_patch::ApplyPatchToolType; +use crate::tools::handlers::apply_patch::create_apply_patch_freeform_tool; +use crate::tools::handlers::apply_patch::create_apply_patch_json_tool; +use crate::tools::registry::ToolRegistryBuilder; +use serde::Deserialize; +use serde::Serialize; +use serde_json::Value as JsonValue; +use serde_json::json; +use std::collections::BTreeMap; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub enum ConfigShellToolType { + Default, + Local, + Streamable, +} + +#[derive(Debug, Clone)] +pub(crate) struct ToolsConfig { + pub shell_type: ConfigShellToolType, + pub plan_tool: bool, + pub apply_patch_tool_type: Option, + pub web_search_request: bool, + pub include_view_image_tool: bool, + pub experimental_unified_exec_tool: bool, +} + +pub(crate) struct ToolsConfigParams<'a> { + pub(crate) model_family: &'a ModelFamily, + pub(crate) include_plan_tool: bool, + pub(crate) include_apply_patch_tool: bool, + pub(crate) include_web_search_request: bool, + pub(crate) use_streamable_shell_tool: bool, + pub(crate) include_view_image_tool: bool, + pub(crate) experimental_unified_exec_tool: bool, +} + +impl ToolsConfig { + pub fn new(params: &ToolsConfigParams) -> Self { + let ToolsConfigParams { + model_family, + include_plan_tool, + include_apply_patch_tool, + include_web_search_request, + use_streamable_shell_tool, + include_view_image_tool, + experimental_unified_exec_tool, + } = params; + let shell_type = if *use_streamable_shell_tool { + ConfigShellToolType::Streamable + } else if model_family.uses_local_shell_tool { + ConfigShellToolType::Local + } else { + ConfigShellToolType::Default + }; + + let apply_patch_tool_type = match model_family.apply_patch_tool_type { + Some(ApplyPatchToolType::Freeform) => Some(ApplyPatchToolType::Freeform), + Some(ApplyPatchToolType::Function) => Some(ApplyPatchToolType::Function), + None => { + if *include_apply_patch_tool { + Some(ApplyPatchToolType::Freeform) + } else { + None + } + } + }; + + Self { + shell_type, + plan_tool: *include_plan_tool, + apply_patch_tool_type, + web_search_request: *include_web_search_request, + include_view_image_tool: *include_view_image_tool, + experimental_unified_exec_tool: *experimental_unified_exec_tool, + } + } +} + +/// Generic JSON‑Schema subset needed for our tool definitions +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub(crate) enum JsonSchema { + Boolean { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + String { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + /// MCP schema allows "number" | "integer" for Number + #[serde(alias = "integer")] + Number { + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + Array { + items: Box, + + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + }, + Object { + properties: BTreeMap, + #[serde(skip_serializing_if = "Option::is_none")] + required: Option>, + #[serde( + rename = "additionalProperties", + skip_serializing_if = "Option::is_none" + )] + additional_properties: Option, + }, +} + +/// Whether additional properties are allowed, and if so, any required schema +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub(crate) enum AdditionalProperties { + Boolean(bool), + Schema(Box), +} + +impl From for AdditionalProperties { + fn from(b: bool) -> Self { + Self::Boolean(b) + } +} + +impl From for AdditionalProperties { + fn from(s: JsonSchema) -> Self { + Self::Schema(Box::new(s)) + } +} + +fn create_unified_exec_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "input".to_string(), + JsonSchema::Array { + items: Box::new(JsonSchema::String { description: None }), + description: Some( + "When no session_id is provided, treat the array as the command and arguments \ + to launch. When session_id is set, concatenate the strings (in order) and write \ + them to the session's stdin." + .to_string(), + ), + }, + ); + properties.insert( + "session_id".to_string(), + JsonSchema::String { + description: Some( + "Identifier for an existing interactive session. If omitted, a new command \ + is spawned." + .to_string(), + ), + }, + ); + properties.insert( + "timeout_ms".to_string(), + JsonSchema::Number { + description: Some( + "Maximum time in milliseconds to wait for output after writing the input." + .to_string(), + ), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "unified_exec".to_string(), + description: + "Runs a command in a PTY. Provide a session_id to reuse an existing interactive session.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["input".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_shell_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "command".to_string(), + JsonSchema::Array { + items: Box::new(JsonSchema::String { description: None }), + description: Some("The command to execute".to_string()), + }, + ); + properties.insert( + "workdir".to_string(), + JsonSchema::String { + description: Some("The working directory to execute the command in".to_string()), + }, + ); + properties.insert( + "timeout_ms".to_string(), + JsonSchema::Number { + description: Some("The timeout for the command in milliseconds".to_string()), + }, + ); + + properties.insert( + "with_escalated_permissions".to_string(), + JsonSchema::Boolean { + description: Some("Whether to request escalated permissions. Set to true if command needs to be run without sandbox restrictions".to_string()), + }, + ); + properties.insert( + "justification".to_string(), + JsonSchema::String { + description: Some("Only set if with_escalated_permissions is true. 1-sentence explanation of why we want to run this command.".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "shell".to_string(), + description: "Runs a shell command and returns its output.".to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["command".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_view_image_tool() -> ToolSpec { + // Support only local filesystem path. + let mut properties = BTreeMap::new(); + properties.insert( + "path".to_string(), + JsonSchema::String { + description: Some("Local filesystem path to an image file".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "view_image".to_string(), + description: + "Attach a local image (by filesystem path) to the conversation context for this turn." + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["path".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} + +fn create_read_file_tool() -> ToolSpec { + let mut properties = BTreeMap::new(); + properties.insert( + "file_path".to_string(), + JsonSchema::String { + description: Some("Absolute path to the file".to_string()), + }, + ); + properties.insert( + "offset".to_string(), + JsonSchema::Number { + description: Some( + "The line number to start reading from. Must be 1 or greater.".to_string(), + ), + }, + ); + properties.insert( + "limit".to_string(), + JsonSchema::Number { + description: Some("The maximum number of lines to return.".to_string()), + }, + ); + + ToolSpec::Function(ResponsesApiTool { + name: "read_file".to_string(), + description: + "Reads a local file with 1-indexed line numbers and returns up to the requested number of lines." + .to_string(), + strict: false, + parameters: JsonSchema::Object { + properties, + required: Some(vec!["file_path".to_string()]), + additional_properties: Some(false.into()), + }, + }) +} +/// TODO(dylan): deprecate once we get rid of json tool +#[derive(Serialize, Deserialize)] +pub(crate) struct ApplyPatchToolArgs { + pub(crate) input: String, +} + +/// Returns JSON values that are compatible with Function Calling in the +/// Responses API: +/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses +pub fn create_tools_json_for_responses_api( + tools: &[ToolSpec], +) -> crate::error::Result> { + let mut tools_json = Vec::new(); + + for tool in tools { + let json = serde_json::to_value(tool)?; + tools_json.push(json); + } + + Ok(tools_json) +} +/// Returns JSON values that are compatible with Function Calling in the +/// Chat Completions API: +/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat +pub(crate) fn create_tools_json_for_chat_completions_api( + tools: &[ToolSpec], +) -> crate::error::Result> { + // We start with the JSON for the Responses API and than rewrite it to match + // the chat completions tool call format. + let responses_api_tools_json = create_tools_json_for_responses_api(tools)?; + let tools_json = responses_api_tools_json + .into_iter() + .filter_map(|mut tool| { + if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) { + return None; + } + + if let Some(map) = tool.as_object_mut() { + // Remove "type" field as it is not needed in chat completions. + map.remove("type"); + Some(json!({ + "type": "function", + "function": map, + })) + } else { + None + } + }) + .collect::>(); + Ok(tools_json) +} + +pub(crate) fn mcp_tool_to_openai_tool( + fully_qualified_name: String, + tool: mcp_types::Tool, +) -> Result { + let mcp_types::Tool { + description, + mut input_schema, + .. + } = tool; + + // OpenAI models mandate the "properties" field in the schema. The Agents + // SDK fixed this by inserting an empty object for "properties" if it is not + // already present https://github.com/openai/openai-agents-python/issues/449 + // so here we do the same. + if input_schema.properties.is_none() { + input_schema.properties = Some(serde_json::Value::Object(serde_json::Map::new())); + } + + // Serialize to a raw JSON value so we can sanitize schemas coming from MCP + // servers. Some servers omit the top-level or nested `type` in JSON + // Schemas (e.g. using enum/anyOf), or use unsupported variants like + // `integer`. Our internal JsonSchema is a small subset and requires + // `type`, so we coerce/sanitize here for compatibility. + let mut serialized_input_schema = serde_json::to_value(input_schema)?; + sanitize_json_schema(&mut serialized_input_schema); + let input_schema = serde_json::from_value::(serialized_input_schema)?; + + Ok(ResponsesApiTool { + name: fully_qualified_name, + description: description.unwrap_or_default(), + strict: false, + parameters: input_schema, + }) +} + +/// Sanitize a JSON Schema (as serde_json::Value) so it can fit our limited +/// JsonSchema enum. This function: +/// - Ensures every schema object has a "type". If missing, infers it from +/// common keywords (properties => object, items => array, enum/const/format => string) +/// and otherwise defaults to "string". +/// - Fills required child fields (e.g. array items, object properties) with +/// permissive defaults when absent. +fn sanitize_json_schema(value: &mut JsonValue) { + match value { + JsonValue::Bool(_) => { + // JSON Schema boolean form: true/false. Coerce to an accept-all string. + *value = json!({ "type": "string" }); + } + JsonValue::Array(arr) => { + for v in arr.iter_mut() { + sanitize_json_schema(v); + } + } + JsonValue::Object(map) => { + // First, recursively sanitize known nested schema holders + if let Some(props) = map.get_mut("properties") + && let Some(props_map) = props.as_object_mut() + { + for (_k, v) in props_map.iter_mut() { + sanitize_json_schema(v); + } + } + if let Some(items) = map.get_mut("items") { + sanitize_json_schema(items); + } + // Some schemas use oneOf/anyOf/allOf - sanitize their entries + for combiner in ["oneOf", "anyOf", "allOf", "prefixItems"] { + if let Some(v) = map.get_mut(combiner) { + sanitize_json_schema(v); + } + } + + // Normalize/ensure type + let mut ty = map.get("type").and_then(|v| v.as_str()).map(str::to_string); + + // If type is an array (union), pick first supported; else leave to inference + if ty.is_none() + && let Some(JsonValue::Array(types)) = map.get("type") + { + for t in types { + if let Some(tt) = t.as_str() + && matches!( + tt, + "object" | "array" | "string" | "number" | "integer" | "boolean" + ) + { + ty = Some(tt.to_string()); + break; + } + } + } + + // Infer type if still missing + if ty.is_none() { + if map.contains_key("properties") + || map.contains_key("required") + || map.contains_key("additionalProperties") + { + ty = Some("object".to_string()); + } else if map.contains_key("items") || map.contains_key("prefixItems") { + ty = Some("array".to_string()); + } else if map.contains_key("enum") + || map.contains_key("const") + || map.contains_key("format") + { + ty = Some("string".to_string()); + } else if map.contains_key("minimum") + || map.contains_key("maximum") + || map.contains_key("exclusiveMinimum") + || map.contains_key("exclusiveMaximum") + || map.contains_key("multipleOf") + { + ty = Some("number".to_string()); + } + } + // If we still couldn't infer, default to string + let ty = ty.unwrap_or_else(|| "string".to_string()); + map.insert("type".to_string(), JsonValue::String(ty.to_string())); + + // Ensure object schemas have properties map + if ty == "object" { + if !map.contains_key("properties") { + map.insert( + "properties".to_string(), + JsonValue::Object(serde_json::Map::new()), + ); + } + // If additionalProperties is an object schema, sanitize it too. + // Leave booleans as-is, since JSON Schema allows boolean here. + if let Some(ap) = map.get_mut("additionalProperties") { + let is_bool = matches!(ap, JsonValue::Bool(_)); + if !is_bool { + sanitize_json_schema(ap); + } + } + } + + // Ensure array schemas have items + if ty == "array" && !map.contains_key("items") { + map.insert("items".to_string(), json!({ "type": "string" })); + } + } + _ => {} + } +} + +/// Builds the tool registry builder while collecting tool specs for later serialization. +pub(crate) fn build_specs( + config: &ToolsConfig, + mcp_tools: Option>, +) -> ToolRegistryBuilder { + use crate::exec_command::EXEC_COMMAND_TOOL_NAME; + use crate::exec_command::WRITE_STDIN_TOOL_NAME; + use crate::exec_command::create_exec_command_tool_for_responses_api; + use crate::exec_command::create_write_stdin_tool_for_responses_api; + use crate::tools::handlers::ApplyPatchHandler; + use crate::tools::handlers::ExecStreamHandler; + use crate::tools::handlers::McpHandler; + use crate::tools::handlers::PlanHandler; + use crate::tools::handlers::ReadFileHandler; + use crate::tools::handlers::ShellHandler; + use crate::tools::handlers::UnifiedExecHandler; + use crate::tools::handlers::ViewImageHandler; + use std::sync::Arc; + + let mut builder = ToolRegistryBuilder::new(); + + let shell_handler = Arc::new(ShellHandler); + let exec_stream_handler = Arc::new(ExecStreamHandler); + let unified_exec_handler = Arc::new(UnifiedExecHandler); + let plan_handler = Arc::new(PlanHandler); + let read_file_handler = Arc::new(ReadFileHandler); + let apply_patch_handler = Arc::new(ApplyPatchHandler); + let view_image_handler = Arc::new(ViewImageHandler); + let mcp_handler = Arc::new(McpHandler); + + if config.experimental_unified_exec_tool { + builder.push_spec(create_unified_exec_tool()); + builder.register_handler("unified_exec", unified_exec_handler); + } else { + match &config.shell_type { + ConfigShellToolType::Default => { + builder.push_spec(create_shell_tool()); + } + ConfigShellToolType::Local => { + builder.push_spec(ToolSpec::LocalShell {}); + } + ConfigShellToolType::Streamable => { + builder.push_spec(ToolSpec::Function( + create_exec_command_tool_for_responses_api(), + )); + builder.push_spec(ToolSpec::Function( + create_write_stdin_tool_for_responses_api(), + )); + builder.register_handler(EXEC_COMMAND_TOOL_NAME, exec_stream_handler.clone()); + builder.register_handler(WRITE_STDIN_TOOL_NAME, exec_stream_handler); + } + } + } + + // Always register shell aliases so older prompts remain compatible. + builder.register_handler("shell", shell_handler.clone()); + builder.register_handler("container.exec", shell_handler.clone()); + builder.register_handler("local_shell", shell_handler); + + if config.plan_tool { + builder.push_spec(PLAN_TOOL.clone()); + builder.register_handler("update_plan", plan_handler); + } + + if let Some(apply_patch_tool_type) = &config.apply_patch_tool_type { + match apply_patch_tool_type { + ApplyPatchToolType::Freeform => { + builder.push_spec(create_apply_patch_freeform_tool()); + } + ApplyPatchToolType::Function => { + builder.push_spec(create_apply_patch_json_tool()); + } + } + builder.register_handler("apply_patch", apply_patch_handler); + } + + builder.push_spec(create_read_file_tool()); + builder.register_handler("read_file", read_file_handler); + + if config.web_search_request { + builder.push_spec(ToolSpec::WebSearch {}); + } + + if config.include_view_image_tool { + builder.push_spec(create_view_image_tool()); + builder.register_handler("view_image", view_image_handler); + } + + if let Some(mcp_tools) = mcp_tools { + let mut entries: Vec<(String, mcp_types::Tool)> = mcp_tools.into_iter().collect(); + entries.sort_by(|a, b| a.0.cmp(&b.0)); + + for (name, tool) in entries.into_iter() { + match mcp_tool_to_openai_tool(name.clone(), tool.clone()) { + Ok(converted_tool) => { + builder.push_spec(ToolSpec::Function(converted_tool)); + builder.register_handler(name, mcp_handler.clone()); + } + Err(e) => { + tracing::error!("Failed to convert {name:?} MCP tool to OpenAI tool: {e:?}"); + } + } + } + } + + builder +} + +#[cfg(test)] +mod tests { + use crate::client_common::tools::FreeformTool; + use crate::model_family::find_family_for_model; + use mcp_types::ToolInputSchema; + use pretty_assertions::assert_eq; + + use super::*; + + fn assert_eq_tool_names(tools: &[ToolSpec], expected_names: &[&str]) { + let tool_names = tools + .iter() + .map(|tool| match tool { + ToolSpec::Function(ResponsesApiTool { name, .. }) => name, + ToolSpec::LocalShell {} => "local_shell", + ToolSpec::WebSearch {} => "web_search", + ToolSpec::Freeform(FreeformTool { name, .. }) => name, + }) + .collect::>(); + + assert_eq!( + tool_names.len(), + expected_names.len(), + "tool_name mismatch, {tool_names:?}, {expected_names:?}", + ); + for (name, expected_name) in tool_names.iter().zip(expected_names.iter()) { + assert_eq!( + name, expected_name, + "tool_name mismatch, {name:?}, {expected_name:?}" + ); + } + } + + #[test] + fn test_build_specs() { + let model_family = find_family_for_model("codex-mini-latest") + .expect("codex-mini-latest should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: true, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + let (tools, _) = build_specs(&config, Some(HashMap::new())).build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "update_plan", + "read_file", + "web_search", + "view_image", + ], + ); + } + + #[test] + fn test_build_specs_default_shell() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: true, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + let (tools, _) = build_specs(&config, Some(HashMap::new())).build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "update_plan", + "read_file", + "web_search", + "view_image", + ], + ); + } + + #[test] + fn test_build_specs_mcp_tools() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "test_server/do_something_cool".to_string(), + mcp_types::Tool { + name: "do_something_cool".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "string_argument": { + "type": "string", + }, + "number_argument": { + "type": "number", + }, + "object_argument": { + "type": "object", + "properties": { + "string_property": { "type": "string" }, + "number_property": { "type": "number" }, + }, + "required": [ + "string_property", + "number_property", + ], + "additionalProperties": Some(false), + }, + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("Do something cool".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "test_server/do_something_cool", + ], + ); + + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "test_server/do_something_cool".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_argument".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_argument".to_string(), + JsonSchema::Number { description: None } + ), + ( + "object_argument".to_string(), + JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_property".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_property".to_string(), + JsonSchema::Number { description: None } + ), + ]), + required: Some(vec![ + "string_property".to_string(), + "number_property".to_string(), + ]), + additional_properties: Some(false.into()), + }, + ), + ]), + required: None, + additional_properties: None, + }, + description: "Do something cool".to_string(), + strict: false, + }) + ); + } + + #[test] + fn test_build_specs_mcp_tools_sorted_by_name() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: false, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + + // Intentionally construct a map with keys that would sort alphabetically. + let tools_map: HashMap = HashMap::from([ + ( + "test_server/do".to_string(), + mcp_types::Tool { + name: "a".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({})), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("a".to_string()), + }, + ), + ( + "test_server/something".to_string(), + mcp_types::Tool { + name: "b".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({})), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("b".to_string()), + }, + ), + ( + "test_server/cool".to_string(), + mcp_types::Tool { + name: "c".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({})), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("c".to_string()), + }, + ), + ]); + + let (tools, _) = build_specs(&config, Some(tools_map)).build(); + // Expect unified_exec first, followed by MCP tools sorted by fully-qualified name. + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "view_image", + "test_server/cool", + "test_server/do", + "test_server/something", + ], + ); + } + + #[test] + fn test_mcp_tool_property_missing_type_defaults_to_string() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "dash/search".to_string(), + mcp_types::Tool { + name: "search".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "query": { + "description": "search query" + } + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("Search docs".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "dash/search", + ], + ); + + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "dash/search".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "query".to_string(), + JsonSchema::String { + description: Some("search query".to_string()) + } + )]), + required: None, + additional_properties: None, + }, + description: "Search docs".to_string(), + strict: false, + }) + ); + } + + #[test] + fn test_mcp_tool_integer_normalized_to_number() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "dash/paginate".to_string(), + mcp_types::Tool { + name: "paginate".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "page": { "type": "integer" } + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("Pagination".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "dash/paginate", + ], + ); + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "dash/paginate".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "page".to_string(), + JsonSchema::Number { description: None } + )]), + required: None, + additional_properties: None, + }, + description: "Pagination".to_string(), + strict: false, + }) + ); + } + + #[test] + fn test_mcp_tool_array_without_items_gets_default_string_items() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "dash/tags".to_string(), + mcp_types::Tool { + name: "tags".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "tags": { "type": "array" } + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("Tags".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "dash/tags", + ], + ); + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "dash/tags".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "tags".to_string(), + JsonSchema::Array { + items: Box::new(JsonSchema::String { description: None }), + description: None + } + )]), + required: None, + additional_properties: None, + }, + description: "Tags".to_string(), + strict: false, + }) + ); + } + + #[test] + fn test_mcp_tool_anyof_defaults_to_string() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "dash/value".to_string(), + mcp_types::Tool { + name: "value".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "value": { "anyOf": [ { "type": "string" }, { "type": "number" } ] } + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("AnyOf Value".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "dash/value", + ], + ); + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "dash/value".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([( + "value".to_string(), + JsonSchema::String { description: None } + )]), + required: None, + additional_properties: None, + }, + description: "AnyOf Value".to_string(), + strict: false, + }) + ); + } + + #[test] + fn test_shell_tool() { + let tool = super::create_shell_tool(); + let ToolSpec::Function(ResponsesApiTool { + description, name, .. + }) = &tool + else { + panic!("expected function tool"); + }; + assert_eq!(name, "shell"); + + let expected = "Runs a shell command and returns its output."; + assert_eq!(description, expected); + } + + #[test] + fn test_get_openai_tools_mcp_tools_with_additional_properties_schema() { + let model_family = find_family_for_model("o3").expect("o3 should be a valid model family"); + let config = ToolsConfig::new(&ToolsConfigParams { + model_family: &model_family, + include_plan_tool: false, + include_apply_patch_tool: false, + include_web_search_request: true, + use_streamable_shell_tool: false, + include_view_image_tool: true, + experimental_unified_exec_tool: true, + }); + let (tools, _) = build_specs( + &config, + Some(HashMap::from([( + "test_server/do_something_cool".to_string(), + mcp_types::Tool { + name: "do_something_cool".to_string(), + input_schema: ToolInputSchema { + properties: Some(serde_json::json!({ + "string_argument": { + "type": "string", + }, + "number_argument": { + "type": "number", + }, + "object_argument": { + "type": "object", + "properties": { + "string_property": { "type": "string" }, + "number_property": { "type": "number" }, + }, + "required": [ + "string_property", + "number_property", + ], + "additionalProperties": { + "type": "object", + "properties": { + "addtl_prop": { "type": "string" }, + }, + "required": [ + "addtl_prop", + ], + "additionalProperties": false, + }, + }, + })), + required: None, + r#type: "object".to_string(), + }, + output_schema: None, + title: None, + annotations: None, + description: Some("Do something cool".to_string()), + }, + )])), + ) + .build(); + + assert_eq_tool_names( + &tools, + &[ + "unified_exec", + "read_file", + "web_search", + "view_image", + "test_server/do_something_cool", + ], + ); + + assert_eq!( + tools[4], + ToolSpec::Function(ResponsesApiTool { + name: "test_server/do_something_cool".to_string(), + parameters: JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_argument".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_argument".to_string(), + JsonSchema::Number { description: None } + ), + ( + "object_argument".to_string(), + JsonSchema::Object { + properties: BTreeMap::from([ + ( + "string_property".to_string(), + JsonSchema::String { description: None } + ), + ( + "number_property".to_string(), + JsonSchema::Number { description: None } + ), + ]), + required: Some(vec![ + "string_property".to_string(), + "number_property".to_string(), + ]), + additional_properties: Some( + JsonSchema::Object { + properties: BTreeMap::from([( + "addtl_prop".to_string(), + JsonSchema::String { description: None } + ),]), + required: Some(vec!["addtl_prop".to_string(),]), + additional_properties: Some(false.into()), + } + .into() + ), + }, + ), + ]), + required: None, + additional_properties: None, + }, + description: "Do something cool".to_string(), + strict: false, + }) + ); + } +} diff --git a/codex-rs/core/tests/common/lib.rs b/codex-rs/core/tests/common/lib.rs index ce90f397..500c5955 100644 --- a/codex-rs/core/tests/common/lib.rs +++ b/codex-rs/core/tests/common/lib.rs @@ -7,6 +7,9 @@ use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; +#[cfg(target_os = "linux")] +use assert_cmd::cargo::cargo_bin; + pub mod responses; pub mod test_codex; pub mod test_codex_exec; @@ -17,12 +20,25 @@ pub mod test_codex_exec; pub fn load_default_config_for_test(codex_home: &TempDir) -> Config { Config::load_from_base_config_with_overrides( ConfigToml::default(), - ConfigOverrides::default(), + default_test_overrides(), codex_home.path().to_path_buf(), ) .expect("defaults for test should always succeed") } +#[cfg(target_os = "linux")] +fn default_test_overrides() -> ConfigOverrides { + ConfigOverrides { + codex_linux_sandbox_exe: Some(cargo_bin("codex-linux-sandbox")), + ..ConfigOverrides::default() + } +} + +#[cfg(not(target_os = "linux"))] +fn default_test_overrides() -> ConfigOverrides { + ConfigOverrides::default() +} + /// Builds an SSE stream body from a JSON fixture. /// /// The fixture must contain an array of objects where each object represents a diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index b3a90ff3..5f2892cb 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -12,12 +12,18 @@ mod fork_conversation; mod json_result; mod live_cli; mod model_overrides; +mod model_tools; mod otel; mod prompt_caching; +mod read_file; mod review; mod rmcp_client; mod rollout_list_find; mod seatbelt; mod stream_error_allows_next_turn; mod stream_no_completed; +mod tool_harness; +mod tools; +mod unified_exec; mod user_notification; +mod view_image; diff --git a/codex-rs/core/tests/suite/model_tools.rs b/codex-rs/core/tests/suite/model_tools.rs new file mode 100644 index 00000000..29a31911 --- /dev/null +++ b/codex-rs/core/tests/suite/model_tools.rs @@ -0,0 +1,124 @@ +#![allow(clippy::unwrap_used)] + +use codex_core::CodexAuth; +use codex_core::ConversationManager; +use codex_core::ModelProviderInfo; +use codex_core::built_in_model_providers; +use codex_core::model_family::find_family_for_model; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use core_test_support::load_default_config_for_test; +use core_test_support::load_sse_fixture_with_id; +use core_test_support::skip_if_no_network; +use core_test_support::wait_for_event; +use tempfile::TempDir; +use wiremock::Mock; +use wiremock::MockServer; +use wiremock::ResponseTemplate; +use wiremock::matchers::method; +use wiremock::matchers::path; + +fn sse_completed(id: &str) -> String { + load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) +} + +#[allow(clippy::expect_used)] +fn tool_identifiers(body: &serde_json::Value) -> Vec { + body["tools"] + .as_array() + .unwrap() + .iter() + .map(|tool| { + tool.get("name") + .and_then(|v| v.as_str()) + .or_else(|| tool.get("type").and_then(|v| v.as_str())) + .map(std::string::ToString::to_string) + .expect("tool should have either name or type") + }) + .collect() +} + +#[allow(clippy::expect_used)] +async fn collect_tool_identifiers_for_model(model: &str) -> Vec { + let server = MockServer::start().await; + + let sse = sse_completed(model); + let template = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse, "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(template) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + base_url: Some(format!("{}/v1", server.uri())), + ..built_in_model_providers()["openai"].clone() + }; + + let cwd = TempDir::new().unwrap(); + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + config.cwd = cwd.path().to_path_buf(); + config.model_provider = model_provider; + config.model = model.to_string(); + config.model_family = + find_family_for_model(model).unwrap_or_else(|| panic!("unknown model family for {model}")); + config.include_plan_tool = false; + config.include_apply_patch_tool = false; + config.include_view_image_tool = false; + config.tools_web_search_request = false; + config.use_experimental_streamable_shell_tool = false; + config.use_experimental_unified_exec_tool = false; + + let conversation_manager = + ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key")); + let codex = conversation_manager + .new_conversation(config) + .await + .expect("create new conversation") + .conversation; + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello tools".into(), + }], + }) + .await + .unwrap(); + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.unwrap(); + assert_eq!( + requests.len(), + 1, + "expected a single request for model {model}" + ); + let body = requests[0].body_json::().unwrap(); + tool_identifiers(&body) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn model_selects_expected_tools() { + skip_if_no_network!(); + use pretty_assertions::assert_eq; + + let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await; + assert_eq!( + codex_tools, + vec!["local_shell".to_string(), "read_file".to_string()], + "codex-mini-latest should expose the local shell tool", + ); + + let o3_tools = collect_tool_identifiers_for_model("o3").await; + assert_eq!( + o3_tools, + vec!["shell".to_string(), "read_file".to_string()], + "o3 should expose the generic shell tool", + ); +} diff --git a/codex-rs/core/tests/suite/prompt_caching.rs b/codex-rs/core/tests/suite/prompt_caching.rs index 79be6083..bc66be18 100644 --- a/codex-rs/core/tests/suite/prompt_caching.rs +++ b/codex-rs/core/tests/suite/prompt_caching.rs @@ -219,7 +219,13 @@ async fn prompt_tools_are_consistent_across_requests() { // our internal implementation is responsible for keeping tools in sync // with the OpenAI schema, so we just verify the tool presence here - let expected_tools_names: &[&str] = &["shell", "update_plan", "apply_patch", "view_image"]; + let expected_tools_names: &[&str] = &[ + "shell", + "update_plan", + "apply_patch", + "read_file", + "view_image", + ]; let body0 = requests[0].body_json::().unwrap(); assert_eq!( body0["instructions"], diff --git a/codex-rs/core/tests/suite/read_file.rs b/codex-rs/core/tests/suite/read_file.rs new file mode 100644 index 00000000..d72f53e3 --- /dev/null +++ b/codex-rs/core/tests/suite/read_file.rs @@ -0,0 +1,124 @@ +#![cfg(not(target_os = "windows"))] + +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use pretty_assertions::assert_eq; +use serde_json::Value; +use wiremock::matchers::any; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let TestCodex { + codex, + cwd, + session_configured, + .. + } = test_codex().build(&server).await?; + + let file_path = cwd.path().join("sample.txt"); + std::fs::write(&file_path, "first\nsecond\nthird\nfourth\n")?; + let file_path = file_path.to_string_lossy().to_string(); + + let call_id = "read-file-call"; + let arguments = serde_json::json!({ + "file_path": file_path, + "offset": 2, + "limit": 2, + }) + .to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "read_file", &arguments), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please inspect sample.txt".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let requests = server.received_requests().await.expect("recorded requests"); + let request_bodies = requests + .iter() + .map(|req| req.body_json::().unwrap()) + .collect::>(); + assert!( + !request_bodies.is_empty(), + "expected at least one request body" + ); + + let tool_output_item = request_bodies + .iter() + .find_map(|body| { + body.get("input") + .and_then(Value::as_array) + .and_then(|items| { + items.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + }) + }) + }) + .unwrap_or_else(|| { + panic!("function_call_output item not found in requests: {request_bodies:#?}") + }); + + assert_eq!( + tool_output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + + let output_text = tool_output_item + .get("output") + .and_then(|value| match value { + Value::String(text) => Some(text.as_str()), + Value::Object(obj) => obj.get("content").and_then(Value::as_str), + _ => None, + }) + .expect("output text present"); + assert_eq!(output_text, "L2: second\nL3: third"); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/tool_harness.rs b/codex-rs/core/tests/suite/tool_harness.rs new file mode 100644 index 00000000..317e530c --- /dev/null +++ b/codex-rs/core/tests/suite/tool_harness.rs @@ -0,0 +1,568 @@ +#![cfg(not(target_os = "windows"))] + +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::plan_tool::StepStatus; +use core_test_support::responses; +use core_test_support::responses::ev_apply_patch_function_call; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_local_shell_call; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use serde_json::Value; +use serde_json::json; +use wiremock::matchers::any; + +fn function_call_output(body: &Value) -> Option<&Value> { + body.get("input") + .and_then(Value::as_array) + .and_then(|items| { + items.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + }) + }) +} + +fn extract_output_text(item: &Value) -> Option<&str> { + item.get("output").and_then(|value| match value { + Value::String(text) => Some(text.as_str()), + Value::Object(obj) => obj.get("content").and_then(Value::as_str), + _ => None, + }) +} + +fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> { + requests + .iter() + .find(|body| function_call_output(body).is_some()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.include_apply_patch_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "shell-tool-call"; + let command = vec!["/bin/echo", "tool harness"]; + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_local_shell_call(call_id, "completed", command), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "all done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please run the shell command".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let output_text = extract_output_text(output_item).expect("output text present"); + let exec_output: Value = serde_json::from_str(output_text)?; + assert_eq!(exec_output["metadata"]["exit_code"], 0); + let stdout = exec_output["output"].as_str().expect("stdout field"); + assert!( + stdout.contains("tool harness"), + "expected stdout to contain command output, got {stdout:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.include_plan_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "plan-tool-call"; + let plan_args = json!({ + "explanation": "Tool harness check", + "plan": [ + {"step": "Inspect workspace", "status": "in_progress"}, + {"step": "Report results", "status": "pending"}, + ], + }) + .to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "update_plan", &plan_args), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "plan acknowledged"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please update the plan".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let mut saw_plan_update = false; + + loop { + let event = codex.next_event().await.expect("event"); + match event.msg { + EventMsg::PlanUpdate(update) => { + saw_plan_update = true; + assert_eq!(update.explanation.as_deref(), Some("Tool harness check")); + assert_eq!(update.plan.len(), 2); + assert_eq!(update.plan[0].step, "Inspect workspace"); + assert!(matches!(update.plan[0].status, StepStatus::InProgress)); + assert_eq!(update.plan[1].step, "Report results"); + assert!(matches!(update.plan[1].status, StepStatus::Pending)); + } + EventMsg::TaskComplete(_) => break, + _ => {} + } + } + + assert!(saw_plan_update, "expected PlanUpdate event"); + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + assert_eq!( + output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + let output_text = extract_output_text(output_item).expect("output text present"); + assert_eq!(output_text, "Plan updated"); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.include_plan_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "plan-tool-invalid"; + let invalid_args = json!({ + "explanation": "Missing plan data" + }) + .to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "update_plan", &invalid_args), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "malformed plan payload"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please update the plan".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let mut saw_plan_update = false; + + loop { + let event = codex.next_event().await.expect("event"); + match event.msg { + EventMsg::PlanUpdate(_) => saw_plan_update = true, + EventMsg::TaskComplete(_) => break, + _ => {} + } + } + + assert!( + !saw_plan_update, + "did not expect PlanUpdate event for malformed payload" + ); + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + assert_eq!( + output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + let output_text = extract_output_text(output_item).expect("output text present"); + assert!( + output_text.contains("failed to parse function arguments"), + "expected parse error message in output text, got {output_text:?}" + ); + if let Some(success_flag) = output_item + .get("output") + .and_then(|value| value.as_object()) + .and_then(|obj| obj.get("success")) + .and_then(serde_json::Value::as_bool) + { + assert!( + !success_flag, + "expected tool output to mark success=false for malformed payload" + ); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.include_apply_patch_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "apply-patch-call"; + let patch_content = r#"*** Begin Patch +*** Add File: notes.txt ++Tool harness apply patch +*** End Patch"#; + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_apply_patch_function_call(call_id, patch_content), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "patch complete"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please apply a patch".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let mut saw_patch_begin = false; + let mut patch_end_success = None; + + loop { + let event = codex.next_event().await.expect("event"); + match event.msg { + EventMsg::PatchApplyBegin(begin) => { + saw_patch_begin = true; + assert_eq!(begin.call_id, call_id); + } + EventMsg::PatchApplyEnd(end) => { + assert_eq!(end.call_id, call_id); + patch_end_success = Some(end.success); + } + EventMsg::TaskComplete(_) => break, + _ => {} + } + } + + assert!(saw_patch_begin, "expected PatchApplyBegin event"); + let patch_end_success = + patch_end_success.expect("expected PatchApplyEnd event to capture success flag"); + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + assert_eq!( + output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + let output_text = extract_output_text(output_item).expect("output text present"); + + if let Ok(exec_output) = serde_json::from_str::(output_text) { + let exit_code = exec_output["metadata"]["exit_code"] + .as_i64() + .expect("exit_code present"); + let summary = exec_output["output"].as_str().expect("output field"); + assert_eq!( + exit_code, 0, + "expected apply_patch exit_code=0, got {exit_code}, summary: {summary:?}" + ); + assert!( + patch_end_success, + "expected PatchApplyEnd success flag, summary: {summary:?}" + ); + assert!( + summary.contains("Success."), + "expected apply_patch summary to note success, got {summary:?}" + ); + + let patched_path = cwd.path().join("notes.txt"); + let contents = std::fs::read_to_string(&patched_path) + .unwrap_or_else(|e| panic!("failed reading {}: {e}", patched_path.display())); + assert_eq!(contents, "Tool harness apply patch\n"); + } else { + assert!( + output_text.contains("codex-run-as-apply-patch"), + "expected apply_patch failure message to mention codex-run-as-apply-patch, got {output_text:?}" + ); + assert!( + !patch_end_success, + "expected PatchApplyEnd to report success=false when apply_patch invocation fails" + ); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.include_apply_patch_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let call_id = "apply-patch-parse-error"; + let patch_content = r"*** Begin Patch +*** Update File: broken.txt +*** End Patch"; + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_apply_patch_function_call(call_id, patch_content), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "failed"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please apply a patch".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + assert_eq!( + output_item.get("call_id").and_then(Value::as_str), + Some(call_id) + ); + let output_text = extract_output_text(output_item).expect("output text present"); + + assert!( + output_text.contains("apply_patch verification failed"), + "expected apply_patch verification failure message, got {output_text:?}" + ); + assert!( + output_text.contains("invalid hunk"), + "expected parse diagnostics in output text, got {output_text:?}" + ); + + if let Some(success_flag) = output_item + .get("output") + .and_then(|value| value.as_object()) + .and_then(|obj| obj.get("success")) + .and_then(serde_json::Value::as_bool) + { + assert!( + !success_flag, + "expected tool output to mark success=false for parse failures" + ); + } + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/tools.rs b/codex-rs/core/tests/suite/tools.rs new file mode 100644 index 00000000..4d6ccdac --- /dev/null +++ b/codex-rs/core/tests/suite/tools.rs @@ -0,0 +1,450 @@ +#![cfg(not(target_os = "windows"))] +#![allow(clippy::unwrap_used, clippy::expect_used)] + +use anyhow::Result; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_custom_tool_call; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use serde_json::Value; +use serde_json::json; +use wiremock::Request; + +async fn submit_turn( + test: &TestCodex, + prompt: &str, + approval_policy: AskForApproval, + sandbox_policy: SandboxPolicy, +) -> Result<()> { + let session_model = test.session_configured.model.clone(); + + test.codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: prompt.into(), + }], + final_output_json_schema: None, + cwd: test.cwd.path().to_path_buf(), + approval_policy, + sandbox_policy, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = test.codex.next_event().await?; + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + Ok(()) +} + +fn request_bodies(requests: &[Request]) -> Result> { + requests + .iter() + .map(|req| Ok(serde_json::from_slice::(&req.body)?)) + .collect() +} + +fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> { + let mut out = Vec::new(); + for body in bodies { + if let Some(items) = body.get("input").and_then(Value::as_array) { + for item in items { + if item.get("type").and_then(Value::as_str) == Some(ty) { + out.push(item); + } + } + } + } + out +} + +fn tool_names(body: &Value) -> Vec { + body.get("tools") + .and_then(Value::as_array) + .map(|tools| { + tools + .iter() + .filter_map(|tool| { + tool.get("name") + .or_else(|| tool.get("type")) + .and_then(Value::as_str) + .map(str::to_string) + }) + .collect() + }) + .unwrap_or_default() +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex(); + let test = builder.build(&server).await?; + + let call_id = "custom-unsupported"; + let tool_name = "unsupported_tool"; + + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_custom_tool_call(call_id, tool_name, "\"payload\""), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + submit_turn( + &test, + "invoke custom tool", + AskForApproval::Never, + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let requests = server.received_requests().await.expect("recorded requests"); + let bodies = request_bodies(&requests)?; + let custom_items = collect_output_items(&bodies, "custom_tool_call_output"); + assert_eq!(custom_items.len(), 1, "expected single custom tool output"); + let item = custom_items[0]; + assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id)); + + let output = item + .get("output") + .and_then(Value::as_str) + .unwrap_or_default(); + let expected = format!("unsupported custom tool call: {tool_name}"); + assert_eq!(output, expected); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex(); + let test = builder.build(&server).await?; + + let command = ["/bin/echo", "shell ok"]; + let call_id_blocked = "shell-blocked"; + let call_id_success = "shell-success"; + + let first_args = json!({ + "command": command, + "timeout_ms": 1_000, + "with_escalated_permissions": true, + }); + let second_args = json!({ + "command": command, + "timeout_ms": 1_000, + }); + + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call( + call_id_blocked, + "shell", + &serde_json::to_string(&first_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-2"}}), + ev_function_call( + call_id_success, + "shell", + &serde_json::to_string(&second_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-3"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + submit_turn( + &test, + "run the shell command", + AskForApproval::Never, + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let requests = server.received_requests().await.expect("recorded requests"); + let bodies = request_bodies(&requests)?; + let function_outputs = collect_output_items(&bodies, "function_call_output"); + for item in &function_outputs { + let call_id = item + .get("call_id") + .and_then(Value::as_str) + .unwrap_or_default(); + assert!( + call_id == call_id_blocked || call_id == call_id_success, + "unexpected call id {call_id}" + ); + } + + let policy = AskForApproval::Never; + let expected_message = format!( + "approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}" + ); + + let blocked_outputs: Vec<&Value> = function_outputs + .iter() + .filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked)) + .copied() + .collect(); + assert!( + !blocked_outputs.is_empty(), + "expected at least one rejection output for {call_id_blocked}" + ); + for item in blocked_outputs { + assert_eq!( + item.get("output").and_then(Value::as_str), + Some(expected_message.as_str()), + "unexpected rejection message" + ); + } + + let success_item = function_outputs + .iter() + .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success)) + .expect("success output present"); + let output_json: Value = serde_json::from_str( + success_item + .get("output") + .and_then(Value::as_str) + .expect("success output string"), + )?; + assert_eq!( + output_json["metadata"]["exit_code"].as_i64(), + Some(0), + "expected exit code 0 after rerunning without escalation", + ); + let stdout = output_json["output"].as_str().unwrap_or_default(); + assert!( + stdout.contains("shell ok"), + "expected stdout to include command output, got {stdout:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex(); + let test = builder.build(&server).await?; + + let local_shell_event = json!({ + "type": "response.output_item.done", + "item": { + "type": "local_shell_call", + "status": "completed", + "action": { + "type": "exec", + "command": ["/bin/echo", "hi"], + } + } + }); + + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + local_shell_event, + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + submit_turn( + &test, + "check shell output", + AskForApproval::Never, + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let requests = server.received_requests().await.expect("recorded requests"); + let bodies = request_bodies(&requests)?; + let function_outputs = collect_output_items(&bodies, "function_call_output"); + assert_eq!( + function_outputs.len(), + 1, + "expected a single function output" + ); + let item = function_outputs[0]; + assert_eq!(item.get("call_id").and_then(Value::as_str), Some("")); + assert_eq!( + item.get("output").and_then(Value::as_str), + Some("LocalShellCall without call_id or id"), + ); + + Ok(()) +} + +async fn collect_tools(use_unified_exec: bool) -> Result> { + let server = start_mock_server().await; + + let responses = vec![sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-1"), + ])]; + mount_sse_sequence(&server, responses).await; + + let mut builder = test_codex().with_config(move |config| { + config.use_experimental_unified_exec_tool = use_unified_exec; + }); + let test = builder.build(&server).await?; + + submit_turn( + &test, + "list tools", + AskForApproval::Never, + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let requests = server.received_requests().await.expect("recorded requests"); + assert_eq!( + requests.len(), + 1, + "expected a single request for tools collection" + ); + let bodies = request_bodies(&requests)?; + let first_body = bodies.first().expect("request body present"); + Ok(tool_names(first_body)) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_spec_toggle_end_to_end() -> Result<()> { + skip_if_no_network!(Ok(())); + + let tools_disabled = collect_tools(false).await?; + assert!( + !tools_disabled.iter().any(|name| name == "unified_exec"), + "tools list should not include unified_exec when disabled: {tools_disabled:?}" + ); + + let tools_enabled = collect_tools(true).await?; + assert!( + tools_enabled.iter().any(|name| name == "unified_exec"), + "tools list should include unified_exec when enabled: {tools_enabled:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + let mut builder = test_codex(); + let test = builder.build(&server).await?; + + let call_id = "shell-timeout"; + let timeout_ms = 50u64; + let args = json!({ + "command": ["/bin/sh", "-c", "yes line | head -n 400; sleep 1"], + "timeout_ms": timeout_ms, + }); + + let responses = vec![ + sse(vec![ + json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call(call_id, "shell", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + submit_turn( + &test, + "run a long command", + AskForApproval::Never, + SandboxPolicy::DangerFullAccess, + ) + .await?; + + let requests = server.received_requests().await.expect("recorded requests"); + let bodies = request_bodies(&requests)?; + let function_outputs = collect_output_items(&bodies, "function_call_output"); + let timeout_item = function_outputs + .iter() + .find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id)) + .expect("timeout output present"); + + let output_json: Value = serde_json::from_str( + timeout_item + .get("output") + .and_then(Value::as_str) + .expect("timeout output string"), + )?; + assert_eq!( + output_json["metadata"]["exit_code"].as_i64(), + Some(124), + "expected timeout exit code 124", + ); + + let stdout = output_json["output"].as_str().unwrap_or_default(); + assert!( + stdout.starts_with("command timed out after "), + "expected timeout prefix, got {stdout:?}" + ); + let first_line = stdout.lines().next().unwrap_or_default(); + let duration_ms = first_line + .strip_prefix("command timed out after ") + .and_then(|line| line.strip_suffix(" milliseconds")) + .and_then(|value| value.parse::().ok()) + .unwrap_or_default(); + assert!( + duration_ms >= timeout_ms, + "expected duration >= configured timeout, got {duration_ms} (timeout {timeout_ms})" + ); + assert!( + stdout.contains("[... omitted"), + "expected truncated output marker, got {stdout:?}" + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/unified_exec.rs b/codex-rs/core/tests/suite/unified_exec.rs new file mode 100644 index 00000000..c81c0ba9 --- /dev/null +++ b/codex-rs/core/tests/suite/unified_exec.rs @@ -0,0 +1,280 @@ +#![cfg(not(target_os = "windows"))] + +use std::collections::HashMap; + +use anyhow::Result; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::skip_if_sandbox; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use serde_json::Value; + +fn extract_output_text(item: &Value) -> Option<&str> { + item.get("output").and_then(|value| match value { + Value::String(text) => Some(text.as_str()), + Value::Object(obj) => obj.get("content").and_then(Value::as_str), + _ => None, + }) +} + +fn collect_tool_outputs(bodies: &[Value]) -> Result> { + let mut outputs = HashMap::new(); + for body in bodies { + if let Some(items) = body.get("input").and_then(Value::as_array) { + for item in items { + if item.get("type").and_then(Value::as_str) != Some("function_call_output") { + continue; + } + if let Some(call_id) = item.get("call_id").and_then(Value::as_str) { + let content = extract_output_text(item) + .ok_or_else(|| anyhow::anyhow!("missing tool output content"))?; + let parsed: Value = serde_json::from_str(content)?; + outputs.insert(call_id.to_string(), parsed); + } + } + } + } + Ok(outputs) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_reuses_session_via_stdin() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let first_call_id = "uexec-start"; + let first_args = serde_json::json!({ + "input": ["/bin/cat"], + "timeout_ms": 200, + }); + + let second_call_id = "uexec-stdin"; + let second_args = serde_json::json!({ + "input": ["hello unified exec\n"], + "session_id": "0", + "timeout_ms": 500, + }); + + let responses = vec![ + sse(vec![ + serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call( + first_call_id, + "unified_exec", + &serde_json::to_string(&first_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}), + ev_function_call( + second_call_id, + "unified_exec", + &serde_json::to_string(&second_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "all done"), + ev_completed("resp-3"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "run unified exec".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + + let start_output = outputs + .get(first_call_id) + .expect("missing first unified_exec output"); + let session_id = start_output["session_id"].as_str().unwrap_or_default(); + assert!( + !session_id.is_empty(), + "expected session id in first unified_exec response" + ); + assert!( + start_output["output"] + .as_str() + .unwrap_or_default() + .is_empty() + ); + + let reuse_output = outputs + .get(second_call_id) + .expect("missing reused unified_exec output"); + assert_eq!( + reuse_output["session_id"].as_str().unwrap_or_default(), + session_id + ); + let echoed = reuse_output["output"].as_str().unwrap_or_default(); + assert!( + echoed.contains("hello unified exec"), + "expected echoed output, got {echoed:?}" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unified_exec_timeout_and_followup_poll() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = start_mock_server().await; + + let mut builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + }); + let TestCodex { + codex, + cwd, + session_configured, + .. + } = builder.build(&server).await?; + + let first_call_id = "uexec-timeout"; + let first_args = serde_json::json!({ + "input": ["/bin/sh", "-c", "sleep 0.1; echo ready"], + "timeout_ms": 10, + }); + + let second_call_id = "uexec-poll"; + let second_args = serde_json::json!({ + "input": Vec::::new(), + "session_id": "0", + "timeout_ms": 800, + }); + + let responses = vec![ + sse(vec![ + serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}), + ev_function_call( + first_call_id, + "unified_exec", + &serde_json::to_string(&first_args)?, + ), + ev_completed("resp-1"), + ]), + sse(vec![ + serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}), + ev_function_call( + second_call_id, + "unified_exec", + &serde_json::to_string(&second_args)?, + ), + ev_completed("resp-2"), + ]), + sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-3"), + ]), + ]; + mount_sse_sequence(&server, responses).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "check timeout".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!(!requests.is_empty(), "expected at least one POST request"); + + let bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let outputs = collect_tool_outputs(&bodies)?; + + let first_output = outputs.get(first_call_id).expect("missing timeout output"); + assert_eq!(first_output["session_id"], "0"); + assert!( + first_output["output"] + .as_str() + .unwrap_or_default() + .is_empty() + ); + + let poll_output = outputs.get(second_call_id).expect("missing poll output"); + let output_text = poll_output["output"].as_str().unwrap_or_default(); + assert!( + output_text.contains("ready"), + "expected ready output, got {output_text:?}" + ); + + Ok(()) +} diff --git a/codex-rs/core/tests/suite/view_image.rs b/codex-rs/core/tests/suite/view_image.rs new file mode 100644 index 00000000..92fbf4ad --- /dev/null +++ b/codex-rs/core/tests/suite/view_image.rs @@ -0,0 +1,351 @@ +#![cfg(not(target_os = "windows"))] + +use base64::Engine; +use base64::engine::general_purpose::STANDARD as BASE64_STANDARD; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::sse; +use core_test_support::responses::start_mock_server; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::TestCodex; +use core_test_support::test_codex::test_codex; +use serde_json::Value; +use wiremock::matchers::any; + +fn function_call_output(body: &Value) -> Option<&Value> { + body.get("input") + .and_then(Value::as_array) + .and_then(|items| { + items.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + }) + }) +} + +fn find_image_message(body: &Value) -> Option<&Value> { + body.get("input") + .and_then(Value::as_array) + .and_then(|items| { + items.iter().find(|item| { + item.get("type").and_then(Value::as_str) == Some("message") + && item + .get("content") + .and_then(Value::as_array) + .map(|content| { + content.iter().any(|span| { + span.get("type").and_then(Value::as_str) == Some("input_image") + }) + }) + .unwrap_or(false) + }) + }) +} + +fn extract_output_text(item: &Value) -> Option<&str> { + item.get("output").and_then(|value| match value { + Value::String(text) => Some(text.as_str()), + Value::Object(obj) => obj.get("content").and_then(Value::as_str), + _ => None, + }) +} + +fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> { + requests + .iter() + .find(|body| function_call_output(body).is_some()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let TestCodex { + codex, + cwd, + session_configured, + .. + } = test_codex().build(&server).await?; + + let rel_path = "assets/example.png"; + let abs_path = cwd.path().join(rel_path); + if let Some(parent) = abs_path.parent() { + std::fs::create_dir_all(parent)?; + } + let image_bytes = b"fake_png_bytes".to_vec(); + std::fs::write(&abs_path, &image_bytes)?; + + let call_id = "view-image-call"; + let arguments = serde_json::json!({ "path": rel_path }).to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "view_image", &arguments), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please add the screenshot".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let mut tool_event = None; + loop { + let event = codex.next_event().await.expect("event"); + match event.msg { + EventMsg::ViewImageToolCall(ev) => tool_event = Some(ev), + EventMsg::TaskComplete(_) => break, + _ => {} + } + } + + let tool_event = tool_event.expect("view image tool event emitted"); + assert_eq!(tool_event.call_id, call_id); + assert_eq!(tool_event.path, abs_path); + + let requests = server.received_requests().await.expect("recorded requests"); + assert!( + requests.len() >= 2, + "expected at least two POST requests, got {}", + requests.len() + ); + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let output_text = extract_output_text(output_item).expect("output text present"); + assert_eq!(output_text, "attached local image path"); + + let image_message = find_image_message(body_with_tool_output) + .expect("pending input image message not included in request"); + let image_url = image_message + .get("content") + .and_then(Value::as_array) + .and_then(|content| { + content.iter().find_map(|span| { + if span.get("type").and_then(Value::as_str) == Some("input_image") { + span.get("image_url").and_then(Value::as_str) + } else { + None + } + }) + }) + .expect("image_url present"); + + let expected_image_url = format!( + "data:image/png;base64,{}", + BASE64_STANDARD.encode(&image_bytes) + ); + assert_eq!(image_url, expected_image_url); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let TestCodex { + codex, + cwd, + session_configured, + .. + } = test_codex().build(&server).await?; + + let rel_path = "assets"; + let abs_path = cwd.path().join(rel_path); + std::fs::create_dir_all(&abs_path)?; + + let call_id = "view-image-directory"; + let arguments = serde_json::json!({ "path": rel_path }).to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "view_image", &arguments), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please attach the folder".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!( + requests.len() >= 2, + "expected at least two POST requests, got {}", + requests.len() + ); + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let output_text = extract_output_text(output_item).expect("output text present"); + let expected_message = format!("image path `{}` is not a file", abs_path.display()); + assert_eq!(output_text, expected_message); + + assert!( + find_image_message(body_with_tool_output).is_none(), + "directory path should not produce an input_image message" + ); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = start_mock_server().await; + + let TestCodex { + codex, + cwd, + session_configured, + .. + } = test_codex().build(&server).await?; + + let rel_path = "missing/example.png"; + let abs_path = cwd.path().join(rel_path); + + let call_id = "view-image-missing"; + let arguments = serde_json::json!({ "path": rel_path }).to_string(); + + let first_response = sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + ev_function_call(call_id, "view_image", &arguments), + ev_completed("resp-1"), + ]); + responses::mount_sse_once_match(&server, any(), first_response).await; + + let second_response = sse(vec![ + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]); + responses::mount_sse_once_match(&server, any(), second_response).await; + + let session_model = session_configured.model.clone(); + + codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "please attach the missing image".into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + loop { + let event = codex.next_event().await.expect("event"); + if matches!(event.msg, EventMsg::TaskComplete(_)) { + break; + } + } + + let requests = server.received_requests().await.expect("recorded requests"); + assert!( + requests.len() >= 2, + "expected at least two POST requests, got {}", + requests.len() + ); + let request_bodies = requests + .iter() + .map(|req| req.body_json::().expect("request json")) + .collect::>(); + + let body_with_tool_output = find_request_with_function_call_output(&request_bodies) + .expect("function_call_output item not found in requests"); + let output_item = function_call_output(body_with_tool_output).expect("tool output item"); + let output_text = extract_output_text(output_item).expect("output text present"); + let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display()); + assert!( + output_text.starts_with(&expected_prefix), + "expected error to start with `{expected_prefix}` but got `{output_text}`" + ); + + assert!( + find_image_message(body_with_tool_output).is_none(), + "missing file should not produce an input_image message" + ); + + Ok(()) +} diff --git a/codex-rs/exec/src/event_processor_with_human_output.rs b/codex-rs/exec/src/event_processor_with_human_output.rs index a78139ee..71ec2b32 100644 --- a/codex-rs/exec/src/event_processor_with_human_output.rs +++ b/codex-rs/exec/src/event_processor_with_human_output.rs @@ -1,7 +1,6 @@ use codex_common::elapsed::format_duration; use codex_common::elapsed::format_elapsed; use codex_core::config::Config; -use codex_core::plan_tool::UpdatePlanArgs; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::AgentReasoningRawContentEvent; use codex_core::protocol::BackgroundEventEvent; @@ -35,6 +34,8 @@ use crate::event_processor::CodexStatus; use crate::event_processor::EventProcessor; use crate::event_processor::handle_last_message; use codex_common::create_config_summary_entries; +use codex_protocol::plan_tool::StepStatus; +use codex_protocol::plan_tool::UpdatePlanArgs; /// This should be configurable. When used in CI, users may not want to impose /// a limit so they can see the full transcript. @@ -456,7 +457,6 @@ impl EventProcessor for EventProcessorWithHumanOutput { // Pretty-print the plan items with simple status markers. for item in plan { - use codex_core::plan_tool::StepStatus; match item.status { StepStatus::Completed => { ts_println!(self, " {} {}", "✓".style(self.green), item.step); diff --git a/codex-rs/exec/src/event_processor_with_jsonl_output.rs b/codex-rs/exec/src/event_processor_with_jsonl_output.rs index 51ecd71a..ea58e033 100644 --- a/codex-rs/exec/src/event_processor_with_jsonl_output.rs +++ b/codex-rs/exec/src/event_processor_with_jsonl_output.rs @@ -31,8 +31,6 @@ use crate::exec_events::TurnStartedEvent; use crate::exec_events::Usage; use crate::exec_events::WebSearchItem; use codex_core::config::Config; -use codex_core::plan_tool::StepStatus; -use codex_core::plan_tool::UpdatePlanArgs; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::AgentReasoningEvent; use codex_core::protocol::Event; @@ -48,6 +46,8 @@ use codex_core::protocol::SessionConfiguredEvent; use codex_core::protocol::TaskCompleteEvent; use codex_core::protocol::TaskStartedEvent; use codex_core::protocol::WebSearchEndEvent; +use codex_protocol::plan_tool::StepStatus; +use codex_protocol::plan_tool::UpdatePlanArgs; use tracing::error; use tracing::warn; diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 488e0fcf..faac8127 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -171,7 +171,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any codex_linux_sandbox_exe, base_instructions: None, include_plan_tool: Some(include_plan_tool), - include_apply_patch_tool: None, + include_apply_patch_tool: Some(true), include_view_image_tool: None, show_raw_agent_reasoning: oss.then_some(true), tools_web_search_request: None, diff --git a/codex-rs/exec/tests/event_processor_with_json_output.rs b/codex-rs/exec/tests/event_processor_with_json_output.rs index b6475ad1..a995b463 100644 --- a/codex-rs/exec/tests/event_processor_with_json_output.rs +++ b/codex-rs/exec/tests/event_processor_with_json_output.rs @@ -37,6 +37,9 @@ use codex_exec::exec_events::TurnFailedEvent; use codex_exec::exec_events::TurnStartedEvent; use codex_exec::exec_events::Usage; use codex_exec::exec_events::WebSearchItem; +use codex_protocol::plan_tool::PlanItemArg; +use codex_protocol::plan_tool::StepStatus; +use codex_protocol::plan_tool::UpdatePlanArgs; use mcp_types::CallToolResult; use pretty_assertions::assert_eq; use std::path::PathBuf; @@ -115,10 +118,6 @@ fn web_search_end_emits_item_completed() { #[test] fn plan_update_emits_todo_list_started_updated_and_completed() { - use codex_core::plan_tool::PlanItemArg; - use codex_core::plan_tool::StepStatus; - use codex_core::plan_tool::UpdatePlanArgs; - let mut ep = EventProcessorWithJsonOutput::new(None); // First plan update => item.started (todo_list) @@ -339,10 +338,6 @@ fn mcp_tool_call_failure_sets_failed_status() { #[test] fn plan_update_after_complete_starts_new_todo_list_with_new_id() { - use codex_core::plan_tool::PlanItemArg; - use codex_core::plan_tool::StepStatus; - use codex_core::plan_tool::UpdatePlanArgs; - let mut ep = EventProcessorWithJsonOutput::new(None); // First turn: start + complete diff --git a/codex-rs/otel/src/otel_event_manager.rs b/codex-rs/otel/src/otel_event_manager.rs index 3e2ffeb7..bda23433 100644 --- a/codex-rs/otel/src/otel_event_manager.rs +++ b/codex-rs/otel/src/otel_event_manager.rs @@ -14,6 +14,7 @@ use eventsource_stream::EventStreamError as StreamError; use reqwest::Error; use reqwest::Response; use serde::Serialize; +use std::borrow::Cow; use std::fmt::Display; use std::time::Duration; use std::time::Instant; @@ -366,10 +367,10 @@ impl OtelEventManager { call_id: &str, arguments: &str, f: F, - ) -> Result + ) -> Result<(String, bool), E> where F: FnOnce() -> Fut, - Fut: Future>, + Fut: Future>, E: Display, { let start = Instant::now(); @@ -377,10 +378,12 @@ impl OtelEventManager { let duration = start.elapsed(); let (output, success) = match &result { - Ok(content) => (content, true), - Err(error) => (&error.to_string(), false), + Ok((preview, success)) => (Cow::Borrowed(preview.as_str()), *success), + Err(error) => (Cow::Owned(error.to_string()), false), }; + let success_str = if success { "true" } else { "false" }; + tracing::event!( tracing::Level::INFO, event.name = "codex.tool_result", @@ -396,7 +399,8 @@ impl OtelEventManager { call_id = %call_id, arguments = %arguments, duration_ms = %duration.as_millis(), - success = %success, + success = %success_str, + // `output` is truncated by the tool layer before reaching telemetry. output = %output, ); diff --git a/codex-rs/protocol/src/models.rs b/codex-rs/protocol/src/models.rs index f6e5599c..4952aa01 100644 --- a/codex-rs/protocol/src/models.rs +++ b/codex-rs/protocol/src/models.rs @@ -259,6 +259,7 @@ pub struct ShellToolCallParams { #[derive(Debug, Clone, PartialEq, TS)] pub struct FunctionCallOutputPayload { pub content: String, + // TODO(jif) drop this. pub success: Option, } diff --git a/codex-rs/tui/src/chatwidget.rs b/codex-rs/tui/src/chatwidget.rs index de40edf1..79ee810f 100644 --- a/codex-rs/tui/src/chatwidget.rs +++ b/codex-rs/tui/src/chatwidget.rs @@ -111,6 +111,7 @@ use codex_git_tooling::GhostCommit; use codex_git_tooling::GitToolingError; use codex_git_tooling::create_ghost_commit; use codex_git_tooling::restore_ghost_commit; +use codex_protocol::plan_tool::UpdatePlanArgs; use strum::IntoEnumIterator; const MAX_TRACKED_GHOST_COMMITS: usize = 20; @@ -508,7 +509,7 @@ impl ChatWidget { self.request_redraw(); } - fn on_plan_update(&mut self, update: codex_core::plan_tool::UpdatePlanArgs) { + fn on_plan_update(&mut self, update: UpdatePlanArgs) { self.add_to_history(history_cell::new_plan_update(update)); } diff --git a/codex-rs/tui/src/chatwidget/tests.rs b/codex-rs/tui/src/chatwidget/tests.rs index cd31d3a6..7a8015b1 100644 --- a/codex-rs/tui/src/chatwidget/tests.rs +++ b/codex-rs/tui/src/chatwidget/tests.rs @@ -8,9 +8,6 @@ use codex_core::CodexAuth; use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; -use codex_core::plan_tool::PlanItemArg; -use codex_core::plan_tool::StepStatus; -use codex_core::plan_tool::UpdatePlanArgs; use codex_core::protocol::AgentMessageDeltaEvent; use codex_core::protocol::AgentMessageEvent; use codex_core::protocol::AgentReasoningDeltaEvent; @@ -37,6 +34,9 @@ use codex_core::protocol::TaskCompleteEvent; use codex_core::protocol::TaskStartedEvent; use codex_core::protocol::ViewImageToolCallEvent; use codex_protocol::ConversationId; +use codex_protocol::plan_tool::PlanItemArg; +use codex_protocol::plan_tool::StepStatus; +use codex_protocol::plan_tool::UpdatePlanArgs; use crossterm::event::KeyCode; use crossterm::event::KeyEvent; use crossterm::event::KeyModifiers; diff --git a/codex-rs/tui/src/history_cell.rs b/codex-rs/tui/src/history_cell.rs index e91fce65..e3a89da3 100644 --- a/codex-rs/tui/src/history_cell.rs +++ b/codex-rs/tui/src/history_cell.rs @@ -21,13 +21,13 @@ use base64::Engine; use codex_core::config::Config; use codex_core::config_types::McpServerTransportConfig; use codex_core::config_types::ReasoningSummaryFormat; -use codex_core::plan_tool::PlanItemArg; -use codex_core::plan_tool::StepStatus; -use codex_core::plan_tool::UpdatePlanArgs; use codex_core::protocol::FileChange; use codex_core::protocol::McpInvocation; use codex_core::protocol::SessionConfiguredEvent; use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig; +use codex_protocol::plan_tool::PlanItemArg; +use codex_protocol::plan_tool::StepStatus; +use codex_protocol::plan_tool::UpdatePlanArgs; use image::DynamicImage; use image::ImageReader; use mcp_types::EmbeddedResourceResource; diff --git a/codex-rs/utils/string/Cargo.toml b/codex-rs/utils/string/Cargo.toml new file mode 100644 index 00000000..698c4b2f --- /dev/null +++ b/codex-rs/utils/string/Cargo.toml @@ -0,0 +1,7 @@ +[package] +edition.workspace = true +name = "codex-utils-string" +version.workspace = true + +[lints] +workspace = true diff --git a/codex-rs/utils/string/src/lib.rs b/codex-rs/utils/string/src/lib.rs new file mode 100644 index 00000000..f7299d43 --- /dev/null +++ b/codex-rs/utils/string/src/lib.rs @@ -0,0 +1,38 @@ +// Truncate a &str to a byte budget at a char boundary (prefix) +#[inline] +pub fn take_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { + if s.len() <= maxb { + return s; + } + let mut last_ok = 0; + for (i, ch) in s.char_indices() { + let nb = i + ch.len_utf8(); + if nb > maxb { + break; + } + last_ok = nb; + } + &s[..last_ok] +} + +// Take a suffix of a &str within a byte budget at a char boundary +#[inline] +pub fn take_last_bytes_at_char_boundary(s: &str, maxb: usize) -> &str { + if s.len() <= maxb { + return s; + } + let mut start = s.len(); + let mut used = 0usize; + for (i, ch) in s.char_indices().rev() { + let nb = ch.len_utf8(); + if used + nb > maxb { + break; + } + start = i; + used += nb; + if start == 0 { + break; + } + } + &s[start..] +}