From c172e8e997f794c7e8bff5df781fc2b87117bae6 Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Thu, 11 Sep 2025 23:44:17 -0700 Subject: [PATCH] feat: added SetDefaultModel to JSON-RPC server (#3512) This adds `SetDefaultModel`, which takes `model` and `reasoning_effort` as optional fields. If set, the field will overwrite what is in the user's `config.toml`. This reuses logic that was added to support the `/model` command in the TUI: https://github.com/openai/codex/pull/2799. --- codex-rs/core/src/config.rs | 4 +- .../mcp-server/src/codex_message_processor.rs | 42 +++++++++++++ .../mcp-server/tests/common/mcp_process.rs | 10 +++ codex-rs/mcp-server/tests/suite/mod.rs | 1 + .../tests/suite/set_default_model.rs | 62 +++++++++++++++++++ codex-rs/protocol-ts/src/lib.rs | 1 + codex-rs/protocol/src/mcp_protocol.rs | 18 ++++++ 7 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 codex-rs/mcp-server/tests/suite/set_default_model.rs diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index c2b1fd93..2cfd15fc 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -495,7 +495,7 @@ fn apply_toml_override(root: &mut TomlValue, path: &str, value: TomlValue) { } /// Base config deserialized from ~/.codex/config.toml. -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] pub struct ConfigToml { /// Optional override of model selection. pub model: Option, @@ -627,7 +627,7 @@ pub struct ProjectConfig { pub trust_level: Option, } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] pub struct ToolsToml { #[serde(default, alias = "web_search_request")] pub web_search: Option, diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index b7b2b0f9..6cd6fdec 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -18,6 +18,9 @@ use codex_core::config::Config; use codex_core::config::ConfigOverrides; use codex_core::config::ConfigToml; use codex_core::config::load_config_as_toml; +use codex_core::config_edit::CONFIG_KEY_EFFORT; +use codex_core::config_edit::CONFIG_KEY_MODEL; +use codex_core::config_edit::persist_non_null_overrides; use codex_core::default_client::get_codex_user_agent; use codex_core::exec::ExecParams; use codex_core::exec_env::create_env; @@ -71,6 +74,8 @@ use codex_protocol::mcp_protocol::SendUserMessageResponse; use codex_protocol::mcp_protocol::SendUserTurnParams; use codex_protocol::mcp_protocol::SendUserTurnResponse; use codex_protocol::mcp_protocol::ServerNotification; +use codex_protocol::mcp_protocol::SetDefaultModelParams; +use codex_protocol::mcp_protocol::SetDefaultModelResponse; use codex_protocol::mcp_protocol::UserInfoResponse; use codex_protocol::mcp_protocol::UserSavedConfig; use codex_protocol::models::ContentItem; @@ -192,6 +197,9 @@ impl CodexMessageProcessor { ClientRequest::GetUserSavedConfig { request_id } => { self.get_user_saved_config(request_id).await; } + ClientRequest::SetDefaultModel { request_id, params } => { + self.set_default_model(request_id, params).await; + } ClientRequest::GetUserAgent { request_id } => { self.get_user_agent(request_id).await; } @@ -499,6 +507,40 @@ impl CodexMessageProcessor { self.outgoing.send_response(request_id, response).await; } + async fn set_default_model(&self, request_id: RequestId, params: SetDefaultModelParams) { + let SetDefaultModelParams { + model, + reasoning_effort, + } = params; + let effort_str = reasoning_effort.map(|effort| effort.to_string()); + + let overrides: [(&[&str], Option<&str>); 2] = [ + (&[CONFIG_KEY_MODEL], model.as_deref()), + (&[CONFIG_KEY_EFFORT], effort_str.as_deref()), + ]; + + match persist_non_null_overrides( + &self.config.codex_home, + self.config.active_profile.as_deref(), + &overrides, + ) + .await + { + Ok(()) => { + let response = SetDefaultModelResponse {}; + self.outgoing.send_response(request_id, response).await; + } + Err(err) => { + let error = JSONRPCErrorError { + code: INTERNAL_ERROR_CODE, + message: format!("failed to persist overrides: {err}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + } + } + } + async fn exec_one_off_command(&self, request_id: RequestId, params: ExecOneOffCommandParams) { tracing::debug!("ExecOneOffCommand params: {params:?}"); diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index a7969f95..a30e9817 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -24,6 +24,7 @@ use codex_protocol::mcp_protocol::RemoveConversationListenerParams; use codex_protocol::mcp_protocol::ResumeConversationParams; use codex_protocol::mcp_protocol::SendUserMessageParams; use codex_protocol::mcp_protocol::SendUserTurnParams; +use codex_protocol::mcp_protocol::SetDefaultModelParams; use mcp_types::CallToolRequestParams; use mcp_types::ClientCapabilities; @@ -301,6 +302,15 @@ impl McpProcess { self.send_request("userInfo", None).await } + /// Send a `setDefaultModel` JSON-RPC request. + pub async fn send_set_default_model_request( + &mut self, + params: SetDefaultModelParams, + ) -> anyhow::Result { + let params = Some(serde_json::to_value(params)?); + self.send_request("setDefaultModel", params).await + } + /// Send a `listConversations` JSON-RPC request. pub async fn send_list_conversations_request( &mut self, diff --git a/codex-rs/mcp-server/tests/suite/mod.rs b/codex-rs/mcp-server/tests/suite/mod.rs index 4e603e17..97e53709 100644 --- a/codex-rs/mcp-server/tests/suite/mod.rs +++ b/codex-rs/mcp-server/tests/suite/mod.rs @@ -9,5 +9,6 @@ mod interrupt; mod list_resume; mod login; mod send_message; +mod set_default_model; mod user_agent; mod user_info; diff --git a/codex-rs/mcp-server/tests/suite/set_default_model.rs b/codex-rs/mcp-server/tests/suite/set_default_model.rs new file mode 100644 index 00000000..e11b69ed --- /dev/null +++ b/codex-rs/mcp-server/tests/suite/set_default_model.rs @@ -0,0 +1,62 @@ +use codex_core::config::ConfigToml; +use codex_protocol::config_types::ReasoningEffort; +use codex_protocol::mcp_protocol::SetDefaultModelParams; +use codex_protocol::mcp_protocol::SetDefaultModelResponse; +use mcp_test_support::McpProcess; +use mcp_test_support::to_response; +use mcp_types::JSONRPCResponse; +use mcp_types::RequestId; +use pretty_assertions::assert_eq; +use tempfile::TempDir; +use tokio::time::timeout; + +const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn set_default_model_persists_overrides() { + let codex_home = TempDir::new().unwrap_or_else(|e| panic!("create tempdir: {e}")); + + let mut mcp = McpProcess::new(codex_home.path()) + .await + .expect("spawn mcp process"); + timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()) + .await + .expect("init timeout") + .expect("init failed"); + + let params = SetDefaultModelParams { + model: Some("o4-mini".to_string()), + reasoning_effort: Some(ReasoningEffort::High), + }; + + let request_id = mcp + .send_set_default_model_request(params) + .await + .expect("send setDefaultModel"); + + let resp: JSONRPCResponse = timeout( + DEFAULT_READ_TIMEOUT, + mcp.read_stream_until_response_message(RequestId::Integer(request_id)), + ) + .await + .expect("setDefaultModel timeout") + .expect("setDefaultModel response"); + + let _: SetDefaultModelResponse = + to_response(resp).expect("deserialize setDefaultModel response"); + + let config_path = codex_home.path().join("config.toml"); + let config_contents = tokio::fs::read_to_string(&config_path) + .await + .expect("read config.toml"); + let config_toml: ConfigToml = toml::from_str(&config_contents).expect("parse config.toml"); + + assert_eq!( + ConfigToml { + model: Some("o4-mini".to_string()), + model_reasoning_effort: Some(ReasoningEffort::High), + ..Default::default() + }, + config_toml, + ); +} diff --git a/codex-rs/protocol-ts/src/lib.rs b/codex-rs/protocol-ts/src/lib.rs index 6fda7074..967d9b1f 100644 --- a/codex-rs/protocol-ts/src/lib.rs +++ b/codex-rs/protocol-ts/src/lib.rs @@ -40,6 +40,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ExecCommandApprovalResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::GetUserSavedConfigResponse::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::SetDefaultModelResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::GetUserAgentResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::UserInfoResponse::export_all_to(out_dir)?; diff --git a/codex-rs/protocol/src/mcp_protocol.rs b/codex-rs/protocol/src/mcp_protocol.rs index a382b8bc..42f38d46 100644 --- a/codex-rs/protocol/src/mcp_protocol.rs +++ b/codex-rs/protocol/src/mcp_protocol.rs @@ -153,6 +153,11 @@ pub enum ClientRequest { #[serde(rename = "id")] request_id: RequestId, }, + SetDefaultModel { + #[serde(rename = "id")] + request_id: RequestId, + params: SetDefaultModelParams, + }, GetUserAgent { #[serde(rename = "id")] request_id: RequestId, @@ -416,6 +421,19 @@ pub struct GetUserSavedConfigResponse { pub config: UserSavedConfig, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct SetDefaultModelParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct SetDefaultModelResponse {} + /// UserSavedConfig contains a subset of the config. It is meant to expose mcp /// client-configurable settings that can be specified in the NewConversation /// and SendUserTurn requests.