From f48dd99f223afdf9197dced40ffb5404d66e7a23 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Fri, 16 May 2025 14:38:08 -0700 Subject: [PATCH] feat: add support for OpenAI tool type, local_shell (#961) The new `codex-mini-latest` model expects a new tool with `{"type": "local_shell"}`. Its contract is similar to the existing `function` tool with `"name": "shell"`, so this takes the `local_shell` tool call into `ExecParams` and sends it through the existing `handle_container_exec_with_params()` code path. This also adds the following logic when adding the default set of tools to a request: ```rust let default_tools = if self.model.starts_with("codex") { &DEFAULT_CODEX_MODEL_TOOLS } else { &DEFAULT_TOOLS }; ``` That is, if the model name starts with `"codex"`, we add `{"type": "local_shell"}` to the list of tools; otherwise, we add the aforementioned `shell` tool. To test this, I ran the TUI with `-m codex-mini-latest` and verified that it used the `local_shell` tool. Though I also had some entries in `[mcp_servers]` in my personal `config.toml`. The `codex-mini-latest` model seemed eager to try the tools from the MCP servers first, so I have personally commented them out for now, so keep an eye out if you're testing `codex-mini-latest`! Perhaps we should include more details with `{"type": "local_shell"}` or update the following: https://github.com/openai/codex/blob/fd0b1b020818dfe8aaf7eb68425f09e86ab1b819/codex-rs/core/prompt.md For reference, the corresponding change in the TypeScript CLI is https://github.com/openai/codex/pull/951. --- codex-rs/core/src/client.rs | 31 +++++++++---- codex-rs/core/src/codex.rs | 55 ++++++++++++++++++++--- codex-rs/core/src/conversation_history.rs | 7 +-- codex-rs/core/src/models.rs | 33 ++++++++++++++ codex-rs/core/src/rollout.rs | 1 + 5 files changed, 109 insertions(+), 18 deletions(-) diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index 7316e904..57534e2f 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -40,10 +40,18 @@ use crate::util::backoff; /// When serialized as JSON, this produces a valid "Tool" in the OpenAI /// Responses API. -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type")] +enum OpenAiTool { + #[serde(rename = "function")] + Function(ResponsesApiTool), + #[serde(rename = "local_shell")] + LocalShell {}, +} + +#[derive(Debug, Clone, Serialize)] struct ResponsesApiTool { name: &'static str, - r#type: &'static str, // "function" description: &'static str, strict: bool, parameters: JsonSchema, @@ -67,7 +75,7 @@ enum JsonSchema { } /// Tool usage specification -static DEFAULT_TOOLS: LazyLock> = LazyLock::new(|| { +static DEFAULT_TOOLS: LazyLock> = LazyLock::new(|| { let mut properties = BTreeMap::new(); properties.insert( "command".to_string(), @@ -78,9 +86,8 @@ static DEFAULT_TOOLS: LazyLock> = LazyLock::new(|| { properties.insert("workdir".to_string(), JsonSchema::String); properties.insert("timeout".to_string(), JsonSchema::Number); - vec![ResponsesApiTool { + vec![OpenAiTool::Function(ResponsesApiTool { name: "shell", - r#type: "function", description: "Runs a shell command, and returns its output.", strict: false, parameters: JsonSchema::Object { @@ -88,9 +95,12 @@ static DEFAULT_TOOLS: LazyLock> = LazyLock::new(|| { required: &["command"], additional_properties: false, }, - }] + })] }); +static DEFAULT_CODEX_MODEL_TOOLS: LazyLock> = + LazyLock::new(|| vec![OpenAiTool::LocalShell {}]); + #[derive(Clone)] pub struct ModelClient { model: String, @@ -152,8 +162,13 @@ impl ModelClient { } // Assemble tool list: built-in tools + any extra tools from the prompt. - let mut tools_json = Vec::with_capacity(DEFAULT_TOOLS.len() + prompt.extra_tools.len()); - for t in DEFAULT_TOOLS.iter() { + let default_tools = if self.model.starts_with("codex") { + &DEFAULT_CODEX_MODEL_TOOLS + } else { + &DEFAULT_TOOLS + }; + let mut tools_json = Vec::with_capacity(default_tools.len() + prompt.extra_tools.len()); + for t in default_tools.iter() { tools_json.push(serde_json::to_value(t)?); } tools_json.extend( diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 52dff6d2..705b8260 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -51,6 +51,7 @@ use crate::mcp_connection_manager::try_parse_fully_qualified_tool_name; use crate::mcp_tool_call::handle_mcp_tool_call; use crate::models::ContentItem; use crate::models::FunctionCallOutputPayload; +use crate::models::LocalShellAction; use crate::models::ReasoningItemReasoningSummary; use crate::models::ResponseInputItem; use crate::models::ResponseItem; @@ -992,8 +993,7 @@ async fn handle_response_item( item: ResponseItem, ) -> CodexResult> { debug!(?item, "Output item"); - let mut output = None; - match item { + let output = match item { ResponseItem::Message { content, .. } => { for item in content { if let ContentItem::OutputText { text } = item { @@ -1004,6 +1004,7 @@ async fn handle_response_item( sess.tx_event.send(event).await.ok(); } } + None } ResponseItem::Reasoning { id: _, summary } => { for item in summary { @@ -1016,21 +1017,61 @@ async fn handle_response_item( }; sess.tx_event.send(event).await.ok(); } + None } ResponseItem::FunctionCall { name, arguments, call_id, } => { - output = Some( - handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await, - ); + tracing::info!("FunctionCall: {arguments}"); + Some(handle_function_call(sess, sub_id.to_string(), name, arguments, call_id).await) + } + ResponseItem::LocalShellCall { + id, + call_id, + status: _, + action, + } => { + let LocalShellAction::Exec(action) = action; + tracing::info!("LocalShellCall: {action:?}"); + let params = ShellToolCallParams { + command: action.command, + workdir: action.working_directory, + timeout_ms: action.timeout_ms, + }; + let effective_call_id = match (call_id, id) { + (Some(call_id), _) => call_id, + (None, Some(id)) => id, + (None, None) => { + error!("LocalShellCall without call_id or id"); + return Ok(Some(ResponseInputItem::FunctionCallOutput { + call_id: "".to_string(), + output: FunctionCallOutputPayload { + content: "LocalShellCall without call_id or id".to_string(), + success: None, + }, + })); + } + }; + + let exec_params = to_exec_params(params, sess); + Some( + handle_container_exec_with_params( + exec_params, + sess, + sub_id.to_string(), + effective_call_id, + ) + .await, + ) } ResponseItem::FunctionCallOutput { .. } => { debug!("unexpected FunctionCallOutput from stream"); + None } - ResponseItem::Other => (), - } + ResponseItem::Other => None, + }; Ok(output) } diff --git a/codex-rs/core/src/conversation_history.rs b/codex-rs/core/src/conversation_history.rs index 8d19e0cb..fdaf8397 100644 --- a/codex-rs/core/src/conversation_history.rs +++ b/codex-rs/core/src/conversation_history.rs @@ -41,8 +41,9 @@ impl ConversationHistory { fn is_api_message(message: &ResponseItem) -> bool { match message { ResponseItem::Message { role, .. } => role.as_str() != "system", - ResponseItem::FunctionCall { .. } => true, - ResponseItem::FunctionCallOutput { .. } => true, - _ => false, + ResponseItem::FunctionCallOutput { .. } + | ResponseItem::FunctionCall { .. } + | ResponseItem::LocalShellCall { .. } => true, + ResponseItem::Reasoning { .. } | ResponseItem::Other => false, } } diff --git a/codex-rs/core/src/models.rs b/codex-rs/core/src/models.rs index a8817cf7..ab213fd5 100644 --- a/codex-rs/core/src/models.rs +++ b/codex-rs/core/src/models.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use base64::Engine; use serde::Deserialize; use serde::Serialize; @@ -37,6 +39,14 @@ pub enum ResponseItem { id: String, summary: Vec, }, + LocalShellCall { + /// Set when using the chat completions API. + id: Option, + /// Set when using the Responses API. + call_id: Option, + status: LocalShellStatus, + action: LocalShellAction, + }, FunctionCall { name: String, // The Responses API returns the function call arguments as a *string* that contains @@ -71,6 +81,29 @@ impl From for ResponseItem { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum LocalShellStatus { + Completed, + InProgress, + Incomplete, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum LocalShellAction { + Exec(LocalShellExecAction), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LocalShellExecAction { + pub command: Vec, + pub timeout_ms: Option, + pub working_directory: Option, + pub env: Option>, + pub user: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ReasoningItemReasoningSummary { diff --git a/codex-rs/core/src/rollout.rs b/codex-rs/core/src/rollout.rs index 4127b603..c18a58df 100644 --- a/codex-rs/core/src/rollout.rs +++ b/codex-rs/core/src/rollout.rs @@ -115,6 +115,7 @@ impl RolloutRecorder { // "fully qualified MCP tool calls," so we could consider // reformatting them in that case. ResponseItem::Message { .. } + | ResponseItem::LocalShellCall { .. } | ResponseItem::FunctionCall { .. } | ResponseItem::FunctionCallOutput { .. } => {} ResponseItem::Reasoning { .. } | ResponseItem::Other => {