chore: introduce ModelFamily abstraction (#1838)

To date, we have a number of hardcoded OpenAI model slug checks spread
throughout the codebase, which makes it hard to audit the various
special cases for each model. To mitigate this issue, this PR introduces
the idea of a `ModelFamily` that has fields to represent the existing
special cases, such as `supports_reasoning_summaries` and
`uses_local_shell_tool`.

There is a `find_family_for_model()` function that maps the raw model
slug to a `ModelFamily`. This function hardcodes all the knowledge about
the special attributes for each model. This PR then replaces the
hardcoded model name checks with checks against a `ModelFamily`.

Note `ModelFamily` is now available as `Config::model_family`. We should
ultimately remove `Config::model` in favor of
`Config::model_family::slug`.
This commit is contained in:
Michael Bolin
2025-08-04 23:50:03 -07:00
committed by GitHub
parent fcdb1c4b4d
commit 136b3ee5bf
10 changed files with 161 additions and 75 deletions

View File

@@ -10,6 +10,8 @@ use crate::config_types::ShellEnvironmentPolicyToml;
use crate::config_types::Tui;
use crate::config_types::UriBasedFileOpener;
use crate::flags::OPENAI_DEFAULT_MODEL;
use crate::model_family::ModelFamily;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
@@ -33,6 +35,8 @@ pub struct Config {
/// Optional override of model selection.
pub model: String,
pub model_family: ModelFamily,
/// Size of the context window for the model, in tokens.
pub model_context_window: Option<u64>,
@@ -134,10 +138,6 @@ pub struct Config {
/// request using the Responses API.
pub model_reasoning_summary: ReasoningSummary,
/// When set to `true`, overrides the default heuristic and forces
/// `model_supports_reasoning_summaries()` to return `true`.
pub model_supports_reasoning_summaries: bool,
/// Base URL for requests to ChatGPT (as opposed to the OpenAI API).
pub chatgpt_base_url: String,
@@ -465,7 +465,19 @@ impl Config {
.or(config_profile.model)
.or(cfg.model)
.unwrap_or_else(default_model);
let openai_model_info = get_model_info(&model);
let model_family = find_family_for_model(&model).unwrap_or_else(|| {
let supports_reasoning_summaries =
cfg.model_supports_reasoning_summaries.unwrap_or(false);
ModelFamily {
slug: model.clone(),
family: model.clone(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries,
uses_local_shell_tool: false,
}
});
let openai_model_info = get_model_info(&model_family);
let model_context_window = cfg
.model_context_window
.or_else(|| openai_model_info.as_ref().map(|info| info.context_window));
@@ -490,6 +502,7 @@ impl Config {
let config = Self {
model,
model_family,
model_context_window,
model_max_output_tokens,
model_provider_id,
@@ -527,10 +540,6 @@ impl Config {
.or(cfg.model_reasoning_summary)
.unwrap_or_default(),
model_supports_reasoning_summaries: cfg
.model_supports_reasoning_summaries
.unwrap_or(false),
chatgpt_base_url: config_profile
.chatgpt_base_url
.or(cfg.chatgpt_base_url)
@@ -871,6 +880,7 @@ disable_response_storage = true
assert_eq!(
Config {
model: "o3".to_string(),
model_family: find_family_for_model("o3").expect("known model slug"),
model_context_window: Some(200_000),
model_max_output_tokens: Some(100_000),
model_provider_id: "openai".to_string(),
@@ -893,7 +903,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::High,
model_reasoning_summary: ReasoningSummary::Detailed,
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
@@ -921,6 +930,7 @@ disable_response_storage = true
)?;
let expected_gpt3_profile_config = Config {
model: "gpt-3.5-turbo".to_string(),
model_family: find_family_for_model("gpt-3.5-turbo").expect("known model slug"),
model_context_window: Some(16_385),
model_max_output_tokens: Some(4_096),
model_provider_id: "openai-chat-completions".to_string(),
@@ -943,7 +953,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::default(),
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
@@ -986,6 +995,7 @@ disable_response_storage = true
)?;
let expected_zdr_profile_config = Config {
model: "o3".to_string(),
model_family: find_family_for_model("o3").expect("known model slug"),
model_context_window: Some(200_000),
model_max_output_tokens: Some(100_000),
model_provider_id: "openai".to_string(),
@@ -1008,7 +1018,6 @@ disable_response_storage = true
hide_agent_reasoning: false,
model_reasoning_effort: ReasoningEffort::default(),
model_reasoning_summary: ReasoningSummary::default(),
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,