From 6d82907082a7317e72976e625ecd647a6f439128 Mon Sep 17 00:00:00 2001 From: pakrym-oai Date: Tue, 22 Jul 2025 09:42:22 -0700 Subject: [PATCH] Add support for custom base instructions (#1645) Allows providing custom instructions file as a config parameter and custom instruction text via MCP tool call. --- codex-rs/core/src/client_common.rs | 9 +- codex-rs/core/src/codex.rs | 36 ++++--- codex-rs/core/src/config.rs | 37 ++++++-- codex-rs/core/src/project_doc.rs | 8 +- codex-rs/core/src/protocol.rs | 8 +- codex-rs/core/tests/client.rs | 95 +++++++++++++++---- codex-rs/core/tests/test_support.rs | 21 ++++ codex-rs/exec/src/lib.rs | 1 + codex-rs/mcp-server/src/codex_tool_config.rs | 12 ++- .../tests/{elicitation.rs => codex_tool.rs} | 84 +++++++++++++++- .../mcp-server/tests/common/mcp_process.rs | 13 +-- codex-rs/tui/src/lib.rs | 1 + 12 files changed, 264 insertions(+), 61 deletions(-) rename codex-rs/mcp-server/tests/{elicitation.rs => codex_tool.rs} (81%) diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 3e3c2e7e..94d09e7f 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -34,11 +34,18 @@ pub struct Prompt { /// the "fully qualified" tool name (i.e., prefixed with the server name), /// which should be reported to the model in place of Tool::name. pub extra_tools: HashMap, + + /// Optional override for the built-in BASE_INSTRUCTIONS. + pub base_instructions_override: Option, } impl Prompt { pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> { - let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS]; + let base = self + .base_instructions_override + .as_deref() + .unwrap_or(BASE_INSTRUCTIONS); + let mut sections: Vec<&str> = vec![base]; if let Some(ref user) = self.user_instructions { sections.push(user); } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 392e84ea..6eb1715f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -108,13 +108,15 @@ impl Codex { let (tx_sub, rx_sub) = async_channel::bounded(64); let (tx_event, rx_event) = async_channel::bounded(1600); - let instructions = get_user_instructions(&config).await; + let user_instructions = get_user_instructions(&config).await; + let configure_session = Op::ConfigureSession { provider: config.model_provider.clone(), model: config.model.clone(), model_reasoning_effort: config.model_reasoning_effort, model_reasoning_summary: config.model_reasoning_summary, - instructions, + user_instructions, + base_instructions: config.base_instructions.clone(), approval_policy: config.approval_policy, sandbox_policy: config.sandbox_policy.clone(), disable_response_storage: config.disable_response_storage, @@ -183,7 +185,8 @@ pub(crate) struct Session { /// the model as well as sandbox policies are resolved against this path /// instead of `std::env::current_dir()`. cwd: PathBuf, - instructions: Option, + base_instructions: Option, + user_instructions: Option, approval_policy: AskForApproval, sandbox_policy: SandboxPolicy, shell_environment_policy: ShellEnvironmentPolicy, @@ -577,7 +580,8 @@ async fn submission_loop( model, model_reasoning_effort, model_reasoning_summary, - instructions, + user_instructions, + base_instructions, approval_policy, sandbox_policy, disable_response_storage, @@ -625,15 +629,17 @@ async fn submission_loop( let rollout_recorder = match rollout_recorder { Some(rec) => Some(rec), - None => match RolloutRecorder::new(&config, session_id, instructions.clone()) - .await - { - Ok(r) => Some(r), - Err(e) => { - warn!("failed to initialise rollout recorder: {e}"); - None + None => { + match RolloutRecorder::new(&config, session_id, user_instructions.clone()) + .await + { + Ok(r) => Some(r), + Err(e) => { + warn!("failed to initialise rollout recorder: {e}"); + None + } } - }, + } }; let client = ModelClient::new( @@ -699,7 +705,8 @@ async fn submission_loop( client, tx_event: tx_event.clone(), ctrl_c: Arc::clone(&ctrl_c), - instructions, + user_instructions, + base_instructions, approval_policy, sandbox_policy, shell_environment_policy: config.shell_environment_policy.clone(), @@ -1067,9 +1074,10 @@ async fn run_turn( let prompt = Prompt { input, prev_id, - user_instructions: sess.instructions.clone(), + user_instructions: sess.user_instructions.clone(), store, extra_tools, + base_instructions_override: sess.base_instructions.clone(), }; let mut retries = 0; diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index f1d0dd9d..8ed06c45 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -63,7 +63,10 @@ pub struct Config { pub disable_response_storage: bool, /// User-provided instructions from instructions.md. - pub instructions: Option, + pub user_instructions: Option, + + /// Base instructions override. + pub base_instructions: Option, /// Optional external notifier command. When set, Codex will spawn this /// program after each completed *turn* (i.e. when the agent finishes @@ -327,6 +330,9 @@ pub struct ConfigToml { /// Experimental rollout resume path (absolute path to .jsonl; undocumented). pub experimental_resume: Option, + + /// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS. + pub experimental_instructions_file: Option, } impl ConfigToml { @@ -359,6 +365,7 @@ pub struct ConfigOverrides { pub model_provider: Option, pub config_profile: Option, pub codex_linux_sandbox_exe: Option, + pub base_instructions: Option, } impl Config { @@ -369,7 +376,7 @@ impl Config { overrides: ConfigOverrides, codex_home: PathBuf, ) -> std::io::Result { - let instructions = Self::load_instructions(Some(&codex_home)); + let user_instructions = Self::load_instructions(Some(&codex_home)); // Destructure ConfigOverrides fully to ensure all overrides are applied. let ConfigOverrides { @@ -380,6 +387,7 @@ impl Config { model_provider, config_profile: config_profile_key, codex_linux_sandbox_exe, + base_instructions, } = overrides; let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) { @@ -457,6 +465,10 @@ impl Config { let experimental_resume = cfg.experimental_resume; + let base_instructions = base_instructions.or(Self::get_base_instructions( + cfg.experimental_instructions_file.as_ref(), + )); + let config = Self { model, model_context_window, @@ -475,7 +487,8 @@ impl Config { .or(cfg.disable_response_storage) .unwrap_or(false), notify: cfg.notify, - instructions, + user_instructions, + base_instructions, mcp_servers: cfg.mcp_servers, model_providers, project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES), @@ -525,6 +538,15 @@ impl Config { } }) } + + fn get_base_instructions(path: Option<&PathBuf>) -> Option { + let path = path.as_ref()?; + + std::fs::read_to_string(path) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + } } fn default_model() -> String { @@ -801,7 +823,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: false, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -818,6 +840,7 @@ disable_response_storage = true model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), experimental_resume: None, + base_instructions: None, }, o3_profile_config ); @@ -848,7 +871,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: false, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -865,6 +888,7 @@ disable_response_storage = true model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), experimental_resume: None, + base_instructions: None, }; assert_eq!(expected_gpt3_profile_config, gpt3_profile_config); @@ -910,7 +934,7 @@ disable_response_storage = true sandbox_policy: SandboxPolicy::new_read_only_policy(), shell_environment_policy: ShellEnvironmentPolicy::default(), disable_response_storage: true, - instructions: None, + user_instructions: None, notify: None, cwd: fixture.cwd(), mcp_servers: HashMap::new(), @@ -927,6 +951,7 @@ disable_response_storage = true model_supports_reasoning_summaries: false, chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(), experimental_resume: None, + base_instructions: None, }; assert_eq!(expected_zdr_profile_config, zdr_profile_config); diff --git a/codex-rs/core/src/project_doc.rs b/codex-rs/core/src/project_doc.rs index ab9d4618..9f46159d 100644 --- a/codex-rs/core/src/project_doc.rs +++ b/codex-rs/core/src/project_doc.rs @@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n"; /// string of instructions. pub(crate) async fn get_user_instructions(config: &Config) -> Option { match find_project_doc(config).await { - Ok(Some(project_doc)) => match &config.instructions { + Ok(Some(project_doc)) => match &config.user_instructions { Some(original_instructions) => Some(format!( "{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}" )), None => Some(project_doc), }, - Ok(None) => config.instructions.clone(), + Ok(None) => config.user_instructions.clone(), Err(e) => { error!("error trying to find project doc: {e:#}"); - config.instructions.clone() + config.user_instructions.clone() } } } @@ -159,7 +159,7 @@ mod tests { config.cwd = root.path().to_path_buf(); config.project_doc_max_bytes = limit; - config.instructions = instructions.map(ToOwned::to_owned); + config.user_instructions = instructions.map(ToOwned::to_owned); config } diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 08d55b97..9f6e004b 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -44,8 +44,12 @@ pub enum Op { model_reasoning_effort: ReasoningEffortConfig, model_reasoning_summary: ReasoningSummaryConfig, - /// Model instructions - instructions: Option, + /// Model instructions that are appended to the base instructions. + user_instructions: Option, + + /// Base instructions override. + base_instructions: Option, + /// When to escalate for approval for execution approval_policy: AskForApproval, /// How to sandbox commands executed in the system diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index fe4710c8..5a6b6100 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use codex_core::Codex; use codex_core::ModelProviderInfo; use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; @@ -11,7 +9,6 @@ mod test_support; use tempfile::TempDir; use test_support::load_default_config_for_test; use test_support::load_sse_fixture_with_id; -use tokio::time::timeout; use wiremock::Mock; use wiremock::MockServer; use wiremock::ResponseTemplate; @@ -86,21 +83,15 @@ async fn includes_session_id_and_model_headers_in_request() { .await .unwrap(); - let mut current_session_id = None; - // Wait for TaskComplete - loop { - let ev = timeout(Duration::from_secs(1), codex.next_event()) + let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = + test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_))) .await - .unwrap() - .unwrap(); + else { + unreachable!() + }; - if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg { - current_session_id = Some(session_id.to_string()); - } - if matches!(ev.msg, EventMsg::TaskComplete(_)) { - break; - } - } + let current_session_id = Some(session_id.to_string()); + test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; // get request from the server let request = &server.received_requests().await.unwrap()[0]; @@ -108,6 +99,76 @@ async fn includes_session_id_and_model_headers_in_request() { let originator = request.headers.get("originator").unwrap(); assert!(current_session_id.is_some()); - assert_eq!(request_body.to_str().unwrap(), ¤t_session_id.unwrap()); + assert_eq!( + request_body.to_str().unwrap(), + current_session_id.as_ref().unwrap() + ); assert_eq!(originator.to_str().unwrap(), "codex_cli_rs"); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn includes_base_instructions_override_in_request() { + #![allow(clippy::unwrap_used)] + + // Mock server + let server = MockServer::start().await; + + // First request – must NOT include `previous_response_id`. + let first = ResponseTemplate::new(200) + .insert_header("content-type", "text/event-stream") + .set_body_raw(sse_completed("resp1"), "text/event-stream"); + + Mock::given(method("POST")) + .and(path("/v1/responses")) + .respond_with(first) + .expect(1) + .mount(&server) + .await; + + let model_provider = ModelProviderInfo { + name: "openai".into(), + base_url: format!("{}/v1", server.uri()), + // Environment variable that should exist in the test environment. + // ModelClient will return an error if the environment variable for the + // provider is not set. + env_key: Some("PATH".into()), + env_key_instructions: None, + wire_api: codex_core::WireApi::Responses, + query_params: None, + http_headers: None, + env_http_headers: None, + request_max_retries: Some(0), + stream_max_retries: Some(0), + stream_idle_timeout_ms: None, + }; + + let codex_home = TempDir::new().unwrap(); + let mut config = load_default_config_for_test(&codex_home); + + config.base_instructions = Some("test instructions".to_string()); + config.model_provider = model_provider; + + let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new()); + let (codex, ..) = Codex::spawn(config, ctrl_c.clone()).await.unwrap(); + + codex + .submit(Op::UserInput { + items: vec![InputItem::Text { + text: "hello".into(), + }], + }) + .await + .unwrap(); + + test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + let request = &server.received_requests().await.unwrap()[0]; + let request_body = request.body_json::().unwrap(); + + assert!( + request_body["instructions"] + .as_str() + .unwrap() + .contains("test instructions") + ); +} diff --git a/codex-rs/core/tests/test_support.rs b/codex-rs/core/tests/test_support.rs index 7d1e3a7f..83b8a147 100644 --- a/codex-rs/core/tests/test_support.rs +++ b/codex-rs/core/tests/test_support.rs @@ -76,3 +76,24 @@ pub fn load_sse_fixture_with_id(path: impl AsRef, id: &str) -> }) .collect() } + +#[allow(dead_code)] +pub async fn wait_for_event( + codex: &codex_core::Codex, + mut predicate: F, +) -> codex_core::protocol::EventMsg +where + F: FnMut(&codex_core::protocol::EventMsg) -> bool, +{ + use tokio::time::Duration; + use tokio::time::timeout; + loop { + let ev = timeout(Duration::from_secs(1), codex.next_event()) + .await + .expect("timeout waiting for event") + .expect("stream ended unexpectedly"); + if predicate(&ev.msg) { + return ev.msg; + } + } +} diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 769d3c3b..620ab823 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -110,6 +110,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> any cwd: cwd.map(|p| p.canonicalize().unwrap_or(p)), model_provider: None, codex_linux_sandbox_exe, + base_instructions: None, }; // Parse `-c` overrides. let cli_kv_overrides = match config_overrides.parse_overrides() { diff --git a/codex-rs/mcp-server/src/codex_tool_config.rs b/codex-rs/mcp-server/src/codex_tool_config.rs index 54d108c0..6357c94b 100644 --- a/codex-rs/mcp-server/src/codex_tool_config.rs +++ b/codex-rs/mcp-server/src/codex_tool_config.rs @@ -14,7 +14,7 @@ use std::path::PathBuf; use crate::json_to_toml::json_to_toml; /// Client-supplied configuration for a `codex` tool-call. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] #[serde(rename_all = "kebab-case")] pub struct CodexToolCallParam { /// The *initial user prompt* to start the Codex conversation. @@ -46,6 +46,10 @@ pub struct CodexToolCallParam { /// CODEX_HOME/config.toml. #[serde(default, skip_serializing_if = "Option::is_none")] pub config: Option>, + + /// The set of instructions to use instead of the default ones. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub base_instructions: Option, } /// Custom enum mirroring [`AskForApproval`], but has an extra dependency on @@ -135,6 +139,7 @@ impl CodexToolCallParam { approval_policy, sandbox, config: cli_overrides, + base_instructions, } = self; // Build the `ConfigOverrides` recognised by codex-core. @@ -146,6 +151,7 @@ impl CodexToolCallParam { sandbox_mode: sandbox.map(Into::into), model_provider: None, codex_linux_sandbox_exe, + base_instructions, }; let cli_overrides = cli_overrides @@ -268,6 +274,10 @@ mod tests { "description": "The *initial user prompt* to start the Codex conversation.", "type": "string" }, + "base-instructions": { + "description": "The set of instructions to use instead of the default ones.", + "type": "string" + }, }, "required": [ "prompt" diff --git a/codex-rs/mcp-server/tests/elicitation.rs b/codex-rs/mcp-server/tests/codex_tool.rs similarity index 81% rename from codex-rs/mcp-server/tests/elicitation.rs rename to codex-rs/mcp-server/tests/codex_tool.rs index ac9435e8..d36813ce 100644 --- a/codex-rs/mcp-server/tests/elicitation.rs +++ b/codex-rs/mcp-server/tests/codex_tool.rs @@ -8,6 +8,7 @@ use std::path::PathBuf; use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use codex_core::protocol::FileChange; use codex_core::protocol::ReviewDecision; +use codex_mcp_server::CodexToolCallParam; use codex_mcp_server::ExecApprovalElicitRequestParams; use codex_mcp_server::ExecApprovalResponse; use codex_mcp_server::PatchApprovalElicitRequestParams; @@ -76,7 +77,10 @@ async fn shell_command_approval_triggers_elicitation() -> anyhow::Result<()> { // In turn, it should reply with a tool call, which the MCP should forward // as an elicitation. let codex_request_id = mcp_process - .send_codex_tool_call(None, "run `git init`") + .send_codex_tool_call(CodexToolCallParam { + prompt: "run `git init`".to_string(), + ..Default::default() + }) .await?; let elicitation_request = timeout( DEFAULT_READ_TIMEOUT, @@ -209,10 +213,11 @@ async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> { // Send a "codex" tool request that will trigger the apply_patch command let codex_request_id = mcp_process - .send_codex_tool_call( - Some(cwd.path().to_string_lossy().to_string()), - "please modify the test file", - ) + .send_codex_tool_call(CodexToolCallParam { + cwd: Some(cwd.path().to_string_lossy().to_string()), + prompt: "please modify the test file".to_string(), + ..Default::default() + }) .await?; let elicitation_request = timeout( DEFAULT_READ_TIMEOUT, @@ -279,6 +284,75 @@ async fn patch_approval_triggers_elicitation() -> anyhow::Result<()> { Ok(()) } +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_codex_tool_passes_base_instructions() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + // Apparently `#[tokio::test]` must return `()`, so we create a helper + // function that returns `Result` so we can use `?` in favor of `unwrap`. + if let Err(err) = codex_tool_passes_base_instructions().await { + panic!("failure: {err}"); + } +} + +async fn codex_tool_passes_base_instructions() -> anyhow::Result<()> { + #![allow(clippy::unwrap_used)] + + let server = + create_mock_chat_completions_server(vec![create_final_assistant_message_sse_response( + "Enjoy!", + )?]) + .await; + + // Run `codex mcp` with a specific config.toml. + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri())?; + let mut mcp_process = McpProcess::new(codex_home.path()).await?; + timeout(DEFAULT_READ_TIMEOUT, mcp_process.initialize()).await??; + + // Send a "codex" tool request, which should hit the completions endpoint. + let codex_request_id = mcp_process + .send_codex_tool_call(CodexToolCallParam { + prompt: "How are you?".to_string(), + base_instructions: Some("You are a helpful assistant.".to_string()), + ..Default::default() + }) + .await?; + + let codex_response = timeout( + DEFAULT_READ_TIMEOUT, + mcp_process.read_stream_until_response_message(RequestId::Integer(codex_request_id)), + ) + .await??; + assert_eq!( + JSONRPCResponse { + jsonrpc: JSONRPC_VERSION.into(), + id: RequestId::Integer(codex_request_id), + result: json!({ + "content": [ + { + "text": "Enjoy!", + "type": "text" + } + ] + }), + }, + codex_response + ); + + let requests = server.received_requests().await.unwrap(); + let request = requests[0].body_json::().unwrap(); + let instructions = request["messages"][0]["content"].as_str().unwrap(); + assert!(instructions.starts_with("You are a helpful assistant.")); + + Ok(()) +} + fn create_expected_patch_approval_elicitation_request( elicitation_request_id: RequestId, changes: HashMap, diff --git a/codex-rs/mcp-server/tests/common/mcp_process.rs b/codex-rs/mcp-server/tests/common/mcp_process.rs index df9cc98a..a86deaab 100644 --- a/codex-rs/mcp-server/tests/common/mcp_process.rs +++ b/codex-rs/mcp-server/tests/common/mcp_process.rs @@ -141,20 +141,11 @@ impl McpProcess { /// correlating notifications. pub async fn send_codex_tool_call( &mut self, - cwd: Option, - prompt: &str, + params: CodexToolCallParam, ) -> anyhow::Result { let codex_tool_call_params = CallToolRequestParams { name: "codex".to_string(), - arguments: Some(serde_json::to_value(CodexToolCallParam { - cwd, - prompt: prompt.to_string(), - model: None, - profile: None, - approval_policy: None, - sandbox: None, - config: None, - })?), + arguments: Some(serde_json::to_value(params)?), }; self.send_request( mcp_types::CallToolRequest::METHOD, diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 4ca305b3..05a55edc 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -75,6 +75,7 @@ pub fn run_main(cli: Cli, codex_linux_sandbox_exe: Option) -> std::io:: model_provider: None, config_profile: cli.config_profile.clone(), codex_linux_sandbox_exe, + base_instructions: None, }; // Parse `-c` overrides from the CLI. let cli_kv_overrides = match cli.config_overrides.parse_overrides() {