[core] Separate tools config from openai client (#1858)

## Summary
In an effort to make tools easier to work with and more configurable,
I'm introducing `ToolConfig` and updating `Prompt` to take in a general
list of Tools. I think this is simpler and better for a few reasons:
- We can easily assemble tools from various sources (our own harness,
mcp servers, etc.) and we can consolidate the logic for constructing the
logic in one place that is separate from serialization.
- client.rs no longer needs arbitrary config values, it just takes in a
list of tools to serialize

A hefty portion of the PR is now updating our conversion of
`mcp_types::Tool` to `OpenAITool`, but considering that @bolinfest
accurately called this out as a TODO long ago, I think it's time we
tackled it.

## Testing
- [x] Experimented locally, no changes, as expected
- [x] Added additional unit tests
- [x] Responded to rust-review
This commit is contained in:
Dylan
2025-08-05 19:27:52 -07:00
committed by GitHub
parent afa8f0d617
commit aff97ed7dd
7 changed files with 250 additions and 74 deletions

View File

@@ -32,7 +32,6 @@ use crate::util::backoff;
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model_family: &ModelFamily,
include_plan_tool: bool,
client: &reqwest::Client,
provider: &ModelProviderInfo,
) -> Result<ResponseStream> {
@@ -112,8 +111,7 @@ pub(crate) async fn stream_chat_completions(
}
}
let tools_json =
create_tools_json_for_chat_completions_api(prompt, model_family, include_plan_tool)?;
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let payload = json!({
"model": model_family.slug,
"messages": messages,

View File

@@ -83,7 +83,6 @@ impl ModelClient {
let response_stream = stream_chat_completions(
prompt,
&self.config.model_family,
self.config.include_plan_tool,
&self.client,
&self.provider,
)
@@ -132,11 +131,7 @@ impl ModelClient {
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
let tools_json = create_tools_json_for_responses_api(
prompt,
&self.config.model_family,
self.config.include_plan_tool,
)?;
let tools_json = create_tools_json_for_responses_api(&prompt.tools)?;
let reasoning = create_reasoning_param_for_request(
&self.config.model_family,
self.effort,

View File

@@ -3,12 +3,12 @@ use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::models::ResponseItem;
use crate::openai_tools::OpenAiTool;
use crate::protocol::TokenUsage;
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
use futures::Stream;
use serde::Serialize;
use std::borrow::Cow;
use std::collections::HashMap;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
@@ -33,10 +33,9 @@ pub struct Prompt {
/// Whether to store response on server side (disable_response_storage = !store).
pub store: bool,
/// Additional tools sourced from external MCP servers. Note each key is
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
/// Tools available to the model, including additional tools sourced from
/// external MCP servers.
pub tools: Vec<OpenAiTool>,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,

View File

@@ -61,6 +61,8 @@ use crate::models::ReasoningItemReasoningSummary;
use crate::models::ResponseInputItem;
use crate::models::ResponseItem;
use crate::models::ShellToolCallParams;
use crate::openai_tools::ToolsConfig;
use crate::openai_tools::get_openai_tools;
use crate::plan_tool::handle_update_plan;
use crate::project_doc::get_user_instructions;
use crate::protocol::AgentMessageDeltaEvent;
@@ -216,6 +218,7 @@ pub(crate) struct Session {
shell_environment_policy: ShellEnvironmentPolicy,
pub(crate) writable_roots: Mutex<Vec<PathBuf>>,
disable_response_storage: bool,
tools_config: ToolsConfig,
/// Manager for external MCP servers/tools.
mcp_connection_manager: McpConnectionManager,
@@ -810,6 +813,7 @@ async fn submission_loop(
let default_shell = shell::default_user_shell().await;
sess = Some(Arc::new(Session {
client,
tools_config: ToolsConfig::new(&config.model_family, config.include_plan_tool),
tx_event: tx_event.clone(),
ctrl_c: Arc::clone(&ctrl_c),
user_instructions,
@@ -1204,12 +1208,16 @@ async fn run_turn(
sub_id: String,
input: Vec<ResponseItem>,
) -> CodexResult<Vec<ProcessedResponseItem>> {
let extra_tools = sess.mcp_connection_manager.list_all_tools();
let tools = get_openai_tools(
&sess.tools_config,
Some(sess.mcp_connection_manager.list_all_tools()),
);
let prompt = Prompt {
input,
user_instructions: sess.user_instructions.clone(),
store: !sess.disable_response_storage,
extra_tools,
tools,
base_instructions_override: sess.base_instructions.clone(),
};
@@ -1436,7 +1444,7 @@ async fn run_compact_task(
input: turn_input,
user_instructions: None,
store: !sess.disable_response_storage,
extra_tools: HashMap::new(),
tools: Vec::new(),
base_instructions_override: Some(compact_instructions.clone()),
};

View File

@@ -48,6 +48,5 @@ pub mod spawn;
pub mod turn_diff_tracker;
mod user_notification;
pub mod util;
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
pub use safety::get_platform_sandbox;

View File

@@ -1,22 +1,26 @@
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use std::collections::BTreeMap;
use std::collections::HashMap;
use crate::client_common::Prompt;
use crate::model_family::ModelFamily;
use crate::plan_tool::PLAN_TOOL;
#[derive(Debug, Clone, Serialize)]
pub(crate) struct ResponsesApiTool {
pub(crate) name: &'static str,
pub(crate) description: &'static str,
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct ResponsesApiTool {
pub(crate) name: String,
pub(crate) description: String,
/// TODO: Validation. When strict is set to true, the JSON schema,
/// `required` and `additional_properties` must be present. All fields in
/// `properties` must be present in `required`.
pub(crate) strict: bool,
pub(crate) parameters: JsonSchema,
}
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(tag = "type")]
pub(crate) enum OpenAiTool {
#[serde(rename = "function")]
@@ -25,8 +29,35 @@ pub(crate) enum OpenAiTool {
LocalShell {},
}
#[derive(Debug, Clone)]
pub enum ConfigShellToolType {
DefaultShell,
LocalShell,
}
#[derive(Debug, Clone)]
pub struct ToolsConfig {
pub shell_type: ConfigShellToolType,
pub plan_tool: bool,
}
impl ToolsConfig {
pub fn new(model_family: &ModelFamily, include_plan_tool: bool) -> Self {
let shell_type = if model_family.uses_local_shell_tool {
ConfigShellToolType::LocalShell
} else {
ConfigShellToolType::DefaultShell
};
Self {
shell_type,
plan_tool: include_plan_tool,
}
}
}
/// Generic JSONSchema subset needed for our tool definitions
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub(crate) enum JsonSchema {
String,
@@ -36,13 +67,17 @@ pub(crate) enum JsonSchema {
},
Object {
properties: BTreeMap<String, JsonSchema>,
required: &'static [&'static str],
#[serde(rename = "additionalProperties")]
additional_properties: bool,
#[serde(skip_serializing_if = "Option::is_none")]
required: Option<Vec<String>>,
#[serde(
rename = "additionalProperties",
skip_serializing_if = "Option::is_none"
)]
additional_properties: Option<bool>,
},
}
fn create_shell_tool() -> OpenAiTool {
pub(crate) fn create_shell_tool() -> OpenAiTool {
let mut properties = BTreeMap::new();
properties.insert(
"command".to_string(),
@@ -54,13 +89,13 @@ fn create_shell_tool() -> OpenAiTool {
properties.insert("timeout".to_string(), JsonSchema::Number);
OpenAiTool::Function(ResponsesApiTool {
name: "shell",
description: "Runs a shell command and returns its output",
name: "shell".to_string(),
description: "Runs a shell command and returns its output".to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: &["command"],
additional_properties: false,
required: Some(vec!["command".to_string()]),
additional_properties: Some(false),
},
})
}
@@ -69,31 +104,13 @@ fn create_shell_tool() -> OpenAiTool {
/// Responses API:
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
pub(crate) fn create_tools_json_for_responses_api(
prompt: &Prompt,
model_family: &ModelFamily,
include_plan_tool: bool,
tools: &Vec<OpenAiTool>,
) -> crate::error::Result<Vec<serde_json::Value>> {
// Assemble tool list: built-in tools + any extra tools from the prompt.
let mut openai_tools = vec![create_shell_tool()];
if model_family.uses_local_shell_tool {
openai_tools.push(OpenAiTool::LocalShell {});
}
let mut tools_json = Vec::new();
let mut tools_json = Vec::with_capacity(openai_tools.len() + prompt.extra_tools.len() + 1);
for tool in openai_tools.iter() {
for tool in tools {
tools_json.push(serde_json::to_value(tool)?);
}
tools_json.extend(
prompt
.extra_tools
.clone()
.into_iter()
.map(|(name, tool)| mcp_tool_to_openai_tool(name, tool)),
);
if include_plan_tool {
tools_json.push(serde_json::to_value(PLAN_TOOL.clone())?);
}
Ok(tools_json)
}
@@ -102,14 +119,11 @@ pub(crate) fn create_tools_json_for_responses_api(
/// Chat Completions API:
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
pub(crate) fn create_tools_json_for_chat_completions_api(
prompt: &Prompt,
model_family: &ModelFamily,
include_plan_tool: bool,
tools: &Vec<OpenAiTool>,
) -> crate::error::Result<Vec<serde_json::Value>> {
// We start with the JSON for the Responses API and than rewrite it to match
// the chat completions tool call format.
let responses_api_tools_json =
create_tools_json_for_responses_api(prompt, model_family, include_plan_tool)?;
let responses_api_tools_json = create_tools_json_for_responses_api(tools)?;
let tools_json = responses_api_tools_json
.into_iter()
.filter_map(|mut tool| {
@@ -132,10 +146,10 @@ pub(crate) fn create_tools_json_for_chat_completions_api(
Ok(tools_json)
}
fn mcp_tool_to_openai_tool(
pub(crate) fn mcp_tool_to_openai_tool(
fully_qualified_name: String,
tool: mcp_types::Tool,
) -> serde_json::Value {
) -> Result<ResponsesApiTool, serde_json::Error> {
let mcp_types::Tool {
description,
mut input_schema,
@@ -150,12 +164,175 @@ fn mcp_tool_to_openai_tool(
input_schema.properties = Some(serde_json::Value::Object(serde_json::Map::new()));
}
// TODO(mbolin): Change the contract of this function to return
// ResponsesApiTool.
json!({
"name": fully_qualified_name,
"description": description,
"parameters": input_schema,
"type": "function",
let serialized_input_schema = serde_json::to_value(input_schema)?;
let input_schema = serde_json::from_value::<JsonSchema>(serialized_input_schema)?;
Ok(ResponsesApiTool {
name: fully_qualified_name,
description: description.unwrap_or_default(),
strict: false,
parameters: input_schema,
})
}
/// Returns a list of OpenAiTools based on the provided config and MCP tools.
/// Note that the keys of mcp_tools should be fully qualified names. See
/// [`McpConnectionManager`] for more details.
pub(crate) fn get_openai_tools(
config: &ToolsConfig,
mcp_tools: Option<HashMap<String, mcp_types::Tool>>,
) -> Vec<OpenAiTool> {
let mut tools: Vec<OpenAiTool> = Vec::new();
match config.shell_type {
ConfigShellToolType::DefaultShell => {
tools.push(create_shell_tool());
}
ConfigShellToolType::LocalShell => {
tools.push(OpenAiTool::LocalShell {});
}
}
if config.plan_tool {
tools.push(PLAN_TOOL.clone());
}
if let Some(mcp_tools) = mcp_tools {
for (name, tool) in mcp_tools {
match mcp_tool_to_openai_tool(name.clone(), tool.clone()) {
Ok(converted_tool) => tools.push(OpenAiTool::Function(converted_tool)),
Err(e) => {
tracing::error!("Failed to convert {name:?} MCP tool to OpenAI tool: {e:?}");
}
}
}
}
tools
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use crate::model_family::find_family_for_model;
use mcp_types::ToolInputSchema;
use super::*;
fn assert_eq_tool_names(tools: &[OpenAiTool], expected_names: &[&str]) {
let tool_names = tools
.iter()
.map(|tool| match tool {
OpenAiTool::Function(ResponsesApiTool { name, .. }) => name,
OpenAiTool::LocalShell {} => "local_shell",
})
.collect::<Vec<_>>();
assert_eq!(
tool_names.len(),
expected_names.len(),
"tool_name mismatch, {tool_names:?}, {expected_names:?}",
);
for (name, expected_name) in tool_names.iter().zip(expected_names.iter()) {
assert_eq!(
name, expected_name,
"tool_name mismatch, {name:?}, {expected_name:?}"
);
}
}
#[test]
fn test_get_openai_tools() {
let model_family = find_family_for_model("codex-mini-latest")
.expect("codex-mini-latest should be a valid model family");
let config = ToolsConfig::new(&model_family, true);
let tools = get_openai_tools(&config, Some(HashMap::new()));
assert_eq_tool_names(&tools, &["local_shell", "update_plan"]);
}
#[test]
fn test_get_openai_tools_default_shell() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
let config = ToolsConfig::new(&model_family, true);
let tools = get_openai_tools(&config, Some(HashMap::new()));
assert_eq_tool_names(&tools, &["shell", "update_plan"]);
}
#[test]
fn test_get_openai_tools_mcp_tools() {
let model_family = find_family_for_model("o3").expect("o3 should be a valid model family");
let config = ToolsConfig::new(&model_family, false);
let tools = get_openai_tools(
&config,
Some(HashMap::from([(
"test_server/do_something_cool".to_string(),
mcp_types::Tool {
name: "do_something_cool".to_string(),
input_schema: ToolInputSchema {
properties: Some(serde_json::json!({
"string_argument": {
"type": "string",
},
"number_argument": {
"type": "number",
},
"object_argument": {
"type": "object",
"properties": {
"string_property": { "type": "string" },
"number_property": { "type": "number" },
},
"required": [
"string_property",
"number_property"
],
"additionalProperties": Some(false),
},
})),
required: None,
r#type: "object".to_string(),
},
output_schema: None,
title: None,
annotations: None,
description: Some("Do something cool".to_string()),
},
)])),
);
assert_eq_tool_names(&tools, &["shell", "test_server/do_something_cool"]);
assert_eq!(
tools[1],
OpenAiTool::Function(ResponsesApiTool {
name: "test_server/do_something_cool".to_string(),
parameters: JsonSchema::Object {
properties: BTreeMap::from([
("string_argument".to_string(), JsonSchema::String),
("number_argument".to_string(), JsonSchema::Number),
(
"object_argument".to_string(),
JsonSchema::Object {
properties: BTreeMap::from([
("string_property".to_string(), JsonSchema::String),
("number_property".to_string(), JsonSchema::Number),
]),
required: Some(vec![
"string_property".to_string(),
"number_property".to_string(),
]),
additional_properties: Some(false),
},
),
]),
required: None,
additional_properties: None,
},
description: "Do something cool".to_string(),
strict: false,
})
);
}
}

View File

@@ -45,8 +45,8 @@ pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
let plan_items_schema = JsonSchema::Array {
items: Box::new(JsonSchema::Object {
properties: plan_item_props,
required: &["step", "status"],
additional_properties: false,
required: Some(vec!["step".to_string(), "status".to_string()]),
additional_properties: Some(false),
}),
};
@@ -55,7 +55,7 @@ pub(crate) static PLAN_TOOL: LazyLock<OpenAiTool> = LazyLock::new(|| {
properties.insert("plan".to_string(), plan_items_schema);
OpenAiTool::Function(ResponsesApiTool {
name: "update_plan",
name: "update_plan".to_string(),
description: r#"Use the update_plan tool to keep the user updated on the current plan for the task.
After understanding the user's task, call the update_plan tool with an initial plan. An example of a plan:
1. Explore the codebase to find relevant files (status: in_progress)
@@ -66,12 +66,12 @@ Until all the steps are finished, there should always be exactly one in_progress
Call the update_plan tool whenever you finish a step, marking the completed step as `completed` and marking the next step as `in_progress`.
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step.
Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#,
When all steps are completed, call update_plan one last time with all steps marked as `completed`."#.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: &["plan"],
additional_properties: false,
required: Some(vec!["plan".to_string()]),
additional_properties: Some(false),
},
})
});