use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use codex_protocol::models::ResponseInputItem; use tracing::warn; use crate::client_common::tools::ToolSpec; use crate::function_tool::FunctionCallError; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub enum ToolKind { Function, Mcp, } #[async_trait] pub trait ToolHandler: Send + Sync { fn kind(&self) -> ToolKind; fn matches_kind(&self, payload: &ToolPayload) -> bool { matches!( (self.kind(), payload), (ToolKind::Function, ToolPayload::Function { .. }) | (ToolKind::Mcp, ToolPayload::Mcp { .. }) ) } async fn handle(&self, invocation: ToolInvocation) -> Result; } pub struct ToolRegistry { handlers: HashMap>, } impl ToolRegistry { pub fn new(handlers: HashMap>) -> Self { Self { handlers } } pub fn handler(&self, name: &str) -> Option> { self.handlers.get(name).map(Arc::clone) } // TODO(jif) for dynamic tools. // pub fn register(&mut self, name: impl Into, handler: Arc) { // let name = name.into(); // if self.handlers.insert(name.clone(), handler).is_some() { // warn!("overwriting handler for tool {name}"); // } // } pub async fn dispatch( &self, invocation: ToolInvocation, ) -> Result { let tool_name = invocation.tool_name.clone(); let call_id_owned = invocation.call_id.clone(); let otel = invocation.turn.client.get_otel_event_manager(); let payload_for_response = invocation.payload.clone(); let log_payload = payload_for_response.log_payload(); let handler = match self.handler(tool_name.as_ref()) { Some(handler) => handler, None => { let message = unsupported_tool_call_message(&invocation.payload, tool_name.as_ref()); otel.tool_result( tool_name.as_ref(), &call_id_owned, log_payload.as_ref(), Duration::ZERO, false, &message, ); return Err(FunctionCallError::RespondToModel(message)); } }; if !handler.matches_kind(&invocation.payload) { let message = format!("tool {tool_name} invoked with incompatible payload"); otel.tool_result( tool_name.as_ref(), &call_id_owned, log_payload.as_ref(), Duration::ZERO, false, &message, ); return Err(FunctionCallError::Fatal(message)); } let output_cell = tokio::sync::Mutex::new(None); let result = otel .log_tool_result( tool_name.as_ref(), &call_id_owned, log_payload.as_ref(), || { let handler = handler.clone(); let output_cell = &output_cell; let invocation = invocation; async move { match handler.handle(invocation).await { Ok(output) => { let preview = output.log_preview(); let success = output.success_for_logging(); let mut guard = output_cell.lock().await; *guard = Some(output); Ok((preview, success)) } Err(err) => Err(err), } } }, ) .await; match result { Ok(_) => { let mut guard = output_cell.lock().await; let output = guard.take().ok_or_else(|| { FunctionCallError::Fatal("tool produced no output".to_string()) })?; Ok(output.into_response(&call_id_owned, &payload_for_response)) } Err(err) => Err(err), } } } #[derive(Debug, Clone)] pub struct ConfiguredToolSpec { pub spec: ToolSpec, pub supports_parallel_tool_calls: bool, } impl ConfiguredToolSpec { pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self { Self { spec, supports_parallel_tool_calls, } } } pub struct ToolRegistryBuilder { handlers: HashMap>, specs: Vec, } impl ToolRegistryBuilder { pub fn new() -> Self { Self { handlers: HashMap::new(), specs: Vec::new(), } } pub fn push_spec(&mut self, spec: ToolSpec) { self.push_spec_with_parallel_support(spec, false); } pub fn push_spec_with_parallel_support( &mut self, spec: ToolSpec, supports_parallel_tool_calls: bool, ) { self.specs .push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls)); } pub fn register_handler(&mut self, name: impl Into, handler: Arc) { let name = name.into(); if self .handlers .insert(name.clone(), handler.clone()) .is_some() { warn!("overwriting handler for tool {name}"); } } // TODO(jif) for dynamic tools. // pub fn register_many(&mut self, names: I, handler: Arc) // where // I: IntoIterator, // I::Item: Into, // { // for name in names { // let name = name.into(); // if self // .handlers // .insert(name.clone(), handler.clone()) // .is_some() // { // warn!("overwriting handler for tool {name}"); // } // } // } pub fn build(self) -> (Vec, ToolRegistry) { let registry = ToolRegistry::new(self.handlers); (self.specs, registry) } } fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String { match payload { ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"), _ => format!("unsupported call: {tool_name}"), } }