198 lines
6.1 KiB
Rust
198 lines
6.1 KiB
Rust
|
|
use std::collections::HashMap;
|
||
|
|
use std::sync::Arc;
|
||
|
|
use std::time::Duration;
|
||
|
|
|
||
|
|
use async_trait::async_trait;
|
||
|
|
use codex_protocol::models::ResponseInputItem;
|
||
|
|
use tracing::warn;
|
||
|
|
|
||
|
|
use crate::client_common::tools::ToolSpec;
|
||
|
|
use crate::function_tool::FunctionCallError;
|
||
|
|
use crate::tools::context::ToolInvocation;
|
||
|
|
use crate::tools::context::ToolOutput;
|
||
|
|
use crate::tools::context::ToolPayload;
|
||
|
|
|
||
|
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||
|
|
pub enum ToolKind {
|
||
|
|
Function,
|
||
|
|
UnifiedExec,
|
||
|
|
Mcp,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[async_trait]
|
||
|
|
pub trait ToolHandler: Send + Sync {
|
||
|
|
fn kind(&self) -> ToolKind;
|
||
|
|
|
||
|
|
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||
|
|
matches!(
|
||
|
|
(self.kind(), payload),
|
||
|
|
(ToolKind::Function, ToolPayload::Function { .. })
|
||
|
|
| (ToolKind::UnifiedExec, ToolPayload::UnifiedExec { .. })
|
||
|
|
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn handle(&self, invocation: ToolInvocation<'_>)
|
||
|
|
-> Result<ToolOutput, FunctionCallError>;
|
||
|
|
}
|
||
|
|
|
||
|
|
pub struct ToolRegistry {
|
||
|
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ToolRegistry {
|
||
|
|
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||
|
|
Self { handlers }
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||
|
|
self.handlers.get(name).map(Arc::clone)
|
||
|
|
}
|
||
|
|
|
||
|
|
// TODO(jif) for dynamic tools.
|
||
|
|
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||
|
|
// let name = name.into();
|
||
|
|
// if self.handlers.insert(name.clone(), handler).is_some() {
|
||
|
|
// warn!("overwriting handler for tool {name}");
|
||
|
|
// }
|
||
|
|
// }
|
||
|
|
|
||
|
|
pub async fn dispatch<'a>(
|
||
|
|
&self,
|
||
|
|
invocation: ToolInvocation<'a>,
|
||
|
|
) -> Result<ResponseInputItem, FunctionCallError> {
|
||
|
|
let tool_name = invocation.tool_name.clone();
|
||
|
|
let call_id_owned = invocation.call_id.clone();
|
||
|
|
let otel = invocation.turn.client.get_otel_event_manager();
|
||
|
|
let payload_for_response = invocation.payload.clone();
|
||
|
|
let log_payload = payload_for_response.log_payload();
|
||
|
|
|
||
|
|
let handler = match self.handler(tool_name.as_ref()) {
|
||
|
|
Some(handler) => handler,
|
||
|
|
None => {
|
||
|
|
let message =
|
||
|
|
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||
|
|
otel.tool_result(
|
||
|
|
tool_name.as_ref(),
|
||
|
|
&call_id_owned,
|
||
|
|
log_payload.as_ref(),
|
||
|
|
Duration::ZERO,
|
||
|
|
false,
|
||
|
|
&message,
|
||
|
|
);
|
||
|
|
return Err(FunctionCallError::RespondToModel(message));
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
if !handler.matches_kind(&invocation.payload) {
|
||
|
|
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||
|
|
otel.tool_result(
|
||
|
|
tool_name.as_ref(),
|
||
|
|
&call_id_owned,
|
||
|
|
log_payload.as_ref(),
|
||
|
|
Duration::ZERO,
|
||
|
|
false,
|
||
|
|
&message,
|
||
|
|
);
|
||
|
|
return Err(FunctionCallError::Fatal(message));
|
||
|
|
}
|
||
|
|
|
||
|
|
let output_cell = tokio::sync::Mutex::new(None);
|
||
|
|
|
||
|
|
let result = otel
|
||
|
|
.log_tool_result(
|
||
|
|
tool_name.as_ref(),
|
||
|
|
&call_id_owned,
|
||
|
|
log_payload.as_ref(),
|
||
|
|
|| {
|
||
|
|
let handler = handler.clone();
|
||
|
|
let output_cell = &output_cell;
|
||
|
|
let invocation = invocation;
|
||
|
|
async move {
|
||
|
|
match handler.handle(invocation).await {
|
||
|
|
Ok(output) => {
|
||
|
|
let preview = output.log_preview();
|
||
|
|
let success = output.success_for_logging();
|
||
|
|
let mut guard = output_cell.lock().await;
|
||
|
|
*guard = Some(output);
|
||
|
|
Ok((preview, success))
|
||
|
|
}
|
||
|
|
Err(err) => Err(err),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
},
|
||
|
|
)
|
||
|
|
.await;
|
||
|
|
|
||
|
|
match result {
|
||
|
|
Ok(_) => {
|
||
|
|
let mut guard = output_cell.lock().await;
|
||
|
|
let output = guard.take().ok_or_else(|| {
|
||
|
|
FunctionCallError::Fatal("tool produced no output".to_string())
|
||
|
|
})?;
|
||
|
|
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||
|
|
}
|
||
|
|
Err(err) => Err(err),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub struct ToolRegistryBuilder {
|
||
|
|
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||
|
|
specs: Vec<ToolSpec>,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl ToolRegistryBuilder {
|
||
|
|
pub fn new() -> Self {
|
||
|
|
Self {
|
||
|
|
handlers: HashMap::new(),
|
||
|
|
specs: Vec::new(),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||
|
|
self.specs.push(spec);
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||
|
|
let name = name.into();
|
||
|
|
if self
|
||
|
|
.handlers
|
||
|
|
.insert(name.clone(), handler.clone())
|
||
|
|
.is_some()
|
||
|
|
{
|
||
|
|
warn!("overwriting handler for tool {name}");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// TODO(jif) for dynamic tools.
|
||
|
|
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
||
|
|
// where
|
||
|
|
// I: IntoIterator,
|
||
|
|
// I::Item: Into<String>,
|
||
|
|
// {
|
||
|
|
// for name in names {
|
||
|
|
// let name = name.into();
|
||
|
|
// if self
|
||
|
|
// .handlers
|
||
|
|
// .insert(name.clone(), handler.clone())
|
||
|
|
// .is_some()
|
||
|
|
// {
|
||
|
|
// warn!("overwriting handler for tool {name}");
|
||
|
|
// }
|
||
|
|
// }
|
||
|
|
// }
|
||
|
|
|
||
|
|
pub fn build(self) -> (Vec<ToolSpec>, ToolRegistry) {
|
||
|
|
let registry = ToolRegistry::new(self.handlers);
|
||
|
|
(self.specs, registry)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
||
|
|
match payload {
|
||
|
|
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||
|
|
_ => format!("unsupported call: {tool_name}"),
|
||
|
|
}
|
||
|
|
}
|