feat: added SetDefaultModel to JSON-RPC server (#3512)
This adds `SetDefaultModel`, which takes `model` and `reasoning_effort` as optional fields. If set, the field will overwrite what is in the user's `config.toml`. This reuses logic that was added to support the `/model` command in the TUI: https://github.com/openai/codex/pull/2799.
This commit is contained in:
@@ -495,7 +495,7 @@ fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Base config deserialized from ~/.codex/config.toml.
|
/// Base config deserialized from ~/.codex/config.toml.
|
||||||
#[derive(Deserialize, Debug, Clone, Default)]
|
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||||
pub struct ConfigToml {
|
pub struct ConfigToml {
|
||||||
/// Optional override of model selection.
|
/// Optional override of model selection.
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
@@ -627,7 +627,7 @@ pub struct ProjectConfig {
|
|||||||
pub trust_level: Option<String>,
|
pub trust_level: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Debug, Clone, Default)]
|
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||||
pub struct ToolsToml {
|
pub struct ToolsToml {
|
||||||
#[serde(default, alias = "web_search_request")]
|
#[serde(default, alias = "web_search_request")]
|
||||||
pub web_search: Option<bool>,
|
pub web_search: Option<bool>,
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ use codex_core::config::Config;
|
|||||||
use codex_core::config::ConfigOverrides;
|
use codex_core::config::ConfigOverrides;
|
||||||
use codex_core::config::ConfigToml;
|
use codex_core::config::ConfigToml;
|
||||||
use codex_core::config::load_config_as_toml;
|
use codex_core::config::load_config_as_toml;
|
||||||
|
use codex_core::config_edit::CONFIG_KEY_EFFORT;
|
||||||
|
use codex_core::config_edit::CONFIG_KEY_MODEL;
|
||||||
|
use codex_core::config_edit::persist_non_null_overrides;
|
||||||
use codex_core::default_client::get_codex_user_agent;
|
use codex_core::default_client::get_codex_user_agent;
|
||||||
use codex_core::exec::ExecParams;
|
use codex_core::exec::ExecParams;
|
||||||
use codex_core::exec_env::create_env;
|
use codex_core::exec_env::create_env;
|
||||||
@@ -71,6 +74,8 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse;
|
|||||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||||
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
use codex_protocol::mcp_protocol::SendUserTurnResponse;
|
||||||
use codex_protocol::mcp_protocol::ServerNotification;
|
use codex_protocol::mcp_protocol::ServerNotification;
|
||||||
|
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||||
|
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||||
use codex_protocol::mcp_protocol::UserInfoResponse;
|
use codex_protocol::mcp_protocol::UserInfoResponse;
|
||||||
use codex_protocol::mcp_protocol::UserSavedConfig;
|
use codex_protocol::mcp_protocol::UserSavedConfig;
|
||||||
use codex_protocol::models::ContentItem;
|
use codex_protocol::models::ContentItem;
|
||||||
@@ -192,6 +197,9 @@ impl CodexMessageProcessor {
|
|||||||
ClientRequest::GetUserSavedConfig { request_id } => {
|
ClientRequest::GetUserSavedConfig { request_id } => {
|
||||||
self.get_user_saved_config(request_id).await;
|
self.get_user_saved_config(request_id).await;
|
||||||
}
|
}
|
||||||
|
ClientRequest::SetDefaultModel { request_id, params } => {
|
||||||
|
self.set_default_model(request_id, params).await;
|
||||||
|
}
|
||||||
ClientRequest::GetUserAgent { request_id } => {
|
ClientRequest::GetUserAgent { request_id } => {
|
||||||
self.get_user_agent(request_id).await;
|
self.get_user_agent(request_id).await;
|
||||||
}
|
}
|
||||||
@@ -499,6 +507,40 @@ impl CodexMessageProcessor {
|
|||||||
self.outgoing.send_response(request_id, response).await;
|
self.outgoing.send_response(request_id, response).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) {
|
||||||
|
let SetDefaultModelParams {
|
||||||
|
model,
|
||||||
|
reasoning_effort,
|
||||||
|
} = params;
|
||||||
|
let effort_str = reasoning_effort.map(|effort| effort.to_string());
|
||||||
|
|
||||||
|
let overrides: [(&[&str], Option<&str>); 2] = [
|
||||||
|
(&[CONFIG_KEY_MODEL], model.as_deref()),
|
||||||
|
(&[CONFIG_KEY_EFFORT], effort_str.as_deref()),
|
||||||
|
];
|
||||||
|
|
||||||
|
match persist_non_null_overrides(
|
||||||
|
&self.config.codex_home,
|
||||||
|
self.config.active_profile.as_deref(),
|
||||||
|
&overrides,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(()) => {
|
||||||
|
let response = SetDefaultModelResponse {};
|
||||||
|
self.outgoing.send_response(request_id, response).await;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
let error = JSONRPCErrorError {
|
||||||
|
code: INTERNAL_ERROR_CODE,
|
||||||
|
message: format!("failed to persist overrides: {err}"),
|
||||||
|
data: None,
|
||||||
|
};
|
||||||
|
self.outgoing.send_error(request_id, error).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) {
|
||||||
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
tracing::debug!("ExecOneOffCommand params: {params:?}");
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ use codex_protocol::mcp_protocol::RemoveConversationListenerParams;
|
|||||||
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
use codex_protocol::mcp_protocol::ResumeConversationParams;
|
||||||
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
use codex_protocol::mcp_protocol::SendUserMessageParams;
|
||||||
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
use codex_protocol::mcp_protocol::SendUserTurnParams;
|
||||||
|
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||||
|
|
||||||
use mcp_types::CallToolRequestParams;
|
use mcp_types::CallToolRequestParams;
|
||||||
use mcp_types::ClientCapabilities;
|
use mcp_types::ClientCapabilities;
|
||||||
@@ -301,6 +302,15 @@ impl McpProcess {
|
|||||||
self.send_request("userInfo", None).await
|
self.send_request("userInfo", None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send a `setDefaultModel` JSON-RPC request.
|
||||||
|
pub async fn send_set_default_model_request(
|
||||||
|
&mut self,
|
||||||
|
params: SetDefaultModelParams,
|
||||||
|
) -> anyhow::Result<i64> {
|
||||||
|
let params = Some(serde_json::to_value(params)?);
|
||||||
|
self.send_request("setDefaultModel", params).await
|
||||||
|
}
|
||||||
|
|
||||||
/// Send a `listConversations` JSON-RPC request.
|
/// Send a `listConversations` JSON-RPC request.
|
||||||
pub async fn send_list_conversations_request(
|
pub async fn send_list_conversations_request(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
|||||||
@@ -9,5 +9,6 @@ mod interrupt;
|
|||||||
mod list_resume;
|
mod list_resume;
|
||||||
mod login;
|
mod login;
|
||||||
mod send_message;
|
mod send_message;
|
||||||
|
mod set_default_model;
|
||||||
mod user_agent;
|
mod user_agent;
|
||||||
mod user_info;
|
mod user_info;
|
||||||
|
|||||||
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
62
codex-rs/mcp-server/tests/suite/set_default_model.rs
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
use codex_core::config::ConfigToml;
|
||||||
|
use codex_protocol::config_types::ReasoningEffort;
|
||||||
|
use codex_protocol::mcp_protocol::SetDefaultModelParams;
|
||||||
|
use codex_protocol::mcp_protocol::SetDefaultModelResponse;
|
||||||
|
use mcp_test_support::McpProcess;
|
||||||
|
use mcp_test_support::to_response;
|
||||||
|
use mcp_types::JSONRPCResponse;
|
||||||
|
use mcp_types::RequestId;
|
||||||
|
use pretty_assertions::assert_eq;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn set_default_model_persists_overrides() {
|
||||||
|
let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}"));
|
||||||
|
|
||||||
|
let mut mcp = McpProcess::new(codex_home.path())
|
||||||
|
.await
|
||||||
|
.expect("spawn mcp process");
|
||||||
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize())
|
||||||
|
.await
|
||||||
|
.expect("init timeout")
|
||||||
|
.expect("init failed");
|
||||||
|
|
||||||
|
let params = SetDefaultModelParams {
|
||||||
|
model: Some("o4-mini".to_string()),
|
||||||
|
reasoning_effort: Some(ReasoningEffort::High),
|
||||||
|
};
|
||||||
|
|
||||||
|
let request_id = mcp
|
||||||
|
.send_set_default_model_request(params)
|
||||||
|
.await
|
||||||
|
.expect("send setDefaultModel");
|
||||||
|
|
||||||
|
let resp: JSONRPCResponse = timeout(
|
||||||
|
DEFAULT_READ_TIMEOUT,
|
||||||
|
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("setDefaultModel timeout")
|
||||||
|
.expect("setDefaultModel response");
|
||||||
|
|
||||||
|
let _: SetDefaultModelResponse =
|
||||||
|
to_response(resp).expect("deserialize setDefaultModel response");
|
||||||
|
|
||||||
|
let config_path = codex_home.path().join("config.toml");
|
||||||
|
let config_contents = tokio::fs::read_to_string(&config_path)
|
||||||
|
.await
|
||||||
|
.expect("read config.toml");
|
||||||
|
let config_toml: ConfigToml = toml::from_str(&config_contents).expect("parse config.toml");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
ConfigToml {
|
||||||
|
model: Some("o4-mini".to_string()),
|
||||||
|
model_reasoning_effort: Some(ReasoningEffort::High),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
config_toml,
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -40,6 +40,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
|||||||
codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?;
|
codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?;
|
||||||
codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?;
|
codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?;
|
||||||
codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?;
|
codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?;
|
||||||
|
codex_protocol::mcp_protocol::SetDefaultModelResponse::export_all_to(out_dir)?;
|
||||||
codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?;
|
codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?;
|
||||||
codex_protocol::mcp_protocol::UserInfoResponse::export_all_to(out_dir)?;
|
codex_protocol::mcp_protocol::UserInfoResponse::export_all_to(out_dir)?;
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,11 @@ pub enum ClientRequest {
|
|||||||
#[serde(rename = "id")]
|
#[serde(rename = "id")]
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
},
|
},
|
||||||
|
SetDefaultModel {
|
||||||
|
#[serde(rename = "id")]
|
||||||
|
request_id: RequestId,
|
||||||
|
params: SetDefaultModelParams,
|
||||||
|
},
|
||||||
GetUserAgent {
|
GetUserAgent {
|
||||||
#[serde(rename = "id")]
|
#[serde(rename = "id")]
|
||||||
request_id: RequestId,
|
request_id: RequestId,
|
||||||
@@ -416,6 +421,19 @@ pub struct GetUserSavedConfigResponse {
|
|||||||
pub config: UserSavedConfig,
|
pub config: UserSavedConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SetDefaultModelParams {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_effort: Option<ReasoningEffort>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct SetDefaultModelResponse {}
|
||||||
|
|
||||||
/// UserSavedConfig contains a subset of the config. It is meant to expose mcp
|
/// UserSavedConfig contains a subset of the config. It is meant to expose mcp
|
||||||
/// client-configurable settings that can be specified in the NewConversation
|
/// client-configurable settings that can be specified in the NewConversation
|
||||||
/// and SendUserTurn requests.
|
/// and SendUserTurn requests.
|
||||||
|
|||||||
Reference in New Issue
Block a user