chore: refactor tool handling (#4510)
# Tool System Refactor - Centralizes tool definitions and execution in `core/src/tools/*`: specs (`spec.rs`), handlers (`handlers/*`), router (`router.rs`), registry/dispatch (`registry.rs`), and shared context (`context.rs`). One registry now builds the model-visible tool list and binds handlers. - Router converts model responses to tool calls; Registry dispatches with consistent telemetry via `codex-rs/otel` and unified error handling. Function, Local Shell, MCP, and experimental `unified_exec` all flow through this path; legacy shell aliases still work. - Rationale: reduce per‑tool boilerplate, keep spec/handler in sync, and make adding tools predictable and testable. Example: `read_file` - Spec: `core/src/tools/spec.rs` (see `create_read_file_tool`, registered by `build_specs`). - Handler: `core/src/tools/handlers/read_file.rs` (absolute `file_path`, 1‑indexed `offset`, `limit`, `L#: ` prefixes, safe truncation). - E2E test: `core/tests/suite/read_file.rs` validates the tool returns the requested lines. ## Next steps: - Decompose `handle_container_exec_with_params` - Add parallel tool calls
This commit is contained in:
5
codex-rs/Cargo.lock
generated
5
codex-rs/Cargo.lock
generated
@@ -861,6 +861,7 @@ dependencies = [
|
|||||||
"codex-otel",
|
"codex-otel",
|
||||||
"codex-protocol",
|
"codex-protocol",
|
||||||
"codex-rmcp-client",
|
"codex-rmcp-client",
|
||||||
|
"codex-utils-string",
|
||||||
"core_test_support",
|
"core_test_support",
|
||||||
"dirs",
|
"dirs",
|
||||||
"dunce",
|
"dunce",
|
||||||
@@ -1254,6 +1255,10 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "codex-utils-string"
|
||||||
|
version = "0.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "color-eyre"
|
name = "color-eyre"
|
||||||
version = "0.6.5"
|
version = "0.6.5"
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ members = [
|
|||||||
"git-apply",
|
"git-apply",
|
||||||
"utils/json-to-toml",
|
"utils/json-to-toml",
|
||||||
"utils/readiness",
|
"utils/readiness",
|
||||||
|
"utils/string",
|
||||||
]
|
]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
@@ -71,6 +72,7 @@ codex-rmcp-client = { path = "rmcp-client" }
|
|||||||
codex-tui = { path = "tui" }
|
codex-tui = { path = "tui" }
|
||||||
codex-utils-json-to-toml = { path = "utils/json-to-toml" }
|
codex-utils-json-to-toml = { path = "utils/json-to-toml" }
|
||||||
codex-utils-readiness = { path = "utils/readiness" }
|
codex-utils-readiness = { path = "utils/readiness" }
|
||||||
|
codex-utils-string = { path = "utils/string" }
|
||||||
core_test_support = { path = "core/tests/common" }
|
core_test_support = { path = "core/tests/common" }
|
||||||
mcp-types = { path = "mcp-types" }
|
mcp-types = { path = "mcp-types" }
|
||||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ codex-rmcp-client = { workspace = true }
|
|||||||
codex-protocol = { workspace = true }
|
codex-protocol = { workspace = true }
|
||||||
codex-app-server-protocol = { workspace = true }
|
codex-app-server-protocol = { workspace = true }
|
||||||
codex-otel = { workspace = true, features = ["otel"] }
|
codex-otel = { workspace = true, features = ["otel"] }
|
||||||
|
codex-utils-string = { workspace = true }
|
||||||
dirs = { workspace = true }
|
dirs = { workspace = true }
|
||||||
dunce = { workspace = true }
|
dunce = { workspace = true }
|
||||||
env-flags = { workspace = true }
|
env-flags = { workspace = true }
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
use crate::client_common::tools::ToolSpec;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
use crate::model_family::ModelFamily;
|
use crate::model_family::ModelFamily;
|
||||||
use crate::openai_tools::OpenAiTool;
|
|
||||||
use crate::protocol::RateLimitSnapshot;
|
use crate::protocol::RateLimitSnapshot;
|
||||||
use crate::protocol::TokenUsage;
|
use crate::protocol::TokenUsage;
|
||||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||||
@@ -29,7 +29,7 @@ pub struct Prompt {
|
|||||||
|
|
||||||
/// Tools available to the model, including additional tools sourced from
|
/// Tools available to the model, including additional tools sourced from
|
||||||
/// external MCP servers.
|
/// external MCP servers.
|
||||||
pub(crate) tools: Vec<OpenAiTool>,
|
pub(crate) tools: Vec<ToolSpec>,
|
||||||
|
|
||||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||||
pub base_instructions_override: Option<String>,
|
pub base_instructions_override: Option<String>,
|
||||||
@@ -49,8 +49,8 @@ impl Prompt {
|
|||||||
// AND
|
// AND
|
||||||
// - there is no apply_patch tool present
|
// - there is no apply_patch tool present
|
||||||
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||||
OpenAiTool::Function(f) => f.name == "apply_patch",
|
ToolSpec::Function(f) => f.name == "apply_patch",
|
||||||
OpenAiTool::Freeform(f) => f.name == "apply_patch",
|
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||||
_ => false,
|
_ => false,
|
||||||
});
|
});
|
||||||
if self.base_instructions_override.is_none()
|
if self.base_instructions_override.is_none()
|
||||||
@@ -160,6 +160,54 @@ pub(crate) struct ResponsesApiRequest<'a> {
|
|||||||
pub(crate) text: Option<TextControls>,
|
pub(crate) text: Option<TextControls>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) mod tools {
|
||||||
|
use crate::openai_tools::JsonSchema;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||||
|
/// Responses API.
|
||||||
|
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub(crate) enum ToolSpec {
|
||||||
|
#[serde(rename = "function")]
|
||||||
|
Function(ResponsesApiTool),
|
||||||
|
#[serde(rename = "local_shell")]
|
||||||
|
LocalShell {},
|
||||||
|
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||||
|
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||||
|
#[serde(rename = "web_search")]
|
||||||
|
WebSearch {},
|
||||||
|
#[serde(rename = "custom")]
|
||||||
|
Freeform(FreeformTool),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct FreeformTool {
|
||||||
|
pub(crate) name: String,
|
||||||
|
pub(crate) description: String,
|
||||||
|
pub(crate) format: FreeformToolFormat,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
pub struct FreeformToolFormat {
|
||||||
|
pub(crate) r#type: String,
|
||||||
|
pub(crate) syntax: String,
|
||||||
|
pub(crate) definition: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||||
|
pub struct ResponsesApiTool {
|
||||||
|
pub(crate) name: String,
|
||||||
|
pub(crate) description: String,
|
||||||
|
/// TODO: Validation. When strict is set to true, the JSON schema,
|
||||||
|
/// `required` and `additional_properties` must be present. All fields in
|
||||||
|
/// `properties` must be present in `required`.
|
||||||
|
pub(crate) strict: bool,
|
||||||
|
pub(crate) parameters: JsonSchema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn create_reasoning_param_for_request(
|
pub(crate) fn create_reasoning_param_for_request(
|
||||||
model_family: &ModelFamily,
|
model_family: &ModelFamily,
|
||||||
effort: Option<ReasoningEffortConfig>,
|
effort: Option<ReasoningEffortConfig>,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -108,6 +108,9 @@ pub enum CodexErr {
|
|||||||
#[error("unsupported operation: {0}")]
|
#[error("unsupported operation: {0}")]
|
||||||
UnsupportedOperation(String),
|
UnsupportedOperation(String),
|
||||||
|
|
||||||
|
#[error("Fatal error: {0}")]
|
||||||
|
Fatal(String),
|
||||||
|
|
||||||
// -----------------------------------------------------------------
|
// -----------------------------------------------------------------
|
||||||
// Automatic conversions for common external error types
|
// Automatic conversions for common external error types
|
||||||
// -----------------------------------------------------------------
|
// -----------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
|
use crate::client_common::tools::ResponsesApiTool;
|
||||||
use crate::openai_tools::JsonSchema;
|
use crate::openai_tools::JsonSchema;
|
||||||
use crate::openai_tools::ResponsesApiTool;
|
|
||||||
|
|
||||||
pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
|
pub const EXEC_COMMAND_TOOL_NAME: &str = "exec_command";
|
||||||
pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
|
pub const WRITE_STDIN_TOOL_NAME: &str = "write_stdin";
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ pub(crate) use runner::ExecutorConfig;
|
|||||||
pub(crate) use runner::normalize_exec_result;
|
pub(crate) use runner::normalize_exec_result;
|
||||||
|
|
||||||
pub(crate) mod linkers {
|
pub(crate) mod linkers {
|
||||||
use crate::codex::ExecCommandContext;
|
|
||||||
use crate::exec::ExecParams;
|
use crate::exec::ExecParams;
|
||||||
use crate::exec::StdoutStream;
|
use crate::exec::StdoutStream;
|
||||||
use crate::executor::backends::ExecutionMode;
|
use crate::executor::backends::ExecutionMode;
|
||||||
use crate::executor::runner::ExecutionRequest;
|
use crate::executor::runner::ExecutionRequest;
|
||||||
|
use crate::tools::context::ExecCommandContext;
|
||||||
|
|
||||||
pub struct PreparedExec {
|
pub struct PreparedExec {
|
||||||
pub(crate) context: ExecCommandContext,
|
pub(crate) context: ExecCommandContext,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use std::time::Duration;
|
|||||||
use super::backends::ExecutionMode;
|
use super::backends::ExecutionMode;
|
||||||
use super::backends::backend_for_mode;
|
use super::backends::backend_for_mode;
|
||||||
use super::cache::ApprovalCache;
|
use super::cache::ApprovalCache;
|
||||||
use crate::codex::ExecCommandContext;
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::error::SandboxErr;
|
use crate::error::SandboxErr;
|
||||||
@@ -24,6 +23,7 @@ use crate::protocol::AskForApproval;
|
|||||||
use crate::protocol::ReviewDecision;
|
use crate::protocol::ReviewDecision;
|
||||||
use crate::protocol::SandboxPolicy;
|
use crate::protocol::SandboxPolicy;
|
||||||
use crate::shell;
|
use crate::shell;
|
||||||
|
use crate::tools::context::ExecCommandContext;
|
||||||
use codex_otel::otel_event_manager::ToolDecisionSource;
|
use codex_otel::otel_event_manager::ToolDecisionSource;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -303,6 +303,7 @@ pub(crate) fn normalize_exec_result(
|
|||||||
let message = match err {
|
let message = match err {
|
||||||
ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(),
|
ExecError::Function(FunctionCallError::RespondToModel(msg)) => msg.clone(),
|
||||||
ExecError::Codex(e) => get_error_message_ui(e),
|
ExecError::Codex(e) => get_error_message_ui(e),
|
||||||
|
err => err.to_string(),
|
||||||
};
|
};
|
||||||
let synthetic = ExecToolCallOutput {
|
let synthetic = ExecToolCallOutput {
|
||||||
exit_code: -1,
|
exit_code: -1,
|
||||||
|
|||||||
@@ -4,4 +4,8 @@ use thiserror::Error;
|
|||||||
pub enum FunctionCallError {
|
pub enum FunctionCallError {
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
RespondToModel(String),
|
RespondToModel(String),
|
||||||
|
#[error("LocalShellCall without call_id or id")]
|
||||||
|
MissingLocalShellCallId,
|
||||||
|
#[error("Fatal error: {0}")]
|
||||||
|
Fatal(String),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ pub mod default_client;
|
|||||||
pub mod model_family;
|
pub mod model_family;
|
||||||
mod openai_model_info;
|
mod openai_model_info;
|
||||||
mod openai_tools;
|
mod openai_tools;
|
||||||
pub mod plan_tool;
|
|
||||||
pub mod project_doc;
|
pub mod project_doc;
|
||||||
mod rollout;
|
mod rollout;
|
||||||
pub(crate) mod safety;
|
pub(crate) mod safety;
|
||||||
@@ -65,7 +64,7 @@ pub mod seatbelt;
|
|||||||
pub mod shell;
|
pub mod shell;
|
||||||
pub mod spawn;
|
pub mod spawn;
|
||||||
pub mod terminal;
|
pub mod terminal;
|
||||||
mod tool_apply_patch;
|
mod tools;
|
||||||
pub mod turn_diff_tracker;
|
pub mod turn_diff_tracker;
|
||||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||||
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::config_types::ReasoningSummaryFormat;
|
use crate::config_types::ReasoningSummaryFormat;
|
||||||
use crate::tool_apply_patch::ApplyPatchToolType;
|
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
|
||||||
|
|
||||||
/// The `instructions` field in the payload sent to a model should always start
|
/// The `instructions` field in the payload sent to a model should always start
|
||||||
/// with this content.
|
/// with this content.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
244
codex-rs/core/src/tools/context.rs
Normal file
244
codex-rs/core/src/tools/context.rs
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
|
use crate::tools::TELEMETRY_PREVIEW_MAX_BYTES;
|
||||||
|
use crate::tools::TELEMETRY_PREVIEW_MAX_LINES;
|
||||||
|
use crate::tools::TELEMETRY_PREVIEW_TRUNCATION_NOTICE;
|
||||||
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
|
use codex_otel::otel_event_manager::OtelEventManager;
|
||||||
|
use codex_protocol::models::FunctionCallOutputPayload;
|
||||||
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
use codex_protocol::models::ShellToolCallParams;
|
||||||
|
use codex_protocol::protocol::FileChange;
|
||||||
|
use codex_utils_string::take_bytes_at_char_boundary;
|
||||||
|
use mcp_types::CallToolResult;
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
pub struct ToolInvocation<'a> {
|
||||||
|
pub session: &'a Session,
|
||||||
|
pub turn: &'a TurnContext,
|
||||||
|
pub tracker: &'a mut TurnDiffTracker,
|
||||||
|
pub sub_id: &'a str,
|
||||||
|
pub call_id: String,
|
||||||
|
pub tool_name: String,
|
||||||
|
pub payload: ToolPayload,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ToolPayload {
|
||||||
|
Function {
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
Custom {
|
||||||
|
input: String,
|
||||||
|
},
|
||||||
|
LocalShell {
|
||||||
|
params: ShellToolCallParams,
|
||||||
|
},
|
||||||
|
UnifiedExec {
|
||||||
|
arguments: String,
|
||||||
|
},
|
||||||
|
Mcp {
|
||||||
|
server: String,
|
||||||
|
tool: String,
|
||||||
|
raw_arguments: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolPayload {
|
||||||
|
pub fn log_payload(&self) -> Cow<'_, str> {
|
||||||
|
match self {
|
||||||
|
ToolPayload::Function { arguments } => Cow::Borrowed(arguments),
|
||||||
|
ToolPayload::Custom { input } => Cow::Borrowed(input),
|
||||||
|
ToolPayload::LocalShell { params } => Cow::Owned(params.command.join(" ")),
|
||||||
|
ToolPayload::UnifiedExec { arguments } => Cow::Borrowed(arguments),
|
||||||
|
ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ToolOutput {
|
||||||
|
Function {
|
||||||
|
content: String,
|
||||||
|
success: Option<bool>,
|
||||||
|
},
|
||||||
|
Mcp {
|
||||||
|
result: Result<CallToolResult, String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolOutput {
|
||||||
|
pub fn log_preview(&self) -> String {
|
||||||
|
match self {
|
||||||
|
ToolOutput::Function { content, .. } => telemetry_preview(content),
|
||||||
|
ToolOutput::Mcp { result } => format!("{result:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn success_for_logging(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ToolOutput::Function { success, .. } => success.unwrap_or(true),
|
||||||
|
ToolOutput::Mcp { result } => result.is_ok(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||||
|
match self {
|
||||||
|
ToolOutput::Function { content, success } => {
|
||||||
|
if matches!(payload, ToolPayload::Custom { .. }) {
|
||||||
|
ResponseInputItem::CustomToolCallOutput {
|
||||||
|
call_id: call_id.to_string(),
|
||||||
|
output: content,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id: call_id.to_string(),
|
||||||
|
output: FunctionCallOutputPayload { content, success },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput {
|
||||||
|
call_id: call_id.to_string(),
|
||||||
|
result,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn telemetry_preview(content: &str) -> String {
|
||||||
|
let truncated_slice = take_bytes_at_char_boundary(content, TELEMETRY_PREVIEW_MAX_BYTES);
|
||||||
|
let truncated_by_bytes = truncated_slice.len() < content.len();
|
||||||
|
|
||||||
|
let mut preview = String::new();
|
||||||
|
let mut lines_iter = truncated_slice.lines();
|
||||||
|
for idx in 0..TELEMETRY_PREVIEW_MAX_LINES {
|
||||||
|
match lines_iter.next() {
|
||||||
|
Some(line) => {
|
||||||
|
if idx > 0 {
|
||||||
|
preview.push('\n');
|
||||||
|
}
|
||||||
|
preview.push_str(line);
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let truncated_by_lines = lines_iter.next().is_some();
|
||||||
|
|
||||||
|
if !truncated_by_bytes && !truncated_by_lines {
|
||||||
|
return content.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
if preview.len() < truncated_slice.len()
|
||||||
|
&& truncated_slice
|
||||||
|
.as_bytes()
|
||||||
|
.get(preview.len())
|
||||||
|
.is_some_and(|byte| *byte == b'\n')
|
||||||
|
{
|
||||||
|
preview.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
if !preview.is_empty() && !preview.ends_with('\n') {
|
||||||
|
preview.push('\n');
|
||||||
|
}
|
||||||
|
preview.push_str(TELEMETRY_PREVIEW_TRUNCATION_NOTICE);
|
||||||
|
|
||||||
|
preview
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_tool_calls_should_roundtrip_as_custom_outputs() {
|
||||||
|
let payload = ToolPayload::Custom {
|
||||||
|
input: "patch".to_string(),
|
||||||
|
};
|
||||||
|
let response = ToolOutput::Function {
|
||||||
|
content: "patched".to_string(),
|
||||||
|
success: Some(true),
|
||||||
|
}
|
||||||
|
.into_response("call-42", &payload);
|
||||||
|
|
||||||
|
match response {
|
||||||
|
ResponseInputItem::CustomToolCallOutput { call_id, output } => {
|
||||||
|
assert_eq!(call_id, "call-42");
|
||||||
|
assert_eq!(output, "patched");
|
||||||
|
}
|
||||||
|
other => panic!("expected CustomToolCallOutput, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn function_payloads_remain_function_outputs() {
|
||||||
|
let payload = ToolPayload::Function {
|
||||||
|
arguments: "{}".to_string(),
|
||||||
|
};
|
||||||
|
let response = ToolOutput::Function {
|
||||||
|
content: "ok".to_string(),
|
||||||
|
success: Some(true),
|
||||||
|
}
|
||||||
|
.into_response("fn-1", &payload);
|
||||||
|
|
||||||
|
match response {
|
||||||
|
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||||
|
assert_eq!(call_id, "fn-1");
|
||||||
|
assert_eq!(output.content, "ok");
|
||||||
|
assert_eq!(output.success, Some(true));
|
||||||
|
}
|
||||||
|
other => panic!("expected FunctionCallOutput, got {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telemetry_preview_returns_original_within_limits() {
|
||||||
|
let content = "short output";
|
||||||
|
assert_eq!(telemetry_preview(content), content);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telemetry_preview_truncates_by_bytes() {
|
||||||
|
let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8);
|
||||||
|
let preview = telemetry_preview(&content);
|
||||||
|
|
||||||
|
assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||||
|
assert!(
|
||||||
|
preview.len()
|
||||||
|
<= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telemetry_preview_truncates_by_lines() {
|
||||||
|
let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5))
|
||||||
|
.map(|idx| format!("line {idx}"))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
let preview = telemetry_preview(&content);
|
||||||
|
let lines: Vec<&str> = preview.lines().collect();
|
||||||
|
|
||||||
|
assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1);
|
||||||
|
assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct ExecCommandContext {
|
||||||
|
pub(crate) sub_id: String,
|
||||||
|
pub(crate) call_id: String,
|
||||||
|
pub(crate) command_for_display: Vec<String>,
|
||||||
|
pub(crate) cwd: PathBuf,
|
||||||
|
pub(crate) apply_patch: Option<ApplyPatchCommandContext>,
|
||||||
|
pub(crate) tool_name: String,
|
||||||
|
pub(crate) otel_event_manager: OtelEventManager,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub(crate) struct ApplyPatchCommandContext {
|
||||||
|
pub(crate) user_explicitly_approved_this_action: bool,
|
||||||
|
pub(crate) changes: HashMap<PathBuf, FileChange>,
|
||||||
|
}
|
||||||
@@ -1,15 +1,99 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::client_common::tools::FreeformTool;
|
||||||
|
use crate::client_common::tools::FreeformToolFormat;
|
||||||
|
use crate::client_common::tools::ResponsesApiTool;
|
||||||
|
use crate::client_common::tools::ToolSpec;
|
||||||
|
use crate::exec::ExecParams;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::openai_tools::JsonSchema;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::handle_container_exec_with_params;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
use crate::tools::spec::ApplyPatchToolArgs;
|
||||||
|
use async_trait::async_trait;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
|
||||||
use crate::openai_tools::FreeformTool;
|
pub struct ApplyPatchHandler;
|
||||||
use crate::openai_tools::FreeformToolFormat;
|
|
||||||
use crate::openai_tools::JsonSchema;
|
|
||||||
use crate::openai_tools::OpenAiTool;
|
|
||||||
use crate::openai_tools::ResponsesApiTool;
|
|
||||||
|
|
||||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for ApplyPatchHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||||
|
matches!(
|
||||||
|
payload,
|
||||||
|
ToolPayload::Function { .. } | ToolPayload::Custom { .. }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
tool_name,
|
||||||
|
payload,
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let patch_input = match payload {
|
||||||
|
ToolPayload::Function { arguments } => {
|
||||||
|
let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {e:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
args.input
|
||||||
|
}
|
||||||
|
ToolPayload::Custom { input } => input,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"apply_patch handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let exec_params = ExecParams {
|
||||||
|
command: vec!["apply_patch".to_string(), patch_input.clone()],
|
||||||
|
cwd: turn.cwd.clone(),
|
||||||
|
timeout_ms: None,
|
||||||
|
env: HashMap::new(),
|
||||||
|
with_escalated_permissions: None,
|
||||||
|
justification: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = handle_container_exec_with_params(
|
||||||
|
tool_name.as_str(),
|
||||||
|
exec_params,
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id.to_string(),
|
||||||
|
call_id.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum ApplyPatchToolType {
|
pub enum ApplyPatchToolType {
|
||||||
@@ -19,8 +103,8 @@ pub enum ApplyPatchToolType {
|
|||||||
|
|
||||||
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
|
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
|
||||||
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
|
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
|
||||||
pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
|
||||||
OpenAiTool::Freeform(FreeformTool {
|
ToolSpec::Freeform(FreeformTool {
|
||||||
name: "apply_patch".to_string(),
|
name: "apply_patch".to_string(),
|
||||||
description: "Use the `apply_patch` tool to edit files".to_string(),
|
description: "Use the `apply_patch` tool to edit files".to_string(),
|
||||||
format: FreeformToolFormat {
|
format: FreeformToolFormat {
|
||||||
@@ -32,7 +116,7 @@ pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models
|
/// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models
|
||||||
pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool {
|
pub(crate) fn create_apply_patch_json_tool() -> ToolSpec {
|
||||||
let mut properties = BTreeMap::new();
|
let mut properties = BTreeMap::new();
|
||||||
properties.insert(
|
properties.insert(
|
||||||
"input".to_string(),
|
"input".to_string(),
|
||||||
@@ -41,7 +125,7 @@ pub(crate) fn create_apply_patch_json_tool() -> OpenAiTool {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
OpenAiTool::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "apply_patch".to_string(),
|
name: "apply_patch".to_string(),
|
||||||
description: r#"Use the `apply_patch` tool to edit files.
|
description: r#"Use the `apply_patch` tool to edit files.
|
||||||
Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
||||||
@@ -111,7 +195,7 @@ It is important to remember:
|
|||||||
- You must prefix new lines with `+` even when creating a new file
|
- You must prefix new lines with `+` even when creating a new file
|
||||||
- File references can only be relative, NEVER ABSOLUTE.
|
- File references can only be relative, NEVER ABSOLUTE.
|
||||||
"#
|
"#
|
||||||
.to_string(),
|
.to_string(),
|
||||||
strict: false,
|
strict: false,
|
||||||
parameters: JsonSchema::Object {
|
parameters: JsonSchema::Object {
|
||||||
properties,
|
properties,
|
||||||
71
codex-rs/core/src/tools/handlers/exec_stream.rs
Normal file
71
codex-rs/core/src/tools/handlers/exec_stream.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::exec_command::EXEC_COMMAND_TOOL_NAME;
|
||||||
|
use crate::exec_command::ExecCommandParams;
|
||||||
|
use crate::exec_command::WRITE_STDIN_TOOL_NAME;
|
||||||
|
use crate::exec_command::WriteStdinParams;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct ExecStreamHandler;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for ExecStreamHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
tool_name,
|
||||||
|
payload,
|
||||||
|
..
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let arguments = match payload {
|
||||||
|
ToolPayload::Function { arguments } => arguments,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"exec_stream handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = match tool_name.as_str() {
|
||||||
|
EXEC_COMMAND_TOOL_NAME => {
|
||||||
|
let params: ExecCommandParams = serde_json::from_str(&arguments).map_err(|e| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {e:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
session.handle_exec_command_tool(params).await?
|
||||||
|
}
|
||||||
|
WRITE_STDIN_TOOL_NAME => {
|
||||||
|
let params: WriteStdinParams = serde_json::from_str(&arguments).map_err(|e| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {e:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
session.handle_write_stdin_tool(params).await?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"exec_stream handler does not support tool {tool_name}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
70
codex-rs/core/src/tools/handlers/mcp.rs
Normal file
70
codex-rs/core/src/tools/handlers/mcp.rs
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct McpHandler;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for McpHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Mcp
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
payload,
|
||||||
|
..
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let payload = match payload {
|
||||||
|
ToolPayload::Mcp {
|
||||||
|
server,
|
||||||
|
tool,
|
||||||
|
raw_arguments,
|
||||||
|
} => (server, tool, raw_arguments),
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"mcp handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (server, tool, raw_arguments) = payload;
|
||||||
|
let arguments_str = raw_arguments;
|
||||||
|
|
||||||
|
let response = handle_mcp_tool_call(
|
||||||
|
session,
|
||||||
|
sub_id,
|
||||||
|
call_id.clone(),
|
||||||
|
server,
|
||||||
|
tool,
|
||||||
|
arguments_str,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match response {
|
||||||
|
codex_protocol::models::ResponseInputItem::McpToolCallOutput { result, .. } => {
|
||||||
|
Ok(ToolOutput::Mcp { result })
|
||||||
|
}
|
||||||
|
codex_protocol::models::ResponseInputItem::FunctionCallOutput { output, .. } => {
|
||||||
|
let codex_protocol::models::FunctionCallOutputPayload { content, success } = output;
|
||||||
|
Ok(ToolOutput::Function { content, success })
|
||||||
|
}
|
||||||
|
_ => Err(FunctionCallError::RespondToModel(
|
||||||
|
"mcp handler received unexpected response variant".to_string(),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
19
codex-rs/core/src/tools/handlers/mod.rs
Normal file
19
codex-rs/core/src/tools/handlers/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
pub mod apply_patch;
|
||||||
|
mod exec_stream;
|
||||||
|
mod mcp;
|
||||||
|
mod plan;
|
||||||
|
mod read_file;
|
||||||
|
mod shell;
|
||||||
|
mod unified_exec;
|
||||||
|
mod view_image;
|
||||||
|
|
||||||
|
pub use plan::PLAN_TOOL;
|
||||||
|
|
||||||
|
pub use apply_patch::ApplyPatchHandler;
|
||||||
|
pub use exec_stream::ExecStreamHandler;
|
||||||
|
pub use mcp::McpHandler;
|
||||||
|
pub use plan::PlanHandler;
|
||||||
|
pub use read_file::ReadFileHandler;
|
||||||
|
pub use shell::ShellHandler;
|
||||||
|
pub use unified_exec::UnifiedExecHandler;
|
||||||
|
pub use view_image::ViewImageHandler;
|
||||||
@@ -1,23 +1,23 @@
|
|||||||
use std::collections::BTreeMap;
|
use crate::client_common::tools::ResponsesApiTool;
|
||||||
use std::sync::LazyLock;
|
use crate::client_common::tools::ToolSpec;
|
||||||
|
|
||||||
use crate::codex::Session;
|
use crate::codex::Session;
|
||||||
use crate::function_tool::FunctionCallError;
|
use crate::function_tool::FunctionCallError;
|
||||||
use crate::openai_tools::JsonSchema;
|
use crate::openai_tools::JsonSchema;
|
||||||
use crate::openai_tools::OpenAiTool;
|
use crate::tools::context::ToolInvocation;
|
||||||
use crate::openai_tools::ResponsesApiTool;
|
use crate::tools::context::ToolOutput;
|
||||||
use crate::protocol::Event;
|
use crate::tools::context::ToolPayload;
|
||||||
use crate::protocol::EventMsg;
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
|
use codex_protocol::protocol::Event;
|
||||||
|
use codex_protocol::protocol::EventMsg;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
// Use the canonical plan tool types from the protocol crate to ensure
|
pub struct PlanHandler;
|
||||||
// type-identity matches events transported via `codex_protocol`.
|
|
||||||
pub use codex_protocol::plan_tool::PlanItemArg;
|
|
||||||
pub use codex_protocol::plan_tool::StepStatus;
|
|
||||||
pub use codex_protocol::plan_tool::UpdatePlanArgs;
|
|
||||||
|
|
||||||
// Types for the TODO tool arguments matching codex-vscode/todo-mcp/src/main.rs
|
pub static PLAN_TOOL: LazyLock<ToolSpec> = LazyLock::new(|| {
|
||||||
|
|
||||||
pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
|
||||||
let mut plan_item_props = BTreeMap::new();
|
let mut plan_item_props = BTreeMap::new();
|
||||||
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
||||||
plan_item_props.insert(
|
plan_item_props.insert(
|
||||||
@@ -43,7 +43,7 @@ pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
|
|||||||
);
|
);
|
||||||
properties.insert("plan".to_string(), plan_items_schema);
|
properties.insert("plan".to_string(), plan_items_schema);
|
||||||
|
|
||||||
OpenAiTool::Function(ResponsesApiTool {
|
ToolSpec::Function(ResponsesApiTool {
|
||||||
name: "update_plan".to_string(),
|
name: "update_plan".to_string(),
|
||||||
description: r#"Updates the task plan.
|
description: r#"Updates the task plan.
|
||||||
Provide an optional explanation and a list of plan items, each with a step and status.
|
Provide an optional explanation and a list of plan items, each with a step and status.
|
||||||
@@ -59,6 +59,42 @@ At most one step can be in_progress at a time.
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for PlanHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
payload,
|
||||||
|
..
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let arguments = match payload {
|
||||||
|
ToolPayload::Function { arguments } => arguments,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"update_plan handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = handle_update_plan(session, arguments, sub_id.to_string(), call_id).await?;
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render.
|
/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render.
|
||||||
/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other
|
/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other
|
||||||
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
||||||
255
codex-rs/core/src/tools/handlers/read_file.rs
Normal file
255
codex-rs/core/src/tools/handlers/read_file.rs
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use codex_utils_string::take_bytes_at_char_boundary;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio::io::AsyncBufReadExt;
|
||||||
|
use tokio::io::BufReader;
|
||||||
|
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct ReadFileHandler;
|
||||||
|
|
||||||
|
const MAX_LINE_LENGTH: usize = 500;
|
||||||
|
|
||||||
|
fn default_offset() -> usize {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_limit() -> usize {
|
||||||
|
2000
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ReadFileArgs {
|
||||||
|
file_path: String,
|
||||||
|
#[serde(default = "default_offset")]
|
||||||
|
offset: usize,
|
||||||
|
#[serde(default = "default_limit")]
|
||||||
|
limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for ReadFileHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation { payload, .. } = invocation;
|
||||||
|
|
||||||
|
let arguments = match payload {
|
||||||
|
ToolPayload::Function { arguments } => arguments,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"read_file handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let args: ReadFileArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {err:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let ReadFileArgs {
|
||||||
|
file_path,
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
if offset == 0 {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"offset must be a 1-indexed line number".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit == 0 {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"limit must be greater than zero".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = PathBuf::from(&file_path);
|
||||||
|
if !path.is_absolute() {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"file_path must be an absolute path".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let collected = read_file_slice(&path, offset, limit).await?;
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content: collected.join("\n"),
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_file_slice(
|
||||||
|
path: &Path,
|
||||||
|
offset: usize,
|
||||||
|
limit: usize,
|
||||||
|
) -> Result<Vec<String>, FunctionCallError> {
|
||||||
|
let file = File::open(path)
|
||||||
|
.await
|
||||||
|
.map_err(|err| FunctionCallError::RespondToModel(format!("failed to read file: {err}")))?;
|
||||||
|
|
||||||
|
let mut reader = BufReader::new(file);
|
||||||
|
let mut collected = Vec::new();
|
||||||
|
let mut seen = 0usize;
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
buffer.clear();
|
||||||
|
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if bytes_read == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if buffer.last() == Some(&b'\n') {
|
||||||
|
buffer.pop();
|
||||||
|
if buffer.last() == Some(&b'\r') {
|
||||||
|
buffer.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
seen += 1;
|
||||||
|
|
||||||
|
if seen < offset {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if collected.len() == limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let formatted = format_line(&buffer);
|
||||||
|
collected.push(format!("L{seen}: {formatted}"));
|
||||||
|
|
||||||
|
if collected.len() == limit {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if seen < offset {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"offset exceeds file length".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(collected)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_line(bytes: &[u8]) -> String {
|
||||||
|
let decoded = String::from_utf8_lossy(bytes);
|
||||||
|
if decoded.len() > MAX_LINE_LENGTH {
|
||||||
|
take_bytes_at_char_boundary(&decoded, MAX_LINE_LENGTH).to_string()
|
||||||
|
} else {
|
||||||
|
decoded.into_owned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reads_requested_range() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
writeln!(temp, "alpha").unwrap();
|
||||||
|
writeln!(temp, "beta").unwrap();
|
||||||
|
writeln!(temp, "gamma").unwrap();
|
||||||
|
|
||||||
|
let lines = read_file_slice(temp.path(), 2, 2)
|
||||||
|
.await
|
||||||
|
.expect("read slice");
|
||||||
|
assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn errors_when_offset_exceeds_length() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
writeln!(temp, "only").unwrap();
|
||||||
|
|
||||||
|
let err = read_file_slice(temp.path(), 3, 1)
|
||||||
|
.await
|
||||||
|
.expect_err("offset exceeds length");
|
||||||
|
assert_eq!(
|
||||||
|
err,
|
||||||
|
FunctionCallError::RespondToModel("offset exceeds file length".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reads_non_utf8_lines() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
temp.as_file_mut().write_all(b"\xff\xfe\nplain\n").unwrap();
|
||||||
|
|
||||||
|
let lines = read_file_slice(temp.path(), 1, 2)
|
||||||
|
.await
|
||||||
|
.expect("read slice");
|
||||||
|
let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}');
|
||||||
|
assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn trims_crlf_endings() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
write!(temp, "one\r\ntwo\r\n").unwrap();
|
||||||
|
|
||||||
|
let lines = read_file_slice(temp.path(), 1, 2)
|
||||||
|
.await
|
||||||
|
.expect("read slice");
|
||||||
|
assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn respects_limit_even_with_more_lines() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
writeln!(temp, "first").unwrap();
|
||||||
|
writeln!(temp, "second").unwrap();
|
||||||
|
writeln!(temp, "third").unwrap();
|
||||||
|
|
||||||
|
let lines = read_file_slice(temp.path(), 1, 2)
|
||||||
|
.await
|
||||||
|
.expect("read slice");
|
||||||
|
assert_eq!(
|
||||||
|
lines,
|
||||||
|
vec!["L1: first".to_string(), "L2: second".to_string()]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn truncates_lines_longer_than_max_length() {
|
||||||
|
let mut temp = NamedTempFile::new().expect("create temp file");
|
||||||
|
use std::io::Write as _;
|
||||||
|
let long_line = "x".repeat(MAX_LINE_LENGTH + 50);
|
||||||
|
writeln!(temp, "{long_line}").unwrap();
|
||||||
|
|
||||||
|
let lines = read_file_slice(temp.path(), 1, 1)
|
||||||
|
.await
|
||||||
|
.expect("read slice");
|
||||||
|
let expected = "x".repeat(MAX_LINE_LENGTH);
|
||||||
|
assert_eq!(lines, vec![format!("L1: {expected}")]);
|
||||||
|
}
|
||||||
|
}
|
||||||
103
codex-rs/core/src/tools/handlers/shell.rs
Normal file
103
codex-rs/core/src/tools/handlers/shell.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use codex_protocol::models::ShellToolCallParams;
|
||||||
|
|
||||||
|
use crate::codex::TurnContext;
|
||||||
|
use crate::exec::ExecParams;
|
||||||
|
use crate::exec_env::create_env;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::handle_container_exec_with_params;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct ShellHandler;
|
||||||
|
|
||||||
|
impl ShellHandler {
|
||||||
|
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
|
||||||
|
ExecParams {
|
||||||
|
command: params.command,
|
||||||
|
cwd: turn_context.resolve_path(params.workdir.clone()),
|
||||||
|
timeout_ms: params.timeout_ms,
|
||||||
|
env: create_env(&turn_context.shell_environment_policy),
|
||||||
|
with_escalated_permissions: params.with_escalated_permissions,
|
||||||
|
justification: params.justification,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for ShellHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||||
|
matches!(
|
||||||
|
payload,
|
||||||
|
ToolPayload::Function { .. } | ToolPayload::LocalShell { .. }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
tool_name,
|
||||||
|
payload,
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
match payload {
|
||||||
|
ToolPayload::Function { arguments } => {
|
||||||
|
let params: ShellToolCallParams =
|
||||||
|
serde_json::from_str(&arguments).map_err(|e| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {e:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let exec_params = Self::to_exec_params(params, turn);
|
||||||
|
let content = handle_container_exec_with_params(
|
||||||
|
tool_name.as_str(),
|
||||||
|
exec_params,
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id.to_string(),
|
||||||
|
call_id.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ToolPayload::LocalShell { params } => {
|
||||||
|
let exec_params = Self::to_exec_params(params, turn);
|
||||||
|
let content = handle_container_exec_with_params(
|
||||||
|
tool_name.as_str(),
|
||||||
|
exec_params,
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id.to_string(),
|
||||||
|
call_id.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"unsupported payload for shell handler: {tool_name}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
112
codex-rs/core/src/tools/handlers/unified_exec.rs
Normal file
112
codex-rs/core/src/tools/handlers/unified_exec.rs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
use crate::unified_exec::UnifiedExecRequest;
|
||||||
|
|
||||||
|
pub struct UnifiedExecHandler;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct UnifiedExecArgs {
|
||||||
|
input: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
session_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
timeout_ms: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for UnifiedExecHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::UnifiedExec
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||||
|
matches!(
|
||||||
|
payload,
|
||||||
|
ToolPayload::UnifiedExec { .. } | ToolPayload::Function { .. }
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session, payload, ..
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let args = match payload {
|
||||||
|
ToolPayload::UnifiedExec { arguments } | ToolPayload::Function { arguments } => {
|
||||||
|
serde_json::from_str::<UnifiedExecArgs>(&arguments).map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to parse function arguments: {err:?}"
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"unified_exec handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let UnifiedExecArgs {
|
||||||
|
input,
|
||||||
|
session_id,
|
||||||
|
timeout_ms,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
let parsed_session_id = if let Some(session_id) = session_id {
|
||||||
|
match session_id.parse::<i32>() {
|
||||||
|
Ok(parsed) => Some(parsed),
|
||||||
|
Err(output) => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"invalid session_id: {session_id} due to error {output:?}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = UnifiedExecRequest {
|
||||||
|
session_id: parsed_session_id,
|
||||||
|
input_chunks: &input,
|
||||||
|
timeout_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = session
|
||||||
|
.run_unified_exec_request(request)
|
||||||
|
.await
|
||||||
|
.map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!("unified exec failed: {err:?}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
#[derive(serde::Serialize)]
|
||||||
|
struct SerializedUnifiedExecResult {
|
||||||
|
session_id: Option<String>,
|
||||||
|
output: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = serde_json::to_string(&SerializedUnifiedExecResult {
|
||||||
|
session_id: value.session_id.map(|id| id.to_string()),
|
||||||
|
output: value.output,
|
||||||
|
})
|
||||||
|
.map_err(|err| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"failed to serialize unified exec output: {err:?}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content,
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
96
codex-rs/core/src/tools/handlers/view_image.rs
Normal file
96
codex-rs/core/src/tools/handlers/view_image.rs
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::fs;
|
||||||
|
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::protocol::Event;
|
||||||
|
use crate::protocol::EventMsg;
|
||||||
|
use crate::protocol::InputItem;
|
||||||
|
use crate::protocol::ViewImageToolCallEvent;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolHandler;
|
||||||
|
use crate::tools::registry::ToolKind;
|
||||||
|
|
||||||
|
pub struct ViewImageHandler;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ViewImageArgs {
|
||||||
|
path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ToolHandler for ViewImageHandler {
|
||||||
|
fn kind(&self) -> ToolKind {
|
||||||
|
ToolKind::Function
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'_>,
|
||||||
|
) -> Result<ToolOutput, FunctionCallError> {
|
||||||
|
let ToolInvocation {
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
payload,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
..
|
||||||
|
} = invocation;
|
||||||
|
|
||||||
|
let arguments = match payload {
|
||||||
|
ToolPayload::Function { arguments } => arguments,
|
||||||
|
_ => {
|
||||||
|
return Err(FunctionCallError::RespondToModel(
|
||||||
|
"view_image handler received unsupported payload".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let args: ViewImageArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||||
|
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let abs_path = turn.resolve_path(Some(args.path));
|
||||||
|
|
||||||
|
let metadata = fs::metadata(&abs_path).await.map_err(|error| {
|
||||||
|
FunctionCallError::RespondToModel(format!(
|
||||||
|
"unable to locate image at `{}`: {error}",
|
||||||
|
abs_path.display()
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !metadata.is_file() {
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"image path `{}` is not a file",
|
||||||
|
abs_path.display()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
let event_path = abs_path.clone();
|
||||||
|
|
||||||
|
session
|
||||||
|
.inject_input(vec![InputItem::LocalImage { path: abs_path }])
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
FunctionCallError::RespondToModel(
|
||||||
|
"unable to attach image (no active task)".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
session
|
||||||
|
.send_event(Event {
|
||||||
|
id: sub_id.to_string(),
|
||||||
|
msg: EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||||
|
call_id,
|
||||||
|
path: event_path,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(ToolOutput::Function {
|
||||||
|
content: "attached local image path".to_string(),
|
||||||
|
success: Some(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
280
codex-rs/core/src/tools/mod.rs
Normal file
280
codex-rs/core/src/tools/mod.rs
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
pub mod context;
|
||||||
|
pub(crate) mod handlers;
|
||||||
|
pub mod registry;
|
||||||
|
pub mod router;
|
||||||
|
pub mod spec;
|
||||||
|
|
||||||
|
use crate::apply_patch;
|
||||||
|
use crate::apply_patch::ApplyPatchExec;
|
||||||
|
use crate::apply_patch::InternalApplyPatchInvocation;
|
||||||
|
use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||||
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
|
use crate::error::CodexErr;
|
||||||
|
use crate::error::SandboxErr;
|
||||||
|
use crate::exec::ExecParams;
|
||||||
|
use crate::exec::ExecToolCallOutput;
|
||||||
|
use crate::exec::StdoutStream;
|
||||||
|
use crate::executor::ExecutionMode;
|
||||||
|
use crate::executor::errors::ExecError;
|
||||||
|
use crate::executor::linkers::PreparedExec;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ApplyPatchCommandContext;
|
||||||
|
use crate::tools::context::ExecCommandContext;
|
||||||
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
|
use codex_apply_patch::MaybeApplyPatchVerified;
|
||||||
|
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||||
|
use codex_protocol::protocol::AskForApproval;
|
||||||
|
use codex_utils_string::take_bytes_at_char_boundary;
|
||||||
|
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||||
|
pub use router::ToolRouter;
|
||||||
|
use serde::Serialize;
|
||||||
|
use tracing::trace;
|
||||||
|
|
||||||
|
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||||
|
pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB
|
||||||
|
pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines
|
||||||
|
pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2;
|
||||||
|
pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128
|
||||||
|
pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2;
|
||||||
|
|
||||||
|
// Telemetry preview limits: keep log events smaller than model budgets.
|
||||||
|
pub(crate) const TELEMETRY_PREVIEW_MAX_BYTES: usize = 2 * 1024; // 2 KiB
|
||||||
|
pub(crate) const TELEMETRY_PREVIEW_MAX_LINES: usize = 64; // lines
|
||||||
|
pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||||
|
"[... telemetry preview truncated ...]";
|
||||||
|
|
||||||
|
// TODO(jif) break this down
|
||||||
|
pub(crate) async fn handle_container_exec_with_params(
|
||||||
|
tool_name: &str,
|
||||||
|
params: ExecParams,
|
||||||
|
sess: &Session,
|
||||||
|
turn_context: &TurnContext,
|
||||||
|
turn_diff_tracker: &mut TurnDiffTracker,
|
||||||
|
sub_id: String,
|
||||||
|
call_id: String,
|
||||||
|
) -> Result<String, FunctionCallError> {
|
||||||
|
let otel_event_manager = turn_context.client.get_otel_event_manager();
|
||||||
|
|
||||||
|
if params.with_escalated_permissions.unwrap_or(false)
|
||||||
|
&& !matches!(turn_context.approval_policy, AskForApproval::OnRequest)
|
||||||
|
{
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
|
||||||
|
policy = turn_context.approval_policy
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if this was a patch, and apply it if so
|
||||||
|
let apply_patch_exec = match maybe_parse_apply_patch_verified(¶ms.command, ¶ms.cwd) {
|
||||||
|
MaybeApplyPatchVerified::Body(changes) => {
|
||||||
|
match apply_patch::apply_patch(sess, turn_context, &sub_id, &call_id, changes).await {
|
||||||
|
InternalApplyPatchInvocation::Output(item) => return item,
|
||||||
|
InternalApplyPatchInvocation::DelegateToExec(apply_patch_exec) => {
|
||||||
|
Some(apply_patch_exec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
|
||||||
|
// It looks like an invocation of `apply_patch`, but we
|
||||||
|
// could not resolve it into a patch that would apply
|
||||||
|
// cleanly. Return to model for resample.
|
||||||
|
return Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"apply_patch verification failed: {parse_error}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
MaybeApplyPatchVerified::ShellParseError(error) => {
|
||||||
|
trace!("Failed to parse shell command, {error:?}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
MaybeApplyPatchVerified::NotApplyPatch => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let command_for_display = if let Some(exec) = apply_patch_exec.as_ref() {
|
||||||
|
vec!["apply_patch".to_string(), exec.action.patch.clone()]
|
||||||
|
} else {
|
||||||
|
params.command.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let exec_command_context = ExecCommandContext {
|
||||||
|
sub_id: sub_id.clone(),
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
command_for_display: command_for_display.clone(),
|
||||||
|
cwd: params.cwd.clone(),
|
||||||
|
apply_patch: apply_patch_exec.as_ref().map(
|
||||||
|
|ApplyPatchExec {
|
||||||
|
action,
|
||||||
|
user_explicitly_approved_this_action,
|
||||||
|
}| ApplyPatchCommandContext {
|
||||||
|
user_explicitly_approved_this_action: *user_explicitly_approved_this_action,
|
||||||
|
changes: convert_apply_patch_to_protocol(action),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tool_name: tool_name.to_string(),
|
||||||
|
otel_event_manager,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mode = match apply_patch_exec {
|
||||||
|
Some(exec) => ExecutionMode::ApplyPatch(exec),
|
||||||
|
None => ExecutionMode::Shell,
|
||||||
|
};
|
||||||
|
|
||||||
|
sess.services.executor.update_environment(
|
||||||
|
turn_context.sandbox_policy.clone(),
|
||||||
|
turn_context.cwd.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prepared_exec = PreparedExec::new(
|
||||||
|
exec_command_context,
|
||||||
|
params,
|
||||||
|
command_for_display,
|
||||||
|
mode,
|
||||||
|
Some(StdoutStream {
|
||||||
|
sub_id: sub_id.clone(),
|
||||||
|
call_id: call_id.clone(),
|
||||||
|
tx_event: sess.get_tx_event(),
|
||||||
|
}),
|
||||||
|
turn_context.shell_environment_policy.use_profile,
|
||||||
|
);
|
||||||
|
|
||||||
|
let output_result = sess
|
||||||
|
.run_exec_with_events(
|
||||||
|
turn_diff_tracker,
|
||||||
|
prepared_exec,
|
||||||
|
turn_context.approval_policy,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output_result {
|
||||||
|
Ok(output) => {
|
||||||
|
let ExecToolCallOutput { exit_code, .. } = &output;
|
||||||
|
let content = format_exec_output_apply_patch(&output);
|
||||||
|
if *exit_code == 0 {
|
||||||
|
Ok(content)
|
||||||
|
} else {
|
||||||
|
Err(FunctionCallError::RespondToModel(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(ExecError::Function(err)) => Err(err),
|
||||||
|
Err(ExecError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output }))) => Err(
|
||||||
|
FunctionCallError::RespondToModel(format_exec_output_apply_patch(&output)),
|
||||||
|
),
|
||||||
|
Err(ExecError::Codex(err)) => Err(FunctionCallError::RespondToModel(format!(
|
||||||
|
"execution error: {err:?}"
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_exec_output_apply_patch(exec_output: &ExecToolCallOutput) -> String {
|
||||||
|
let ExecToolCallOutput {
|
||||||
|
exit_code,
|
||||||
|
duration,
|
||||||
|
..
|
||||||
|
} = exec_output;
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ExecMetadata {
|
||||||
|
exit_code: i32,
|
||||||
|
duration_seconds: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ExecOutput<'a> {
|
||||||
|
output: &'a str,
|
||||||
|
metadata: ExecMetadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// round to 1 decimal place
|
||||||
|
let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0;
|
||||||
|
|
||||||
|
let formatted_output = format_exec_output_str(exec_output);
|
||||||
|
|
||||||
|
let payload = ExecOutput {
|
||||||
|
output: &formatted_output,
|
||||||
|
metadata: ExecMetadata {
|
||||||
|
exit_code: *exit_code,
|
||||||
|
duration_seconds,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[expect(clippy::expect_used)]
|
||||||
|
serde_json::to_string(&payload).expect("serialize ExecOutput")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||||
|
let ExecToolCallOutput {
|
||||||
|
aggregated_output, ..
|
||||||
|
} = exec_output;
|
||||||
|
|
||||||
|
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||||
|
// Clients still receive full streams; only this formatted summary is capped.
|
||||||
|
|
||||||
|
let mut s = &aggregated_output.text;
|
||||||
|
let prefixed_str: String;
|
||||||
|
|
||||||
|
if exec_output.timed_out {
|
||||||
|
prefixed_str = format!(
|
||||||
|
"command timed out after {} milliseconds\n",
|
||||||
|
exec_output.duration.as_millis()
|
||||||
|
) + s;
|
||||||
|
s = &prefixed_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_lines = s.lines().count();
|
||||||
|
if s.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||||
|
return s.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let segments: Vec<&str> = s.split_inclusive('\n').collect();
|
||||||
|
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
|
||||||
|
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
|
||||||
|
let omitted = segments.len().saturating_sub(head_take + tail_take);
|
||||||
|
|
||||||
|
let head_slice_end: usize = segments
|
||||||
|
.iter()
|
||||||
|
.take(head_take)
|
||||||
|
.map(|segment| segment.len())
|
||||||
|
.sum();
|
||||||
|
let tail_slice_start: usize = if tail_take == 0 {
|
||||||
|
s.len()
|
||||||
|
} else {
|
||||||
|
s.len()
|
||||||
|
- segments
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(tail_take)
|
||||||
|
.map(|segment| segment.len())
|
||||||
|
.sum::<usize>()
|
||||||
|
};
|
||||||
|
let marker = format!("\n[... omitted {omitted} of {total_lines} lines ...]\n\n");
|
||||||
|
|
||||||
|
// Byte budgets for head/tail around the marker
|
||||||
|
let mut head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES);
|
||||||
|
let tail_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(head_budget + marker.len());
|
||||||
|
if tail_budget == 0 && marker.len() >= MODEL_FORMAT_MAX_BYTES {
|
||||||
|
// Degenerate case: marker alone exceeds budget; return a clipped marker
|
||||||
|
return take_bytes_at_char_boundary(&marker, MODEL_FORMAT_MAX_BYTES).to_string();
|
||||||
|
}
|
||||||
|
if tail_budget == 0 {
|
||||||
|
// Make room for the marker by shrinking head
|
||||||
|
head_budget = MODEL_FORMAT_MAX_BYTES.saturating_sub(marker.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
let head_slice = &s[..head_slice_end];
|
||||||
|
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
|
||||||
|
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(s.len()));
|
||||||
|
|
||||||
|
result.push_str(head_part);
|
||||||
|
result.push_str(&marker);
|
||||||
|
|
||||||
|
let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len());
|
||||||
|
if remaining == 0 {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tail_slice = &s[tail_slice_start..];
|
||||||
|
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
|
||||||
|
result.push_str(tail_part);
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
197
codex-rs/core/src/tools/registry.rs
Normal file
197
codex-rs/core/src/tools/registry.rs
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::client_common::tools::ToolSpec;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolOutput;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub enum ToolKind {
|
||||||
|
Function,
|
||||||
|
UnifiedExec,
|
||||||
|
Mcp,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait ToolHandler: Send + Sync {
|
||||||
|
fn kind(&self) -> ToolKind;
|
||||||
|
|
||||||
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||||
|
matches!(
|
||||||
|
(self.kind(), payload),
|
||||||
|
(ToolKind::Function, ToolPayload::Function { .. })
|
||||||
|
| (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. })
|
||||||
|
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||||||
|
-> Result<ToolOutput, FunctionCallError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||||||
|
Self { handlers }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||||||
|
self.handlers.get(name).map(Arc::clone)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jif) for dynamic tools.
|
||||||
|
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||||
|
// let name = name.into();
|
||||||
|
// if self.handlers.insert(name.clone(), handler).is_some() {
|
||||||
|
// warn!("overwriting handler for tool {name}");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub async fn dispatch<'a>(
|
||||||
|
&self,
|
||||||
|
invocation: ToolInvocation<'a>,
|
||||||
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||||
|
let tool_name = invocation.tool_name.clone();
|
||||||
|
let call_id_owned = invocation.call_id.clone();
|
||||||
|
let otel = invocation.turn.client.get_otel_event_manager();
|
||||||
|
let payload_for_response = invocation.payload.clone();
|
||||||
|
let log_payload = payload_for_response.log_payload();
|
||||||
|
|
||||||
|
let handler = match self.handler(tool_name.as_ref()) {
|
||||||
|
Some(handler) => handler,
|
||||||
|
None => {
|
||||||
|
let message =
|
||||||
|
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||||||
|
otel.tool_result(
|
||||||
|
tool_name.as_ref(),
|
||||||
|
&call_id_owned,
|
||||||
|
log_payload.as_ref(),
|
||||||
|
Duration::ZERO,
|
||||||
|
false,
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
return Err(FunctionCallError::RespondToModel(message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !handler.matches_kind(&invocation.payload) {
|
||||||
|
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||||
|
otel.tool_result(
|
||||||
|
tool_name.as_ref(),
|
||||||
|
&call_id_owned,
|
||||||
|
log_payload.as_ref(),
|
||||||
|
Duration::ZERO,
|
||||||
|
false,
|
||||||
|
&message,
|
||||||
|
);
|
||||||
|
return Err(FunctionCallError::Fatal(message));
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_cell = tokio::sync::Mutex::new(None);
|
||||||
|
|
||||||
|
let result = otel
|
||||||
|
.log_tool_result(
|
||||||
|
tool_name.as_ref(),
|
||||||
|
&call_id_owned,
|
||||||
|
log_payload.as_ref(),
|
||||||
|
|| {
|
||||||
|
let handler = handler.clone();
|
||||||
|
let output_cell = &output_cell;
|
||||||
|
let invocation = invocation;
|
||||||
|
async move {
|
||||||
|
match handler.handle(invocation).await {
|
||||||
|
Ok(output) => {
|
||||||
|
let preview = output.log_preview();
|
||||||
|
let success = output.success_for_logging();
|
||||||
|
let mut guard = output_cell.lock().await;
|
||||||
|
*guard = Some(output);
|
||||||
|
Ok((preview, success))
|
||||||
|
}
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => {
|
||||||
|
let mut guard = output_cell.lock().await;
|
||||||
|
let output = guard.take().ok_or_else(|| {
|
||||||
|
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||||
|
})?;
|
||||||
|
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||||
|
}
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolRegistryBuilder {
|
||||||
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||||
|
specs: Vec<ToolSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistryBuilder {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
handlers: HashMap::new(),
|
||||||
|
specs: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||||
|
self.specs.push(spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||||
|
let name = name.into();
|
||||||
|
if self
|
||||||
|
.handlers
|
||||||
|
.insert(name.clone(), handler.clone())
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
warn!("overwriting handler for tool {name}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(jif) for dynamic tools.
|
||||||
|
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
||||||
|
// where
|
||||||
|
// I: IntoIterator,
|
||||||
|
// I::Item: Into<String>,
|
||||||
|
// {
|
||||||
|
// for name in names {
|
||||||
|
// let name = name.into();
|
||||||
|
// if self
|
||||||
|
// .handlers
|
||||||
|
// .insert(name.clone(), handler.clone())
|
||||||
|
// .is_some()
|
||||||
|
// {
|
||||||
|
// warn!("overwriting handler for tool {name}");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||||||
|
let registry = ToolRegistry::new(self.handlers);
|
||||||
|
(self.specs, registry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
||||||
|
match payload {
|
||||||
|
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||||||
|
_ => format!("unsupported call: {tool_name}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
177
codex-rs/core/src/tools/router.rs
Normal file
177
codex-rs/core/src/tools/router.rs
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::client_common::tools::ToolSpec;
|
||||||
|
use crate::codex::Session;
|
||||||
|
use crate::codex::TurnContext;
|
||||||
|
use crate::function_tool::FunctionCallError;
|
||||||
|
use crate::tools::context::ToolInvocation;
|
||||||
|
use crate::tools::context::ToolPayload;
|
||||||
|
use crate::tools::registry::ToolRegistry;
|
||||||
|
use crate::tools::spec::ToolsConfig;
|
||||||
|
use crate::tools::spec::build_specs;
|
||||||
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||||
|
use codex_protocol::models::LocalShellAction;
|
||||||
|
use codex_protocol::models::ResponseInputItem;
|
||||||
|
use codex_protocol::models::ResponseItem;
|
||||||
|
use codex_protocol::models::ShellToolCallParams;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub tool_name: String,
|
||||||
|
pub call_id: String,
|
||||||
|
pub payload: ToolPayload,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ToolRouter {
|
||||||
|
registry: ToolRegistry,
|
||||||
|
specs: Vec<ToolSpec>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRouter {
|
||||||
|
pub fn from_config(
|
||||||
|
config: &ToolsConfig,
|
||||||
|
mcp_tools: Option<HashMap<String, mcp_types::Tool>>,
|
||||||
|
) -> Self {
|
||||||
|
let builder = build_specs(config, mcp_tools);
|
||||||
|
let (specs, registry) = builder.build();
|
||||||
|
Self { registry, specs }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn specs(&self) -> &[ToolSpec] {
|
||||||
|
&self.specs
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_tool_call(
|
||||||
|
session: &Session,
|
||||||
|
item: ResponseItem,
|
||||||
|
) -> Result<Option<ToolCall>, FunctionCallError> {
|
||||||
|
match item {
|
||||||
|
ResponseItem::FunctionCall {
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
call_id,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
if let Some((server, tool)) = session.parse_mcp_tool_name(&name) {
|
||||||
|
Ok(Some(ToolCall {
|
||||||
|
tool_name: name,
|
||||||
|
call_id,
|
||||||
|
payload: ToolPayload::Mcp {
|
||||||
|
server,
|
||||||
|
tool,
|
||||||
|
raw_arguments: arguments,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
let payload = if name == "unified_exec" {
|
||||||
|
ToolPayload::UnifiedExec { arguments }
|
||||||
|
} else {
|
||||||
|
ToolPayload::Function { arguments }
|
||||||
|
};
|
||||||
|
Ok(Some(ToolCall {
|
||||||
|
tool_name: name,
|
||||||
|
call_id,
|
||||||
|
payload,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ResponseItem::CustomToolCall {
|
||||||
|
name,
|
||||||
|
input,
|
||||||
|
call_id,
|
||||||
|
..
|
||||||
|
} => Ok(Some(ToolCall {
|
||||||
|
tool_name: name,
|
||||||
|
call_id,
|
||||||
|
payload: ToolPayload::Custom { input },
|
||||||
|
})),
|
||||||
|
ResponseItem::LocalShellCall {
|
||||||
|
id,
|
||||||
|
call_id,
|
||||||
|
action,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
let call_id = call_id
|
||||||
|
.or(id)
|
||||||
|
.ok_or(FunctionCallError::MissingLocalShellCallId)?;
|
||||||
|
|
||||||
|
match action {
|
||||||
|
LocalShellAction::Exec(exec) => {
|
||||||
|
let params = ShellToolCallParams {
|
||||||
|
command: exec.command,
|
||||||
|
workdir: exec.working_directory,
|
||||||
|
timeout_ms: exec.timeout_ms,
|
||||||
|
with_escalated_permissions: None,
|
||||||
|
justification: None,
|
||||||
|
};
|
||||||
|
Ok(Some(ToolCall {
|
||||||
|
tool_name: "local_shell".to_string(),
|
||||||
|
call_id,
|
||||||
|
payload: ToolPayload::LocalShell { params },
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn dispatch_tool_call(
|
||||||
|
&self,
|
||||||
|
session: &Session,
|
||||||
|
turn: &TurnContext,
|
||||||
|
tracker: &mut TurnDiffTracker,
|
||||||
|
sub_id: &str,
|
||||||
|
call: ToolCall,
|
||||||
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||||
|
let ToolCall {
|
||||||
|
tool_name,
|
||||||
|
call_id,
|
||||||
|
payload,
|
||||||
|
} = call;
|
||||||
|
let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. });
|
||||||
|
let failure_call_id = call_id.clone();
|
||||||
|
|
||||||
|
let invocation = ToolInvocation {
|
||||||
|
session,
|
||||||
|
turn,
|
||||||
|
tracker,
|
||||||
|
sub_id,
|
||||||
|
call_id,
|
||||||
|
tool_name,
|
||||||
|
payload,
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.registry.dispatch(invocation).await {
|
||||||
|
Ok(response) => Ok(response),
|
||||||
|
Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)),
|
||||||
|
Err(err) => Ok(Self::failure_response(
|
||||||
|
failure_call_id,
|
||||||
|
payload_outputs_custom,
|
||||||
|
err,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn failure_response(
|
||||||
|
call_id: String,
|
||||||
|
payload_outputs_custom: bool,
|
||||||
|
err: FunctionCallError,
|
||||||
|
) -> ResponseInputItem {
|
||||||
|
let message = err.to_string();
|
||||||
|
if payload_outputs_custom {
|
||||||
|
ResponseInputItem::CustomToolCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: message,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ResponseInputItem::FunctionCallOutput {
|
||||||
|
call_id,
|
||||||
|
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||||
|
content: message,
|
||||||
|
success: Some(false),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1269
codex-rs/core/src/tools/spec.rs
Normal file
1269
codex-rs/core/src/tools/spec.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,6 +7,9 @@ use codex_core::config::Config;
|
|||||||
use codex_core::config::ConfigOverrides;
|
use codex_core::config::ConfigOverrides;
|
||||||
use codex_core::config::ConfigToml;
|
use codex_core::config::ConfigToml;
|
||||||
|
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
use assert_cmd::cargo::cargo_bin;
|
||||||
|
|
||||||
pub mod responses;
|
pub mod responses;
|
||||||
pub mod test_codex;
|
pub mod test_codex;
|
||||||
pub mod test_codex_exec;
|
pub mod test_codex_exec;
|
||||||
@@ -17,12 +20,25 @@ pub mod test_codex_exec;
|
|||||||
pub fn load_default_config_for_test(codex_home: &TempDir) -> Config {
|
pub fn load_default_config_for_test(codex_home: &TempDir) -> Config {
|
||||||
Config::load_from_base_config_with_overrides(
|
Config::load_from_base_config_with_overrides(
|
||||||
ConfigToml::default(),
|
ConfigToml::default(),
|
||||||
ConfigOverrides::default(),
|
default_test_overrides(),
|
||||||
codex_home.path().to_path_buf(),
|
codex_home.path().to_path_buf(),
|
||||||
)
|
)
|
||||||
.expect("defaults for test should always succeed")
|
.expect("defaults for test should always succeed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "linux")]
|
||||||
|
fn default_test_overrides() -> ConfigOverrides {
|
||||||
|
ConfigOverrides {
|
||||||
|
codex_linux_sandbox_exe: Some(cargo_bin("codex-linux-sandbox")),
|
||||||
|
..ConfigOverrides::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "linux"))]
|
||||||
|
fn default_test_overrides() -> ConfigOverrides {
|
||||||
|
ConfigOverrides::default()
|
||||||
|
}
|
||||||
|
|
||||||
/// Builds an SSE stream body from a JSON fixture.
|
/// Builds an SSE stream body from a JSON fixture.
|
||||||
///
|
///
|
||||||
/// The fixture must contain an array of objects where each object represents a
|
/// The fixture must contain an array of objects where each object represents a
|
||||||
|
|||||||
@@ -12,12 +12,18 @@ mod fork_conversation;
|
|||||||
mod json_result;
|
mod json_result;
|
||||||
mod live_cli;
|
mod live_cli;
|
||||||
mod model_overrides;
|
mod model_overrides;
|
||||||
|
mod model_tools;
|
||||||
mod otel;
|
mod otel;
|
||||||
mod prompt_caching;
|
mod prompt_caching;
|
||||||
|
mod read_file;
|
||||||
mod review;
|
mod review;
|
||||||
mod rmcp_client;
|
mod rmcp_client;
|
||||||
mod rollout_list_find;
|
mod rollout_list_find;
|
||||||
mod seatbelt;
|
mod seatbelt;
|
||||||
mod stream_error_allows_next_turn;
|
mod stream_error_allows_next_turn;
|
||||||
mod stream_no_completed;
|
mod stream_no_completed;
|
||||||
|
mod tool_harness;
|
||||||
|
mod tools;
|
||||||
|
mod unified_exec;
|
||||||
mod user_notification;
|
mod user_notification;
|
||||||
|
mod view_image;
|
||||||
|
|||||||
124
codex-rs/core/tests/suite/model_tools.rs
Normal file
124
codex-rs/core/tests/suite/model_tools.rs
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#![allow(clippy::unwrap_used)]
|
||||||
|
|
||||||
|
use codex_core::CodexAuth;
|
||||||
|
use codex_core::ConversationManager;
|
||||||
|
use codex_core::ModelProviderInfo;
|
||||||
|
use codex_core::built_in_model_providers;
|
||||||
|
use codex_core::model_family::find_family_for_model;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use core_test_support::load_default_config_for_test;
|
||||||
|
use core_test_support::load_sse_fixture_with_id;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::wait_for_event;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use wiremock::Mock;
|
||||||
|
use wiremock::MockServer;
|
||||||
|
use wiremock::ResponseTemplate;
|
||||||
|
use wiremock::matchers::method;
|
||||||
|
use wiremock::matchers::path;
|
||||||
|
|
||||||
|
fn sse_completed(id: &str) -> String {
|
||||||
|
load_sse_fixture_with_id("tests/fixtures/completed_template.json", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
fn tool_identifiers(body: &serde_json::Value) -> Vec<String> {
|
||||||
|
body["tools"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
tool.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.or_else(|| tool.get("type").and_then(|v| v.as_str()))
|
||||||
|
.map(std::string::ToString::to_string)
|
||||||
|
.expect("tool should have either name or type")
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::expect_used)]
|
||||||
|
async fn collect_tool_identifiers_for_model(model: &str) -> Vec<String> {
|
||||||
|
let server = MockServer::start().await;
|
||||||
|
|
||||||
|
let sse = sse_completed(model);
|
||||||
|
let template = ResponseTemplate::new(200)
|
||||||
|
.insert_header("content-type", "text/event-stream")
|
||||||
|
.set_body_raw(sse, "text/event-stream");
|
||||||
|
|
||||||
|
Mock::given(method("POST"))
|
||||||
|
.and(path("/v1/responses"))
|
||||||
|
.respond_with(template)
|
||||||
|
.expect(1)
|
||||||
|
.mount(&server)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let model_provider = ModelProviderInfo {
|
||||||
|
base_url: Some(format!("{}/v1", server.uri())),
|
||||||
|
..built_in_model_providers()["openai"].clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let cwd = TempDir::new().unwrap();
|
||||||
|
let codex_home = TempDir::new().unwrap();
|
||||||
|
let mut config = load_default_config_for_test(&codex_home);
|
||||||
|
config.cwd = cwd.path().to_path_buf();
|
||||||
|
config.model_provider = model_provider;
|
||||||
|
config.model = model.to_string();
|
||||||
|
config.model_family =
|
||||||
|
find_family_for_model(model).unwrap_or_else(|| panic!("unknown model family for {model}"));
|
||||||
|
config.include_plan_tool = false;
|
||||||
|
config.include_apply_patch_tool = false;
|
||||||
|
config.include_view_image_tool = false;
|
||||||
|
config.tools_web_search_request = false;
|
||||||
|
config.use_experimental_streamable_shell_tool = false;
|
||||||
|
config.use_experimental_unified_exec_tool = false;
|
||||||
|
|
||||||
|
let conversation_manager =
|
||||||
|
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||||
|
let codex = conversation_manager
|
||||||
|
.new_conversation(config)
|
||||||
|
.await
|
||||||
|
.expect("create new conversation")
|
||||||
|
.conversation;
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserInput {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "hello tools".into(),
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
requests.len(),
|
||||||
|
1,
|
||||||
|
"expected a single request for model {model}"
|
||||||
|
);
|
||||||
|
let body = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||||
|
tool_identifiers(&body)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn model_selects_expected_tools() {
|
||||||
|
skip_if_no_network!();
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
|
||||||
|
let codex_tools = collect_tool_identifiers_for_model("codex-mini-latest").await;
|
||||||
|
assert_eq!(
|
||||||
|
codex_tools,
|
||||||
|
vec!["local_shell".to_string(), "read_file".to_string()],
|
||||||
|
"codex-mini-latest should expose the local shell tool",
|
||||||
|
);
|
||||||
|
|
||||||
|
let o3_tools = collect_tool_identifiers_for_model("o3").await;
|
||||||
|
assert_eq!(
|
||||||
|
o3_tools,
|
||||||
|
vec!["shell".to_string(), "read_file".to_string()],
|
||||||
|
"o3 should expose the generic shell tool",
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -219,7 +219,13 @@ async fn prompt_tools_are_consistent_across_requests() {
|
|||||||
|
|
||||||
// our internal implementation is responsible for keeping tools in sync
|
// our internal implementation is responsible for keeping tools in sync
|
||||||
// with the OpenAI schema, so we just verify the tool presence here
|
// with the OpenAI schema, so we just verify the tool presence here
|
||||||
let expected_tools_names: &[&str] = &["shell", "update_plan", "apply_patch", "view_image"];
|
let expected_tools_names: &[&str] = &[
|
||||||
|
"shell",
|
||||||
|
"update_plan",
|
||||||
|
"apply_patch",
|
||||||
|
"read_file",
|
||||||
|
"view_image",
|
||||||
|
];
|
||||||
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
let body0 = requests[0].body_json::<serde_json::Value>().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
body0["instructions"],
|
body0["instructions"],
|
||||||
|
|||||||
124
codex-rs/core/tests/suite/read_file.rs
Normal file
124
codex-rs/core/tests/suite/read_file.rs
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use core_test_support::responses;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use core_test_support::wait_for_event;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use serde_json::Value;
|
||||||
|
use wiremock::matchers::any;
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn read_file_tool_returns_requested_lines() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = test_codex().build(&server).await?;
|
||||||
|
|
||||||
|
let file_path = cwd.path().join("sample.txt");
|
||||||
|
std::fs::write(&file_path, "first\nsecond\nthird\nfourth\n")?;
|
||||||
|
let file_path = file_path.to_string_lossy().to_string();
|
||||||
|
|
||||||
|
let call_id = "read-file-call";
|
||||||
|
let arguments = serde_json::json!({
|
||||||
|
"file_path": file_path,
|
||||||
|
"offset": 2,
|
||||||
|
"limit": 2,
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "read_file", &arguments),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please inspect sample.txt".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().unwrap())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
assert!(
|
||||||
|
!request_bodies.is_empty(),
|
||||||
|
"expected at least one request body"
|
||||||
|
);
|
||||||
|
|
||||||
|
let tool_output_item = request_bodies
|
||||||
|
.iter()
|
||||||
|
.find_map(|body| {
|
||||||
|
body.get("input")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|items| {
|
||||||
|
items.iter().find(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
panic!("function_call_output item not found in requests: {request_bodies:#?}")
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tool_output_item.get("call_id").and_then(Value::as_str),
|
||||||
|
Some(call_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
let output_text = tool_output_item
|
||||||
|
.get("output")
|
||||||
|
.and_then(|value| match value {
|
||||||
|
Value::String(text) => Some(text.as_str()),
|
||||||
|
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.expect("output text present");
|
||||||
|
assert_eq!(output_text, "L2: second\nL3: third");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
568
codex-rs/core/tests/suite/tool_harness.rs
Normal file
568
codex-rs/core/tests/suite/tool_harness.rs
Normal file
@@ -0,0 +1,568 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use core_test_support::responses;
|
||||||
|
use core_test_support::responses::ev_apply_patch_function_call;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::ev_local_shell_call;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use serde_json::Value;
|
||||||
|
use serde_json::json;
|
||||||
|
use wiremock::matchers::any;
|
||||||
|
|
||||||
|
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||||
|
body.get("input")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|items| {
|
||||||
|
items.iter().find(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||||
|
item.get("output").and_then(|value| match value {
|
||||||
|
Value::String(text) => Some(text.as_str()),
|
||||||
|
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||||
|
requests
|
||||||
|
.iter()
|
||||||
|
.find(|body| function_call_output(body).is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn shell_tool_executes_command_and_streams_output() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.include_apply_patch_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "shell-tool-call";
|
||||||
|
let command = vec!["/bin/echo", "tool harness"];
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_local_shell_call(call_id, "completed", command),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "all done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please run the shell command".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
let exec_output: Value = serde_json::from_str(output_text)?;
|
||||||
|
assert_eq!(exec_output["metadata"]["exit_code"], 0);
|
||||||
|
let stdout = exec_output["output"].as_str().expect("stdout field");
|
||||||
|
assert!(
|
||||||
|
stdout.contains("tool harness"),
|
||||||
|
"expected stdout to contain command output, got {stdout:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn update_plan_tool_emits_plan_update_event() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.include_plan_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "plan-tool-call";
|
||||||
|
let plan_args = json!({
|
||||||
|
"explanation": "Tool harness check",
|
||||||
|
"plan": [
|
||||||
|
{"step": "Inspect workspace", "status": "in_progress"},
|
||||||
|
{"step": "Report results", "status": "pending"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "update_plan", &plan_args),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "plan acknowledged"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please update the plan".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut saw_plan_update = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
match event.msg {
|
||||||
|
EventMsg::PlanUpdate(update) => {
|
||||||
|
saw_plan_update = true;
|
||||||
|
assert_eq!(update.explanation.as_deref(), Some("Tool harness check"));
|
||||||
|
assert_eq!(update.plan.len(), 2);
|
||||||
|
assert_eq!(update.plan[0].step, "Inspect workspace");
|
||||||
|
assert!(matches!(update.plan[0].status, StepStatus::InProgress));
|
||||||
|
assert_eq!(update.plan[1].step, "Report results");
|
||||||
|
assert!(matches!(update.plan[1].status, StepStatus::Pending));
|
||||||
|
}
|
||||||
|
EventMsg::TaskComplete(_) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(saw_plan_update, "expected PlanUpdate event");
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
assert_eq!(
|
||||||
|
output_item.get("call_id").and_then(Value::as_str),
|
||||||
|
Some(call_id)
|
||||||
|
);
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
assert_eq!(output_text, "Plan updated");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn update_plan_tool_rejects_malformed_payload() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.include_plan_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "plan-tool-invalid";
|
||||||
|
let invalid_args = json!({
|
||||||
|
"explanation": "Missing plan data"
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "update_plan", &invalid_args),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "malformed plan payload"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please update the plan".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut saw_plan_update = false;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
match event.msg {
|
||||||
|
EventMsg::PlanUpdate(_) => saw_plan_update = true,
|
||||||
|
EventMsg::TaskComplete(_) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!saw_plan_update,
|
||||||
|
"did not expect PlanUpdate event for malformed payload"
|
||||||
|
);
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
assert_eq!(
|
||||||
|
output_item.get("call_id").and_then(Value::as_str),
|
||||||
|
Some(call_id)
|
||||||
|
);
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
assert!(
|
||||||
|
output_text.contains("failed to parse function arguments"),
|
||||||
|
"expected parse error message in output text, got {output_text:?}"
|
||||||
|
);
|
||||||
|
if let Some(success_flag) = output_item
|
||||||
|
.get("output")
|
||||||
|
.and_then(|value| value.as_object())
|
||||||
|
.and_then(|obj| obj.get("success"))
|
||||||
|
.and_then(serde_json::Value::as_bool)
|
||||||
|
{
|
||||||
|
assert!(
|
||||||
|
!success_flag,
|
||||||
|
"expected tool output to mark success=false for malformed payload"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn apply_patch_tool_executes_and_emits_patch_events() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.include_apply_patch_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "apply-patch-call";
|
||||||
|
let patch_content = r#"*** Begin Patch
|
||||||
|
*** Add File: notes.txt
|
||||||
|
+Tool harness apply patch
|
||||||
|
*** End Patch"#;
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_apply_patch_function_call(call_id, patch_content),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "patch complete"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please apply a patch".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut saw_patch_begin = false;
|
||||||
|
let mut patch_end_success = None;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
match event.msg {
|
||||||
|
EventMsg::PatchApplyBegin(begin) => {
|
||||||
|
saw_patch_begin = true;
|
||||||
|
assert_eq!(begin.call_id, call_id);
|
||||||
|
}
|
||||||
|
EventMsg::PatchApplyEnd(end) => {
|
||||||
|
assert_eq!(end.call_id, call_id);
|
||||||
|
patch_end_success = Some(end.success);
|
||||||
|
}
|
||||||
|
EventMsg::TaskComplete(_) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(saw_patch_begin, "expected PatchApplyBegin event");
|
||||||
|
let patch_end_success =
|
||||||
|
patch_end_success.expect("expected PatchApplyEnd event to capture success flag");
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
assert_eq!(
|
||||||
|
output_item.get("call_id").and_then(Value::as_str),
|
||||||
|
Some(call_id)
|
||||||
|
);
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
|
||||||
|
if let Ok(exec_output) = serde_json::from_str::<Value>(output_text) {
|
||||||
|
let exit_code = exec_output["metadata"]["exit_code"]
|
||||||
|
.as_i64()
|
||||||
|
.expect("exit_code present");
|
||||||
|
let summary = exec_output["output"].as_str().expect("output field");
|
||||||
|
assert_eq!(
|
||||||
|
exit_code, 0,
|
||||||
|
"expected apply_patch exit_code=0, got {exit_code}, summary: {summary:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
patch_end_success,
|
||||||
|
"expected PatchApplyEnd success flag, summary: {summary:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
summary.contains("Success."),
|
||||||
|
"expected apply_patch summary to note success, got {summary:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let patched_path = cwd.path().join("notes.txt");
|
||||||
|
let contents = std::fs::read_to_string(&patched_path)
|
||||||
|
.unwrap_or_else(|e| panic!("failed reading {}: {e}", patched_path.display()));
|
||||||
|
assert_eq!(contents, "Tool harness apply patch\n");
|
||||||
|
} else {
|
||||||
|
assert!(
|
||||||
|
output_text.contains("codex-run-as-apply-patch"),
|
||||||
|
"expected apply_patch failure message to mention codex-run-as-apply-patch, got {output_text:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!patch_end_success,
|
||||||
|
"expected PatchApplyEnd to report success=false when apply_patch invocation fails"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn apply_patch_reports_parse_diagnostics() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.include_apply_patch_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "apply-patch-parse-error";
|
||||||
|
let patch_content = r"*** Begin Patch
|
||||||
|
*** Update File: broken.txt
|
||||||
|
*** End Patch";
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_apply_patch_function_call(call_id, patch_content),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "failed"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please apply a patch".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
assert_eq!(
|
||||||
|
output_item.get("call_id").and_then(Value::as_str),
|
||||||
|
Some(call_id)
|
||||||
|
);
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
output_text.contains("apply_patch verification failed"),
|
||||||
|
"expected apply_patch verification failure message, got {output_text:?}"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
output_text.contains("invalid hunk"),
|
||||||
|
"expected parse diagnostics in output text, got {output_text:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(success_flag) = output_item
|
||||||
|
.get("output")
|
||||||
|
.and_then(|value| value.as_object())
|
||||||
|
.and_then(|obj| obj.get("success"))
|
||||||
|
.and_then(serde_json::Value::as_bool)
|
||||||
|
{
|
||||||
|
assert!(
|
||||||
|
!success_flag,
|
||||||
|
"expected tool output to mark success=false for parse failures"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
450
codex-rs/core/tests/suite/tools.rs
Normal file
450
codex-rs/core/tests/suite/tools.rs
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
#![allow(clippy::unwrap_used, clippy::expect_used)]
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_custom_tool_call;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::mount_sse_sequence;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use serde_json::Value;
|
||||||
|
use serde_json::json;
|
||||||
|
use wiremock::Request;
|
||||||
|
|
||||||
|
async fn submit_turn(
|
||||||
|
test: &TestCodex,
|
||||||
|
prompt: &str,
|
||||||
|
approval_policy: AskForApproval,
|
||||||
|
sandbox_policy: SandboxPolicy,
|
||||||
|
) -> Result<()> {
|
||||||
|
let session_model = test.session_configured.model.clone();
|
||||||
|
|
||||||
|
test.codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: prompt.into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: test.cwd.path().to_path_buf(),
|
||||||
|
approval_policy,
|
||||||
|
sandbox_policy,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = test.codex.next_event().await?;
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn request_bodies(requests: &[Request]) -> Result<Vec<Value>> {
|
||||||
|
requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| Ok(serde_json::from_slice::<Value>(&req.body)?))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_output_items<'a>(bodies: &'a [Value], ty: &str) -> Vec<&'a Value> {
|
||||||
|
let mut out = Vec::new();
|
||||||
|
for body in bodies {
|
||||||
|
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||||
|
for item in items {
|
||||||
|
if item.get("type").and_then(Value::as_str) == Some(ty) {
|
||||||
|
out.push(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tool_names(body: &Value) -> Vec<String> {
|
||||||
|
body.get("tools")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|tools| {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|tool| {
|
||||||
|
tool.get("name")
|
||||||
|
.or_else(|| tool.get("type"))
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(str::to_string)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn custom_tool_unknown_returns_custom_output_error() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let mut builder = test_codex();
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "custom-unsupported";
|
||||||
|
let tool_name = "unsupported_tool";
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_custom_tool_call(call_id, tool_name, "\"payload\""),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
submit_turn(
|
||||||
|
&test,
|
||||||
|
"invoke custom tool",
|
||||||
|
AskForApproval::Never,
|
||||||
|
SandboxPolicy::DangerFullAccess,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
let bodies = request_bodies(&requests)?;
|
||||||
|
let custom_items = collect_output_items(&bodies, "custom_tool_call_output");
|
||||||
|
assert_eq!(custom_items.len(), 1, "expected single custom tool output");
|
||||||
|
let item = custom_items[0];
|
||||||
|
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(call_id));
|
||||||
|
|
||||||
|
let output = item
|
||||||
|
.get("output")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.unwrap_or_default();
|
||||||
|
let expected = format!("unsupported custom tool call: {tool_name}");
|
||||||
|
assert_eq!(output, expected);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn shell_escalated_permissions_rejected_then_ok() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let mut builder = test_codex();
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let command = ["/bin/echo", "shell ok"];
|
||||||
|
let call_id_blocked = "shell-blocked";
|
||||||
|
let call_id_success = "shell-success";
|
||||||
|
|
||||||
|
let first_args = json!({
|
||||||
|
"command": command,
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
"with_escalated_permissions": true,
|
||||||
|
});
|
||||||
|
let second_args = json!({
|
||||||
|
"command": command,
|
||||||
|
"timeout_ms": 1_000,
|
||||||
|
});
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call(
|
||||||
|
call_id_blocked,
|
||||||
|
"shell",
|
||||||
|
&serde_json::to_string(&first_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||||
|
ev_function_call(
|
||||||
|
call_id_success,
|
||||||
|
"shell",
|
||||||
|
&serde_json::to_string(&second_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-3"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
submit_turn(
|
||||||
|
&test,
|
||||||
|
"run the shell command",
|
||||||
|
AskForApproval::Never,
|
||||||
|
SandboxPolicy::DangerFullAccess,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
let bodies = request_bodies(&requests)?;
|
||||||
|
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||||
|
for item in &function_outputs {
|
||||||
|
let call_id = item
|
||||||
|
.get("call_id")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
call_id == call_id_blocked || call_id == call_id_success,
|
||||||
|
"unexpected call id {call_id}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let policy = AskForApproval::Never;
|
||||||
|
let expected_message = format!(
|
||||||
|
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let blocked_outputs: Vec<&Value> = function_outputs
|
||||||
|
.iter()
|
||||||
|
.filter(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_blocked))
|
||||||
|
.copied()
|
||||||
|
.collect();
|
||||||
|
assert!(
|
||||||
|
!blocked_outputs.is_empty(),
|
||||||
|
"expected at least one rejection output for {call_id_blocked}"
|
||||||
|
);
|
||||||
|
for item in blocked_outputs {
|
||||||
|
assert_eq!(
|
||||||
|
item.get("output").and_then(Value::as_str),
|
||||||
|
Some(expected_message.as_str()),
|
||||||
|
"unexpected rejection message"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let success_item = function_outputs
|
||||||
|
.iter()
|
||||||
|
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id_success))
|
||||||
|
.expect("success output present");
|
||||||
|
let output_json: Value = serde_json::from_str(
|
||||||
|
success_item
|
||||||
|
.get("output")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.expect("success output string"),
|
||||||
|
)?;
|
||||||
|
assert_eq!(
|
||||||
|
output_json["metadata"]["exit_code"].as_i64(),
|
||||||
|
Some(0),
|
||||||
|
"expected exit code 0 after rerunning without escalation",
|
||||||
|
);
|
||||||
|
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
stdout.contains("shell ok"),
|
||||||
|
"expected stdout to include command output, got {stdout:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn local_shell_missing_ids_maps_to_function_output_error() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let mut builder = test_codex();
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let local_shell_event = json!({
|
||||||
|
"type": "response.output_item.done",
|
||||||
|
"item": {
|
||||||
|
"type": "local_shell_call",
|
||||||
|
"status": "completed",
|
||||||
|
"action": {
|
||||||
|
"type": "exec",
|
||||||
|
"command": ["/bin/echo", "hi"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
local_shell_event,
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
submit_turn(
|
||||||
|
&test,
|
||||||
|
"check shell output",
|
||||||
|
AskForApproval::Never,
|
||||||
|
SandboxPolicy::DangerFullAccess,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
let bodies = request_bodies(&requests)?;
|
||||||
|
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||||
|
assert_eq!(
|
||||||
|
function_outputs.len(),
|
||||||
|
1,
|
||||||
|
"expected a single function output"
|
||||||
|
);
|
||||||
|
let item = function_outputs[0];
|
||||||
|
assert_eq!(item.get("call_id").and_then(Value::as_str), Some(""));
|
||||||
|
assert_eq!(
|
||||||
|
item.get("output").and_then(Value::as_str),
|
||||||
|
Some("LocalShellCall without call_id or id"),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn collect_tools(use_unified_exec: bool) -> Result<Vec<String>> {
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let responses = vec![sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
])];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(move |config| {
|
||||||
|
config.use_experimental_unified_exec_tool = use_unified_exec;
|
||||||
|
});
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
submit_turn(
|
||||||
|
&test,
|
||||||
|
"list tools",
|
||||||
|
AskForApproval::Never,
|
||||||
|
SandboxPolicy::DangerFullAccess,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert_eq!(
|
||||||
|
requests.len(),
|
||||||
|
1,
|
||||||
|
"expected a single request for tools collection"
|
||||||
|
);
|
||||||
|
let bodies = request_bodies(&requests)?;
|
||||||
|
let first_body = bodies.first().expect("request body present");
|
||||||
|
Ok(tool_names(first_body))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn unified_exec_spec_toggle_end_to_end() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let tools_disabled = collect_tools(false).await?;
|
||||||
|
assert!(
|
||||||
|
!tools_disabled.iter().any(|name| name == "unified_exec"),
|
||||||
|
"tools list should not include unified_exec when disabled: {tools_disabled:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
let tools_enabled = collect_tools(true).await?;
|
||||||
|
assert!(
|
||||||
|
tools_enabled.iter().any(|name| name == "unified_exec"),
|
||||||
|
"tools list should include unified_exec when enabled: {tools_enabled:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn shell_timeout_includes_timeout_prefix_and_metadata() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
let mut builder = test_codex();
|
||||||
|
let test = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let call_id = "shell-timeout";
|
||||||
|
let timeout_ms = 50u64;
|
||||||
|
let args = json!({
|
||||||
|
"command": ["/bin/sh", "-c", "yes line | head -n 400; sleep 1"],
|
||||||
|
"timeout_ms": timeout_ms,
|
||||||
|
});
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call(call_id, "shell", &serde_json::to_string(&args)?),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
submit_turn(
|
||||||
|
&test,
|
||||||
|
"run a long command",
|
||||||
|
AskForApproval::Never,
|
||||||
|
SandboxPolicy::DangerFullAccess,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
let bodies = request_bodies(&requests)?;
|
||||||
|
let function_outputs = collect_output_items(&bodies, "function_call_output");
|
||||||
|
let timeout_item = function_outputs
|
||||||
|
.iter()
|
||||||
|
.find(|item| item.get("call_id").and_then(Value::as_str) == Some(call_id))
|
||||||
|
.expect("timeout output present");
|
||||||
|
|
||||||
|
let output_json: Value = serde_json::from_str(
|
||||||
|
timeout_item
|
||||||
|
.get("output")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.expect("timeout output string"),
|
||||||
|
)?;
|
||||||
|
assert_eq!(
|
||||||
|
output_json["metadata"]["exit_code"].as_i64(),
|
||||||
|
Some(124),
|
||||||
|
"expected timeout exit code 124",
|
||||||
|
);
|
||||||
|
|
||||||
|
let stdout = output_json["output"].as_str().unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
stdout.starts_with("command timed out after "),
|
||||||
|
"expected timeout prefix, got {stdout:?}"
|
||||||
|
);
|
||||||
|
let first_line = stdout.lines().next().unwrap_or_default();
|
||||||
|
let duration_ms = first_line
|
||||||
|
.strip_prefix("command timed out after ")
|
||||||
|
.and_then(|line| line.strip_suffix(" milliseconds"))
|
||||||
|
.and_then(|value| value.parse::<u64>().ok())
|
||||||
|
.unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
duration_ms >= timeout_ms,
|
||||||
|
"expected duration >= configured timeout, got {duration_ms} (timeout {timeout_ms})"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
stdout.contains("[... omitted"),
|
||||||
|
"expected truncated output marker, got {stdout:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
280
codex-rs/core/tests/suite/unified_exec.rs
Normal file
280
codex-rs/core/tests/suite/unified_exec.rs
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::mount_sse_sequence;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::skip_if_sandbox;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||||
|
item.get("output").and_then(|value| match value {
|
||||||
|
Value::String(text) => Some(text.as_str()),
|
||||||
|
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_tool_outputs(bodies: &[Value]) -> Result<HashMap<String, Value>> {
|
||||||
|
let mut outputs = HashMap::new();
|
||||||
|
for body in bodies {
|
||||||
|
if let Some(items) = body.get("input").and_then(Value::as_array) {
|
||||||
|
for item in items {
|
||||||
|
if item.get("type").and_then(Value::as_str) != Some("function_call_output") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(call_id) = item.get("call_id").and_then(Value::as_str) {
|
||||||
|
let content = extract_output_text(item)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("missing tool output content"))?;
|
||||||
|
let parsed: Value = serde_json::from_str(content)?;
|
||||||
|
outputs.insert(call_id.to_string(), parsed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn unified_exec_reuses_session_via_stdin() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
skip_if_sandbox!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.use_experimental_unified_exec_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let first_call_id = "uexec-start";
|
||||||
|
let first_args = serde_json::json!({
|
||||||
|
"input": ["/bin/cat"],
|
||||||
|
"timeout_ms": 200,
|
||||||
|
});
|
||||||
|
|
||||||
|
let second_call_id = "uexec-stdin";
|
||||||
|
let second_args = serde_json::json!({
|
||||||
|
"input": ["hello unified exec\n"],
|
||||||
|
"session_id": "0",
|
||||||
|
"timeout_ms": 500,
|
||||||
|
});
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call(
|
||||||
|
first_call_id,
|
||||||
|
"unified_exec",
|
||||||
|
&serde_json::to_string(&first_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||||
|
ev_function_call(
|
||||||
|
second_call_id,
|
||||||
|
"unified_exec",
|
||||||
|
&serde_json::to_string(&second_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "all done"),
|
||||||
|
ev_completed("resp-3"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "run unified exec".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let outputs = collect_tool_outputs(&bodies)?;
|
||||||
|
|
||||||
|
let start_output = outputs
|
||||||
|
.get(first_call_id)
|
||||||
|
.expect("missing first unified_exec output");
|
||||||
|
let session_id = start_output["session_id"].as_str().unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
!session_id.is_empty(),
|
||||||
|
"expected session id in first unified_exec response"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
start_output["output"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.is_empty()
|
||||||
|
);
|
||||||
|
|
||||||
|
let reuse_output = outputs
|
||||||
|
.get(second_call_id)
|
||||||
|
.expect("missing reused unified_exec output");
|
||||||
|
assert_eq!(
|
||||||
|
reuse_output["session_id"].as_str().unwrap_or_default(),
|
||||||
|
session_id
|
||||||
|
);
|
||||||
|
let echoed = reuse_output["output"].as_str().unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
echoed.contains("hello unified exec"),
|
||||||
|
"expected echoed output, got {echoed:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn unified_exec_timeout_and_followup_poll() -> Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
skip_if_sandbox!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let mut builder = test_codex().with_config(|config| {
|
||||||
|
config.use_experimental_unified_exec_tool = true;
|
||||||
|
});
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = builder.build(&server).await?;
|
||||||
|
|
||||||
|
let first_call_id = "uexec-timeout";
|
||||||
|
let first_args = serde_json::json!({
|
||||||
|
"input": ["/bin/sh", "-c", "sleep 0.1; echo ready"],
|
||||||
|
"timeout_ms": 10,
|
||||||
|
});
|
||||||
|
|
||||||
|
let second_call_id = "uexec-poll";
|
||||||
|
let second_args = serde_json::json!({
|
||||||
|
"input": Vec::<String>::new(),
|
||||||
|
"session_id": "0",
|
||||||
|
"timeout_ms": 800,
|
||||||
|
});
|
||||||
|
|
||||||
|
let responses = vec![
|
||||||
|
sse(vec![
|
||||||
|
serde_json::json!({"type": "response.created", "response": {"id": "resp-1"}}),
|
||||||
|
ev_function_call(
|
||||||
|
first_call_id,
|
||||||
|
"unified_exec",
|
||||||
|
&serde_json::to_string(&first_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
serde_json::json!({"type": "response.created", "response": {"id": "resp-2"}}),
|
||||||
|
ev_function_call(
|
||||||
|
second_call_id,
|
||||||
|
"unified_exec",
|
||||||
|
&serde_json::to_string(&second_args)?,
|
||||||
|
),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]),
|
||||||
|
sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-3"),
|
||||||
|
]),
|
||||||
|
];
|
||||||
|
mount_sse_sequence(&server, responses).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "check timeout".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(!requests.is_empty(), "expected at least one POST request");
|
||||||
|
|
||||||
|
let bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let outputs = collect_tool_outputs(&bodies)?;
|
||||||
|
|
||||||
|
let first_output = outputs.get(first_call_id).expect("missing timeout output");
|
||||||
|
assert_eq!(first_output["session_id"], "0");
|
||||||
|
assert!(
|
||||||
|
first_output["output"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.is_empty()
|
||||||
|
);
|
||||||
|
|
||||||
|
let poll_output = outputs.get(second_call_id).expect("missing poll output");
|
||||||
|
let output_text = poll_output["output"].as_str().unwrap_or_default();
|
||||||
|
assert!(
|
||||||
|
output_text.contains("ready"),
|
||||||
|
"expected ready output, got {output_text:?}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
351
codex-rs/core/tests/suite/view_image.rs
Normal file
351
codex-rs/core/tests/suite/view_image.rs
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
#![cfg(not(target_os = "windows"))]
|
||||||
|
|
||||||
|
use base64::Engine;
|
||||||
|
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
|
||||||
|
use codex_core::protocol::AskForApproval;
|
||||||
|
use codex_core::protocol::EventMsg;
|
||||||
|
use codex_core::protocol::InputItem;
|
||||||
|
use codex_core::protocol::Op;
|
||||||
|
use codex_core::protocol::SandboxPolicy;
|
||||||
|
use codex_protocol::config_types::ReasoningSummary;
|
||||||
|
use core_test_support::responses;
|
||||||
|
use core_test_support::responses::ev_assistant_message;
|
||||||
|
use core_test_support::responses::ev_completed;
|
||||||
|
use core_test_support::responses::ev_function_call;
|
||||||
|
use core_test_support::responses::sse;
|
||||||
|
use core_test_support::responses::start_mock_server;
|
||||||
|
use core_test_support::skip_if_no_network;
|
||||||
|
use core_test_support::test_codex::TestCodex;
|
||||||
|
use core_test_support::test_codex::test_codex;
|
||||||
|
use serde_json::Value;
|
||||||
|
use wiremock::matchers::any;
|
||||||
|
|
||||||
|
fn function_call_output(body: &Value) -> Option<&Value> {
|
||||||
|
body.get("input")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|items| {
|
||||||
|
items.iter().find(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("function_call_output")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_image_message(body: &Value) -> Option<&Value> {
|
||||||
|
body.get("input")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|items| {
|
||||||
|
items.iter().find(|item| {
|
||||||
|
item.get("type").and_then(Value::as_str) == Some("message")
|
||||||
|
&& item
|
||||||
|
.get("content")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|content| {
|
||||||
|
content.iter().any(|span| {
|
||||||
|
span.get("type").and_then(Value::as_str) == Some("input_image")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_output_text(item: &Value) -> Option<&str> {
|
||||||
|
item.get("output").and_then(|value| match value {
|
||||||
|
Value::String(text) => Some(text.as_str()),
|
||||||
|
Value::Object(obj) => obj.get("content").and_then(Value::as_str),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_request_with_function_call_output(requests: &[Value]) -> Option<&Value> {
|
||||||
|
requests
|
||||||
|
.iter()
|
||||||
|
.find(|body| function_call_output(body).is_some())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn view_image_tool_attaches_local_image() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = test_codex().build(&server).await?;
|
||||||
|
|
||||||
|
let rel_path = "assets/example.png";
|
||||||
|
let abs_path = cwd.path().join(rel_path);
|
||||||
|
if let Some(parent) = abs_path.parent() {
|
||||||
|
std::fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
let image_bytes = b"fake_png_bytes".to_vec();
|
||||||
|
std::fs::write(&abs_path, &image_bytes)?;
|
||||||
|
|
||||||
|
let call_id = "view-image-call";
|
||||||
|
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "view_image", &arguments),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please add the screenshot".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut tool_event = None;
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
match event.msg {
|
||||||
|
EventMsg::ViewImageToolCall(ev) => tool_event = Some(ev),
|
||||||
|
EventMsg::TaskComplete(_) => break,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool_event = tool_event.expect("view image tool event emitted");
|
||||||
|
assert_eq!(tool_event.call_id, call_id);
|
||||||
|
assert_eq!(tool_event.path, abs_path);
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(
|
||||||
|
requests.len() >= 2,
|
||||||
|
"expected at least two POST requests, got {}",
|
||||||
|
requests.len()
|
||||||
|
);
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
assert_eq!(output_text, "attached local image path");
|
||||||
|
|
||||||
|
let image_message = find_image_message(body_with_tool_output)
|
||||||
|
.expect("pending input image message not included in request");
|
||||||
|
let image_url = image_message
|
||||||
|
.get("content")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|content| {
|
||||||
|
content.iter().find_map(|span| {
|
||||||
|
if span.get("type").and_then(Value::as_str) == Some("input_image") {
|
||||||
|
span.get("image_url").and_then(Value::as_str)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.expect("image_url present");
|
||||||
|
|
||||||
|
let expected_image_url = format!(
|
||||||
|
"data:image/png;base64,{}",
|
||||||
|
BASE64_STANDARD.encode(&image_bytes)
|
||||||
|
);
|
||||||
|
assert_eq!(image_url, expected_image_url);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn view_image_tool_errors_when_path_is_directory() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = test_codex().build(&server).await?;
|
||||||
|
|
||||||
|
let rel_path = "assets";
|
||||||
|
let abs_path = cwd.path().join(rel_path);
|
||||||
|
std::fs::create_dir_all(&abs_path)?;
|
||||||
|
|
||||||
|
let call_id = "view-image-directory";
|
||||||
|
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "view_image", &arguments),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please attach the folder".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(
|
||||||
|
requests.len() >= 2,
|
||||||
|
"expected at least two POST requests, got {}",
|
||||||
|
requests.len()
|
||||||
|
);
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
let expected_message = format!("image path `{}` is not a file", abs_path.display());
|
||||||
|
assert_eq!(output_text, expected_message);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
find_image_message(body_with_tool_output).is_none(),
|
||||||
|
"directory path should not produce an input_image message"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn view_image_tool_errors_when_file_missing() -> anyhow::Result<()> {
|
||||||
|
skip_if_no_network!(Ok(()));
|
||||||
|
|
||||||
|
let server = start_mock_server().await;
|
||||||
|
|
||||||
|
let TestCodex {
|
||||||
|
codex,
|
||||||
|
cwd,
|
||||||
|
session_configured,
|
||||||
|
..
|
||||||
|
} = test_codex().build(&server).await?;
|
||||||
|
|
||||||
|
let rel_path = "missing/example.png";
|
||||||
|
let abs_path = cwd.path().join(rel_path);
|
||||||
|
|
||||||
|
let call_id = "view-image-missing";
|
||||||
|
let arguments = serde_json::json!({ "path": rel_path }).to_string();
|
||||||
|
|
||||||
|
let first_response = sse(vec![
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.created",
|
||||||
|
"response": {"id": "resp-1"}
|
||||||
|
}),
|
||||||
|
ev_function_call(call_id, "view_image", &arguments),
|
||||||
|
ev_completed("resp-1"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), first_response).await;
|
||||||
|
|
||||||
|
let second_response = sse(vec![
|
||||||
|
ev_assistant_message("msg-1", "done"),
|
||||||
|
ev_completed("resp-2"),
|
||||||
|
]);
|
||||||
|
responses::mount_sse_once_match(&server, any(), second_response).await;
|
||||||
|
|
||||||
|
let session_model = session_configured.model.clone();
|
||||||
|
|
||||||
|
codex
|
||||||
|
.submit(Op::UserTurn {
|
||||||
|
items: vec![InputItem::Text {
|
||||||
|
text: "please attach the missing image".into(),
|
||||||
|
}],
|
||||||
|
final_output_json_schema: None,
|
||||||
|
cwd: cwd.path().to_path_buf(),
|
||||||
|
approval_policy: AskForApproval::Never,
|
||||||
|
sandbox_policy: SandboxPolicy::DangerFullAccess,
|
||||||
|
model: session_model,
|
||||||
|
effort: None,
|
||||||
|
summary: ReasoningSummary::Auto,
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = codex.next_event().await.expect("event");
|
||||||
|
if matches!(event.msg, EventMsg::TaskComplete(_)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let requests = server.received_requests().await.expect("recorded requests");
|
||||||
|
assert!(
|
||||||
|
requests.len() >= 2,
|
||||||
|
"expected at least two POST requests, got {}",
|
||||||
|
requests.len()
|
||||||
|
);
|
||||||
|
let request_bodies = requests
|
||||||
|
.iter()
|
||||||
|
.map(|req| req.body_json::<Value>().expect("request json"))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let body_with_tool_output = find_request_with_function_call_output(&request_bodies)
|
||||||
|
.expect("function_call_output item not found in requests");
|
||||||
|
let output_item = function_call_output(body_with_tool_output).expect("tool output item");
|
||||||
|
let output_text = extract_output_text(output_item).expect("output text present");
|
||||||
|
let expected_prefix = format!("unable to locate image at `{}`:", abs_path.display());
|
||||||
|
assert!(
|
||||||
|
output_text.starts_with(&expected_prefix),
|
||||||
|
"expected error to start with `{expected_prefix}` but got `{output_text}`"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
find_image_message(body_with_tool_output).is_none(),
|
||||||
|
"missing file should not produce an input_image message"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
use codex_common::elapsed::format_duration;
|
use codex_common::elapsed::format_duration;
|
||||||
use codex_common::elapsed::format_elapsed;
|
use codex_common::elapsed::format_elapsed;
|
||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
use codex_core::protocol::AgentMessageEvent;
|
use codex_core::protocol::AgentMessageEvent;
|
||||||
use codex_core::protocol::AgentReasoningRawContentEvent;
|
use codex_core::protocol::AgentReasoningRawContentEvent;
|
||||||
use codex_core::protocol::BackgroundEventEvent;
|
use codex_core::protocol::BackgroundEventEvent;
|
||||||
@@ -35,6 +34,8 @@ use crate::event_processor::CodexStatus;
|
|||||||
use crate::event_processor::EventProcessor;
|
use crate::event_processor::EventProcessor;
|
||||||
use crate::event_processor::handle_last_message;
|
use crate::event_processor::handle_last_message;
|
||||||
use codex_common::create_config_summary_entries;
|
use codex_common::create_config_summary_entries;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
|
|
||||||
/// This should be configurable. When used in CI, users may not want to impose
|
/// This should be configurable. When used in CI, users may not want to impose
|
||||||
/// a limit so they can see the full transcript.
|
/// a limit so they can see the full transcript.
|
||||||
@@ -456,7 +457,6 @@ impl EventProcessor for EventProcessorWithHumanOutput {
|
|||||||
|
|
||||||
// Pretty-print the plan items with simple status markers.
|
// Pretty-print the plan items with simple status markers.
|
||||||
for item in plan {
|
for item in plan {
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
match item.status {
|
match item.status {
|
||||||
StepStatus::Completed => {
|
StepStatus::Completed => {
|
||||||
ts_println!(self, " {} {}", "✓".style(self.green), item.step);
|
ts_println!(self, " {} {}", "✓".style(self.green), item.step);
|
||||||
|
|||||||
@@ -31,8 +31,6 @@ use crate::exec_events::TurnStartedEvent;
|
|||||||
use crate::exec_events::Usage;
|
use crate::exec_events::Usage;
|
||||||
use crate::exec_events::WebSearchItem;
|
use crate::exec_events::WebSearchItem;
|
||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
use codex_core::protocol::AgentMessageEvent;
|
use codex_core::protocol::AgentMessageEvent;
|
||||||
use codex_core::protocol::AgentReasoningEvent;
|
use codex_core::protocol::AgentReasoningEvent;
|
||||||
use codex_core::protocol::Event;
|
use codex_core::protocol::Event;
|
||||||
@@ -48,6 +46,8 @@ use codex_core::protocol::SessionConfiguredEvent;
|
|||||||
use codex_core::protocol::TaskCompleteEvent;
|
use codex_core::protocol::TaskCompleteEvent;
|
||||||
use codex_core::protocol::TaskStartedEvent;
|
use codex_core::protocol::TaskStartedEvent;
|
||||||
use codex_core::protocol::WebSearchEndEvent;
|
use codex_core::protocol::WebSearchEndEvent;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
|||||||
codex_linux_sandbox_exe,
|
codex_linux_sandbox_exe,
|
||||||
base_instructions: None,
|
base_instructions: None,
|
||||||
include_plan_tool: Some(include_plan_tool),
|
include_plan_tool: Some(include_plan_tool),
|
||||||
include_apply_patch_tool: None,
|
include_apply_patch_tool: Some(true),
|
||||||
include_view_image_tool: None,
|
include_view_image_tool: None,
|
||||||
show_raw_agent_reasoning: oss.then_some(true),
|
show_raw_agent_reasoning: oss.then_some(true),
|
||||||
tools_web_search_request: None,
|
tools_web_search_request: None,
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ use codex_exec::exec_events::TurnFailedEvent;
|
|||||||
use codex_exec::exec_events::TurnStartedEvent;
|
use codex_exec::exec_events::TurnStartedEvent;
|
||||||
use codex_exec::exec_events::Usage;
|
use codex_exec::exec_events::Usage;
|
||||||
use codex_exec::exec_events::WebSearchItem;
|
use codex_exec::exec_events::WebSearchItem;
|
||||||
|
use codex_protocol::plan_tool::PlanItemArg;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use mcp_types::CallToolResult;
|
use mcp_types::CallToolResult;
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
@@ -115,10 +118,6 @@ fn web_search_end_emits_item_completed() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn plan_update_emits_todo_list_started_updated_and_completed() {
|
fn plan_update_emits_todo_list_started_updated_and_completed() {
|
||||||
use codex_core::plan_tool::PlanItemArg;
|
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
|
|
||||||
let mut ep = EventProcessorWithJsonOutput::new(None);
|
let mut ep = EventProcessorWithJsonOutput::new(None);
|
||||||
|
|
||||||
// First plan update => item.started (todo_list)
|
// First plan update => item.started (todo_list)
|
||||||
@@ -339,10 +338,6 @@ fn mcp_tool_call_failure_sets_failed_status() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn plan_update_after_complete_starts_new_todo_list_with_new_id() {
|
fn plan_update_after_complete_starts_new_todo_list_with_new_id() {
|
||||||
use codex_core::plan_tool::PlanItemArg;
|
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
|
|
||||||
let mut ep = EventProcessorWithJsonOutput::new(None);
|
let mut ep = EventProcessorWithJsonOutput::new(None);
|
||||||
|
|
||||||
// First turn: start + complete
|
// First turn: start + complete
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ use eventsource_stream::EventStreamError as StreamError;
|
|||||||
use reqwest::Error;
|
use reqwest::Error;
|
||||||
use reqwest::Response;
|
use reqwest::Response;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
use std::borrow::Cow;
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
@@ -366,10 +367,10 @@ impl OtelEventManager {
|
|||||||
call_id: &str,
|
call_id: &str,
|
||||||
arguments: &str,
|
arguments: &str,
|
||||||
f: F,
|
f: F,
|
||||||
) -> Result<String, E>
|
) -> Result<(String, bool), E>
|
||||||
where
|
where
|
||||||
F: FnOnce() -> Fut,
|
F: FnOnce() -> Fut,
|
||||||
Fut: Future<Output = Result<String, E>>,
|
Fut: Future<Output = Result<(String, bool), E>>,
|
||||||
E: Display,
|
E: Display,
|
||||||
{
|
{
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
@@ -377,10 +378,12 @@ impl OtelEventManager {
|
|||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
|
|
||||||
let (output, success) = match &result {
|
let (output, success) = match &result {
|
||||||
Ok(content) => (content, true),
|
Ok((preview, success)) => (Cow::Borrowed(preview.as_str()), *success),
|
||||||
Err(error) => (&error.to_string(), false),
|
Err(error) => (Cow::Owned(error.to_string()), false),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let success_str = if success { "true" } else { "false" };
|
||||||
|
|
||||||
tracing::event!(
|
tracing::event!(
|
||||||
tracing::Level::INFO,
|
tracing::Level::INFO,
|
||||||
event.name = "codex.tool_result",
|
event.name = "codex.tool_result",
|
||||||
@@ -396,7 +399,8 @@ impl OtelEventManager {
|
|||||||
call_id = %call_id,
|
call_id = %call_id,
|
||||||
arguments = %arguments,
|
arguments = %arguments,
|
||||||
duration_ms = %duration.as_millis(),
|
duration_ms = %duration.as_millis(),
|
||||||
success = %success,
|
success = %success_str,
|
||||||
|
// `output` is truncated by the tool layer before reaching telemetry.
|
||||||
output = %output,
|
output = %output,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -259,6 +259,7 @@ pub struct ShellToolCallParams {
|
|||||||
#[derive(Debug, Clone, PartialEq, TS)]
|
#[derive(Debug, Clone, PartialEq, TS)]
|
||||||
pub struct FunctionCallOutputPayload {
|
pub struct FunctionCallOutputPayload {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
// TODO(jif) drop this.
|
||||||
pub success: Option<bool>,
|
pub success: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ use codex_git_tooling::GhostCommit;
|
|||||||
use codex_git_tooling::GitToolingError;
|
use codex_git_tooling::GitToolingError;
|
||||||
use codex_git_tooling::create_ghost_commit;
|
use codex_git_tooling::create_ghost_commit;
|
||||||
use codex_git_tooling::restore_ghost_commit;
|
use codex_git_tooling::restore_ghost_commit;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use strum::IntoEnumIterator;
|
use strum::IntoEnumIterator;
|
||||||
|
|
||||||
const MAX_TRACKED_GHOST_COMMITS: usize = 20;
|
const MAX_TRACKED_GHOST_COMMITS: usize = 20;
|
||||||
@@ -508,7 +509,7 @@ impl ChatWidget {
|
|||||||
self.request_redraw();
|
self.request_redraw();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_plan_update(&mut self, update: codex_core::plan_tool::UpdatePlanArgs) {
|
fn on_plan_update(&mut self, update: UpdatePlanArgs) {
|
||||||
self.add_to_history(history_cell::new_plan_update(update));
|
self.add_to_history(history_cell::new_plan_update(update));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ use codex_core::CodexAuth;
|
|||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::config::ConfigOverrides;
|
use codex_core::config::ConfigOverrides;
|
||||||
use codex_core::config::ConfigToml;
|
use codex_core::config::ConfigToml;
|
||||||
use codex_core::plan_tool::PlanItemArg;
|
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
use codex_core::protocol::AgentMessageDeltaEvent;
|
use codex_core::protocol::AgentMessageDeltaEvent;
|
||||||
use codex_core::protocol::AgentMessageEvent;
|
use codex_core::protocol::AgentMessageEvent;
|
||||||
use codex_core::protocol::AgentReasoningDeltaEvent;
|
use codex_core::protocol::AgentReasoningDeltaEvent;
|
||||||
@@ -37,6 +34,9 @@ use codex_core::protocol::TaskCompleteEvent;
|
|||||||
use codex_core::protocol::TaskStartedEvent;
|
use codex_core::protocol::TaskStartedEvent;
|
||||||
use codex_core::protocol::ViewImageToolCallEvent;
|
use codex_core::protocol::ViewImageToolCallEvent;
|
||||||
use codex_protocol::ConversationId;
|
use codex_protocol::ConversationId;
|
||||||
|
use codex_protocol::plan_tool::PlanItemArg;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use crossterm::event::KeyCode;
|
use crossterm::event::KeyCode;
|
||||||
use crossterm::event::KeyEvent;
|
use crossterm::event::KeyEvent;
|
||||||
use crossterm::event::KeyModifiers;
|
use crossterm::event::KeyModifiers;
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ use base64::Engine;
|
|||||||
use codex_core::config::Config;
|
use codex_core::config::Config;
|
||||||
use codex_core::config_types::McpServerTransportConfig;
|
use codex_core::config_types::McpServerTransportConfig;
|
||||||
use codex_core::config_types::ReasoningSummaryFormat;
|
use codex_core::config_types::ReasoningSummaryFormat;
|
||||||
use codex_core::plan_tool::PlanItemArg;
|
|
||||||
use codex_core::plan_tool::StepStatus;
|
|
||||||
use codex_core::plan_tool::UpdatePlanArgs;
|
|
||||||
use codex_core::protocol::FileChange;
|
use codex_core::protocol::FileChange;
|
||||||
use codex_core::protocol::McpInvocation;
|
use codex_core::protocol::McpInvocation;
|
||||||
use codex_core::protocol::SessionConfiguredEvent;
|
use codex_core::protocol::SessionConfiguredEvent;
|
||||||
use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
|
use codex_core::protocol_config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||||
|
use codex_protocol::plan_tool::PlanItemArg;
|
||||||
|
use codex_protocol::plan_tool::StepStatus;
|
||||||
|
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||||
use image::DynamicImage;
|
use image::DynamicImage;
|
||||||
use image::ImageReader;
|
use image::ImageReader;
|
||||||
use mcp_types::EmbeddedResourceResource;
|
use mcp_types::EmbeddedResourceResource;
|
||||||
|
|||||||
7
codex-rs/utils/string/Cargo.toml
Normal file
7
codex-rs/utils/string/Cargo.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[package]
|
||||||
|
edition.workspace = true
|
||||||
|
name = "codex-utils-string"
|
||||||
|
version.workspace = true
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
38
codex-rs/utils/string/src/lib.rs
Normal file
38
codex-rs/utils/string/src/lib.rs
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// Truncate a &str to a byte budget at a char boundary (prefix)
|
||||||
|
#[inline]
|
||||||
|
pub fn take_bytes_at_char_boundary(s: &str, maxb: usize) -> &str {
|
||||||
|
if s.len() <= maxb {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
let mut last_ok = 0;
|
||||||
|
for (i, ch) in s.char_indices() {
|
||||||
|
let nb = i + ch.len_utf8();
|
||||||
|
if nb > maxb {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
last_ok = nb;
|
||||||
|
}
|
||||||
|
&s[..last_ok]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take a suffix of a &str within a byte budget at a char boundary
|
||||||
|
#[inline]
|
||||||
|
pub fn take_last_bytes_at_char_boundary(s: &str, maxb: usize) -> &str {
|
||||||
|
if s.len() <= maxb {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
let mut start = s.len();
|
||||||
|
let mut used = 0usize;
|
||||||
|
for (i, ch) in s.char_indices().rev() {
|
||||||
|
let nb = ch.len_utf8();
|
||||||
|
if used + nb > maxb {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
start = i;
|
||||||
|
used += nb;
|
||||||
|
if start == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
&s[start..]
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user