chore: introduce ModelFamily abstraction (#1838)

To date, we have a number of hardcoded OpenAI model slug checks spread
throughout the codebase, which makes it hard to audit the various
special cases for each model. To mitigate this issue, this PR introduces
the idea of a `ModelFamily` that has fields to represent the existing
special cases, such as `supports_reasoning_summaries` and
`uses_local_shell_tool`.

There is a `find_family_for_model()` function that maps the raw model
slug to a `ModelFamily`. This function hardcodes all the knowledge about
the special attributes for each model. This PR then replaces the
hardcoded model name checks with checks against a `ModelFamily`.

Note `ModelFamily` is now available as `Config::model_family`. We should
ultimately remove `Config::model` in favor of
`Config::model_family::slug`.
This commit is contained in:
Michael Bolin
2025-08-04 23:50:03 -07:00
committed by GitHub
parent fcdb1c4b4d
commit 136b3ee5bf
10 changed files with 161 additions and 75 deletions

View File

@@ -21,6 +21,7 @@ use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::models::ContentItem;
use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
@@ -29,7 +30,7 @@ use crate::util::backoff;
/// Implementation for the classic Chat Completions API.
pub(crate) async fn stream_chat_completions(
prompt: &Prompt,
model: &str,
model_family: &ModelFamily,
include_plan_tool: bool,
client: &reqwest::Client,
provider: &ModelProviderInfo,
@@ -37,7 +38,7 @@ pub(crate) async fn stream_chat_completions(
// Build messages array
let mut messages = Vec::<serde_json::Value>::new();
let full_instructions = prompt.get_full_instructions(model);
let full_instructions = prompt.get_full_instructions(model_family);
messages.push(json!({"role": "system", "content": full_instructions}));
if let Some(instr) = &prompt.get_formatted_user_instructions() {
@@ -110,9 +111,10 @@ pub(crate) async fn stream_chat_completions(
}
}
let tools_json = create_tools_json_for_chat_completions_api(prompt, model, include_plan_tool)?;
let tools_json =
create_tools_json_for_chat_completions_api(prompt, model_family, include_plan_tool)?;
let payload = json!({
"model": model,
"model": model_family.slug,
"messages": messages,
"stream": true,
"tools": tools_json,

View File

@@ -82,7 +82,7 @@ impl ModelClient {
// Create the raw streaming connection first.
let response_stream = stream_chat_completions(
prompt,
&self.config.model,
&self.config.model_family,
self.config.include_plan_tool,
&self.client,
&self.provider,
@@ -127,13 +127,17 @@ impl ModelClient {
let store = prompt.store && auth_mode != Some(AuthMode::ChatGPT);
let full_instructions = prompt.get_full_instructions(&self.config.model);
let full_instructions = prompt.get_full_instructions(&self.config.model_family);
let tools_json = create_tools_json_for_responses_api(
prompt,
&self.config.model,
&self.config.model_family,
self.config.include_plan_tool,
)?;
let reasoning = create_reasoning_param_for_request(&self.config, self.effort, self.summary);
let reasoning = create_reasoning_param_for_request(
&self.config.model_family,
self.effort,
self.summary,
);
// Request encrypted COT if we are not storing responses,
// otherwise reasoning items will be referenced by ID

View File

@@ -1,6 +1,7 @@
use crate::config_types::ReasoningEffort as ReasoningEffortConfig;
use crate::config_types::ReasoningSummary as ReasoningSummaryConfig;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::models::ResponseItem;
use crate::protocol::TokenUsage;
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
@@ -42,13 +43,13 @@ pub struct Prompt {
}
impl Prompt {
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
pub(crate) fn get_full_instructions(&self, model: &ModelFamily) -> Cow<'_, str> {
let base = self
.base_instructions_override
.as_deref()
.unwrap_or(BASE_INSTRUCTIONS);
let mut sections: Vec<&str> = vec![base];
if model.starts_with("gpt-4.1") {
if model.needs_special_apply_patch_instructions {
sections.push(APPLY_PATCH_TOOL_INSTRUCTIONS);
}
Cow::Owned(sections.join("\n"))
@@ -144,14 +145,12 @@ pub(crate) struct ResponsesApiRequest<'a> {
pub(crate) include: Vec<String>,
}
use crate::config::Config;
pub(crate) fn create_reasoning_param_for_request(
config: &Config,
model_family: &ModelFamily,
effort: ReasoningEffortConfig,
summary: ReasoningSummaryConfig,
) -> Option<Reasoning> {
if model_supports_reasoning_summaries(config) {
if model_family.supports_reasoning_summaries {
let effort: Option<OpenAiReasoningEffort> = effort.into();
let effort = effort?;
Some(Reasoning {
@@ -163,27 +162,6 @@ pub(crate) fn create_reasoning_param_for_request(
}
}
pub fn model_supports_reasoning_summaries(config: &Config) -> bool {
// Currently, we hardcode this rule to decide whether to enable reasoning.
// We expect reasoning to apply only to OpenAI models, but we do not want
// users to have to mess with their config to disable reasoning for models
// that do not support it, such as `gpt-4.1`.
//
// Though if a user is using Codex with non-OpenAI models that, say, happen
// to start with "o", then they can set `model_reasoning_effort = "none"` in
// config.toml to disable reasoning.
//
// Converseley, if a user has a non-OpenAI provider that supports reasoning,
// they can set the top-level `model_supports_reasoning_summaries = true`
// config option to enable reasoning.
if config.model_supports_reasoning_summaries {
return true;
}
let model = &config.model;
model.starts_with("o") || model.starts_with("codex")
}
pub(crate) struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
@@ -198,6 +176,9 @@ impl Stream for ResponseStream {
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used)]
use crate::model_family::find_family_for_model;
use super::*;
#[test]
@@ -207,7 +188,8 @@ mod tests {
..Default::default()
};
let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}");
let full = prompt.get_full_instructions("gpt-4.1");
let model_family = find_family_for_model("gpt-4.1").expect("known model slug");
let full = prompt.get_full_instructions(&model_family);
assert_eq!(full, expected);
}
}

View File

@@ -10,6 +10,8 @@ use crate::config_types::ShellEnvironmentPolicyToml;
use crate::config_types::Tui;
use crate::config_types::UriBasedFileOpener;
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::model_family::ModelFamily;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
@@ -33,6 +35,8 @@ pub struct Config {
/// Optional override of model selection.
pub model: String,
pub model_family: ModelFamily,
/// Size of the context window for the model, in tokens.
pub model_context_window: Option<u64>,
@@ -134,10 +138,6 @@ pub struct Config {
/// request using the Responses API.
pub model_reasoning_summary: ReasoningSummary,
/// When set to `true`, overrides the default heuristic and forces
/// `model_supports_reasoning_summaries()` to return `true`.
pub model_supports_reasoning_summaries: bool,
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
pub chatgpt_base_url: String,
@@ -465,7 +465,19 @@ impl Config {
.or(config_profile.model)
.or(cfg.model)
.unwrap_or_else(default_model);
let openai_model_info = get_model_info(&model);
let model_family = find_family_for_model(&model).unwrap_or_else(|| {
let supports_reasoning_summaries =
cfg.model_supports_reasoning_summaries.unwrap_or(false);
ModelFamily {
slug: model.clone(),
family: model.clone(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries,
uses_local_shell_tool: false,
}
});
let openai_model_info = get_model_info(&model_family);
let model_context_window = cfg
.model_context_window
.or_else(|| openai_model_info.as_ref().map(|info| info.context_window));
@@ -490,6 +502,7 @@ impl Config {
let config = Self {
model,
model_family,
model_context_window,
model_max_output_tokens,
model_provider_id,
@@ -527,10 +540,6 @@ impl Config {
.or(cfg.model_reasoning_summary)
.unwrap_or_default(),
model_supports_reasoning_summaries: cfg
.model_supports_reasoning_summaries
.unwrap_or(false),
chatgpt_base_url: config_profile
.chatgpt_base_url
.or(cfg.chatgpt_base_url)
@@ -871,6 +880,7 @@ disable_response_storage = true
assert_eq!(
Config {
model: "o3".to_string(),
model_family: find_family_for_model("o3").expect("known model slug"),
model_context_window: Some(200_000),
model_max_output_tokens: Some(100_000),
model_provider_id: "openai".to_string(),
@@ -893,7 +903,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::High,
model_reasoning_summary: ReasoningSummary::Detailed,
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
@@ -921,6 +930,7 @@ disable_response_storage = true
)?;
let expected_gpt3_profile_config = Config {
model: "gpt-3.5-turbo".to_string(),
model_family: find_family_for_model("gpt-3.5-turbo").expect("known model slug"),
model_context_window: Some(16_385),
model_max_output_tokens: Some(4_096),
model_provider_id: "openai-chat-completions".to_string(),
@@ -943,7 +953,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::default(),
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
@@ -986,6 +995,7 @@ disable_response_storage = true
)?;
let expected_zdr_profile_config = Config {
model: "o3".to_string(),
model_family: find_family_for_model("o3").expect("known model slug"),
model_context_window: Some(200_000),
model_max_output_tokens: Some(100_000),
model_provider_id: "openai".to_string(),
@@ -1008,7 +1018,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::default(),
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,

View File

@@ -31,6 +31,7 @@ mod model_provider_info;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub mod model_family;
mod models;
mod openai_model_info;
mod openai_tools;
@@ -47,5 +48,4 @@ mod user_notification;
pub mod util;
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
pub use client_common::model_supports_reasoning_summaries;
pub use safety::get_platform_sandbox;

View File

@@ -0,0 +1,93 @@
/// A model family is a group of models that share certain characteristics.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
/// The full model slug used to derive this model family, e.g.
/// "gpt-4.1-2025-04-14".
pub slug: String,
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
/// with [`crate::openai_model_info::get_model_info`].
pub family: String,
/// True if the model needs additional instructions on how to use the
/// "virtual" `apply_patch` CLI.
pub needs_special_apply_patch_instructions: bool,
// Whether the `reasoning` field can be set when making a request to this
// model family. Note it has `effort` and `summary` subfields (though
// `summary` is optional).
pub supports_reasoning_summaries: bool,
// This should be set to true when the model expects a tool named
// "local_shell" to be provided. Its contract must be understood natively by
// the model such that its description can be omitted.
// See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool,
}
macro_rules! model_family {
(
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?
) => {{
// defaults
let mut mf = ModelFamily {
slug: $slug.to_string(),
family: $family.to_string(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries: false,
uses_local_shell_tool: false,
};
// apply overrides
$(
mf.$key = $value;
)*
Some(mf)
}};
}
macro_rules! simple_model_family {
(
$slug:expr, $family:expr
) => {{
Some(ModelFamily {
slug: $slug.to_string(),
family: $family.to_string(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries: false,
uses_local_shell_tool: false,
})
}};
}
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
/// does not match any known model family.
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
if slug.starts_with("o3") {
model_family!(
slug, "o3",
supports_reasoning_summaries: true,
)
} else if slug.starts_with("o4-mini") {
model_family!(
slug, "o4-mini",
supports_reasoning_summaries: true,
)
} else if slug.starts_with("codex-mini-latest") {
model_family!(
slug, "codex-mini-latest",
supports_reasoning_summaries: true,
uses_local_shell_tool: true,
)
} else if slug.starts_with("gpt-4.1") {
model_family!(
slug, "gpt-4.1",
needs_special_apply_patch_instructions: true,
)
} else if slug.starts_with("gpt-4o") {
simple_model_family!(slug, "gpt-4o")
} else if slug.starts_with("gpt-3.5") {
simple_model_family!(slug, "gpt-3.5")
} else {
None
}
}

View File

@@ -1,3 +1,5 @@
use crate::model_family::ModelFamily;
/// Metadata about a model, particularly OpenAI models.
/// We may want to consider including details like the pricing for
/// input tokens, output tokens, etc., though users will need to be able to
@@ -14,8 +16,8 @@ pub(crate) struct ModelInfo {
/// Note details such as what a model like gpt-4o is aliased to may be out of
/// date.
pub(crate) fn get_model_info(name: &str) -> Option<ModelInfo> {
match name {
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
match model_family.slug.as_str() {
// https://platform.openai.com/docs/models/o3
"o3" => Some(ModelInfo {
context_window: 200_000,

View File

@@ -1,9 +1,9 @@
use serde::Serialize;
use serde_json::json;
use std::collections::BTreeMap;
use std::sync::LazyLock;
use crate::client_common::Prompt;
use crate::model_family::ModelFamily;
use crate::plan_tool::PLAN_TOOL;
#[derive(Debug, Clone, Serialize)]
@@ -42,8 +42,7 @@ pub(crate) enum JsonSchema {
},
}
/// Tool usage specification
static DEFAULT_TOOLS: LazyLock<Vec<OpenAiTool>> = LazyLock::new(|| {
fn create_shell_tool() -> OpenAiTool {
let mut properties = BTreeMap::new();
properties.insert(
"command".to_string(),
@@ -54,7 +53,7 @@ static DEFAULT_TOOLS: LazyLock<Vec<OpenAiTool>> = LazyLock::new(|| {
properties.insert("workdir".to_string(), JsonSchema::String);
properties.insert("timeout".to_string(), JsonSchema::Number);
vec![OpenAiTool::Function(ResponsesApiTool {
OpenAiTool::Function(ResponsesApiTool {
name: "shell",
description: "Runs a shell command, and returns its output.",
strict: false,
@@ -63,29 +62,26 @@ static DEFAULT_TOOLS: LazyLock<Vec<OpenAiTool>> = LazyLock::new(|| {
required: &["command"],
additional_properties: false,
},
})]
});
static DEFAULT_CODEX_MODEL_TOOLS: LazyLock<Vec<OpenAiTool>> =
LazyLock::new(|| vec![OpenAiTool::LocalShell {}]);
})
}
/// Returns JSON values that are compatible with Function Calling in the
/// Responses API:
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
pub(crate) fn create_tools_json_for_responses_api(
prompt: &Prompt,
model: &str,
model_family: &ModelFamily,
include_plan_tool: bool,
) -> crate::error::Result<Vec<serde_json::Value>> {
// Assemble tool list: built-in tools + any extra tools from the prompt.
let default_tools = if model.starts_with("codex") {
&DEFAULT_CODEX_MODEL_TOOLS
} else {
&DEFAULT_TOOLS
};
let mut tools_json = Vec::with_capacity(default_tools.len() + prompt.extra_tools.len());
for t in default_tools.iter() {
tools_json.push(serde_json::to_value(t)?);
let mut openai_tools = vec![create_shell_tool()];
if model_family.uses_local_shell_tool {
openai_tools.push(OpenAiTool::LocalShell {});
}
let mut tools_json = Vec::with_capacity(openai_tools.len() + prompt.extra_tools.len() + 1);
for tool in openai_tools.iter() {
tools_json.push(serde_json::to_value(tool)?);
}
tools_json.extend(
prompt
@@ -107,13 +103,13 @@ pub(crate) fn create_tools_json_for_responses_api(
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
pub(crate) fn create_tools_json_for_chat_completions_api(
prompt: &Prompt,
model: &str,
model_family: &ModelFamily,
include_plan_tool: bool,
) -> crate::error::Result<Vec<serde_json::Value>> {
// We start with the JSON for the Responses API and than rewrite it to match
// the chat completions tool call format.
let responses_api_tools_json =
create_tools_json_for_responses_api(prompt, model, include_plan_tool)?;
create_tools_json_for_responses_api(prompt, model_family, include_plan_tool)?;
let tools_json = responses_api_tools_json
.into_iter()
.filter_map(|mut tool| {

View File

@@ -3,7 +3,6 @@ use std::path::Path;
use codex_common::summarize_sandbox_policy;
use codex_core::WireApi;
use codex_core::config::Config;
use codex_core::model_supports_reasoning_summaries;
use codex_core::protocol::Event;
pub(crate) enum CodexStatus {
@@ -29,7 +28,7 @@ pub(crate) fn create_config_summary_entries(config: &Config) -> Vec<(&'static st
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
];
if config.model_provider.wire_api == WireApi::Responses
&& model_supports_reasoning_summaries(config)
&& config.model_family.supports_reasoning_summaries
{
entries.push((
"reasoning effort",

View File

@@ -7,7 +7,6 @@ use codex_common::elapsed::format_duration;
use codex_common::summarize_sandbox_policy;
use codex_core::WireApi;
use codex_core::config::Config;
use codex_core::model_supports_reasoning_summaries;
use codex_core::plan_tool::PlanItemArg;
use codex_core::plan_tool::StepStatus;
use codex_core::plan_tool::UpdatePlanArgs;
@@ -177,7 +176,7 @@ impl HistoryCell {
("sandbox", summarize_sandbox_policy(&config.sandbox_policy)),
];
if config.model_provider.wire_api == WireApi::Responses
&& model_supports_reasoning_summaries(config)
&& config.model_family.supports_reasoning_summaries
{
entries.push((
"reasoning effort",