chore: refactor tool handling (#4510)
# Tool System Refactor - Centralizes tool definitions and execution in `core/src/tools/*`: specs (`spec.rs`), handlers (`handlers/*`), router (`router.rs`), registry/dispatch (`registry.rs`), and shared context (`context.rs`). One registry now builds the model-visible tool list and binds handlers. - Router converts model responses to tool calls; Registry dispatches with consistent telemetry via `codex-rs/otel` and unified error handling. Function, Local Shell, MCP, and experimental `unified_exec` all flow through this path; legacy shell aliases still work. - Rationale: reduce per‑tool boilerplate, keep spec/handler in sync, and make adding tools predictable and testable. Example: `read_file` - Spec: `core/src/tools/spec.rs` (see `create_read_file_tool`, registered by `build_specs`). - Handler: `core/src/tools/handlers/read_file.rs` (absolute `file_path`, 1‑indexed `offset`, `limit`, `L#: ` prefixes, safe truncation). - E2E test: `core/tests/suite/read_file.rs` validates the tool returns the requested lines. ## Next steps: - Decompose `handle_container_exec_with_params` - Add parallel tool calls
This commit is contained in:
197
codex-rs/core/src/tools/registry.rs
Normal file
197
codex-rs/core/src/tools/registry.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum ToolKind {
|
||||
Function,
|
||||
UnifiedExec,
|
||||
Mcp,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ToolHandler: Send + Sync {
|
||||
fn kind(&self) -> ToolKind;
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
(self.kind(), payload),
|
||||
(ToolKind::Function, ToolPayload::Function { .. })
|
||||
| (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. })
|
||||
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||||
-> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||||
Self { handlers }
|
||||
}
|
||||
|
||||
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||||
self.handlers.get(name).map(Arc::clone)
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
// let name = name.into();
|
||||
// if self.handlers.insert(name.clone(), handler).is_some() {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch<'a>(
|
||||
&self,
|
||||
invocation: ToolInvocation<'a>,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.client.get_otel_event_manager();
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
let log_payload = payload_for_response.log_payload();
|
||||
|
||||
let handler = match self.handler(tool_name.as_ref()) {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
let message =
|
||||
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::RespondToModel(message));
|
||||
}
|
||||
};
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::Fatal(message));
|
||||
}
|
||||
|
||||
let output_cell = tokio::sync::Mutex::new(None);
|
||||
|
||||
let result = otel
|
||||
.log_tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
|| {
|
||||
let handler = handler.clone();
|
||||
let output_cell = &output_cell;
|
||||
let invocation = invocation;
|
||||
async move {
|
||||
match handler.handle(invocation).await {
|
||||
Ok(output) => {
|
||||
let preview = output.log_preview();
|
||||
let success = output.success_for_logging();
|
||||
let mut guard = output_cell.lock().await;
|
||||
*guard = Some(output);
|
||||
Ok((preview, success))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
let output = guard.take().ok_or_else(|| {
|
||||
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||
})?;
|
||||
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
specs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.specs.push(spec);
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
let name = name.into();
|
||||
if self
|
||||
.handlers
|
||||
.insert(name.clone(), handler.clone())
|
||||
.is_some()
|
||||
{
|
||||
warn!("overwriting handler for tool {name}");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
||||
// where
|
||||
// I: IntoIterator,
|
||||
// I::Item: Into<String>,
|
||||
// {
|
||||
// for name in names {
|
||||
// let name = name.into();
|
||||
// if self
|
||||
// .handlers
|
||||
// .insert(name.clone(), handler.clone())
|
||||
// .is_some()
|
||||
// {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
}
|
||||
|
||||
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
||||
match payload {
|
||||
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||||
_ => format!("unsupported call: {tool_name}"),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user