Add support for custom base instructions (#1645)

Allows providing custom instructions file as a config parameter and
custom instruction text via MCP tool call.
This commit is contained in:
pakrym-oai
2025-07-22 09:42:22 -07:00
committed by GitHub
parent ed206d5687
commit 6d82907082
12 changed files with 264 additions and 61 deletions

View File

@@ -34,11 +34,18 @@ pub struct Prompt {
/// the "fully qualified" tool name (i.e., prefixed with the server name),
/// which should be reported to the model in place of Tool::name.
pub extra_tools: HashMap<String, mcp_types::Tool>,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,
}
impl Prompt {
pub(crate) fn get_full_instructions(&self, model: &str) -> Cow<'_, str> {
let mut sections: Vec<&str> = vec![BASE_INSTRUCTIONS];
let base = self
.base_instructions_override
.as_deref()
.unwrap_or(BASE_INSTRUCTIONS);
let mut sections: Vec<&str> = vec![base];
if let Some(ref user) = self.user_instructions {
sections.push(user);
}

View File

@@ -108,13 +108,15 @@ impl Codex {
let (tx_sub, rx_sub) = async_channel::bounded(64);
let (tx_event, rx_event) = async_channel::bounded(1600);
let instructions = get_user_instructions(&config).await;
let user_instructions = get_user_instructions(&config).await;
let configure_session = Op::ConfigureSession {
provider: config.model_provider.clone(),
model: config.model.clone(),
model_reasoning_effort: config.model_reasoning_effort,
model_reasoning_summary: config.model_reasoning_summary,
instructions,
user_instructions,
base_instructions: config.base_instructions.clone(),
approval_policy: config.approval_policy,
sandbox_policy: config.sandbox_policy.clone(),
disable_response_storage: config.disable_response_storage,
@@ -183,7 +185,8 @@ pub(crate) struct Session {
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
cwd: PathBuf,
instructions: Option<String>,
base_instructions: Option<String>,
user_instructions: Option<String>,
approval_policy: AskForApproval,
sandbox_policy: SandboxPolicy,
shell_environment_policy: ShellEnvironmentPolicy,
@@ -577,7 +580,8 @@ async fn submission_loop(
model,
model_reasoning_effort,
model_reasoning_summary,
instructions,
user_instructions,
base_instructions,
approval_policy,
sandbox_policy,
disable_response_storage,
@@ -625,15 +629,17 @@ async fn submission_loop(
let rollout_recorder = match rollout_recorder {
Some(rec) => Some(rec),
None => match RolloutRecorder::new(&config, session_id, instructions.clone())
.await
{
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
None => {
match RolloutRecorder::new(&config, session_id, user_instructions.clone())
.await
{
Ok(r) => Some(r),
Err(e) => {
warn!("failed to initialise rollout recorder: {e}");
None
}
}
},
}
};
let client = ModelClient::new(
@@ -699,7 +705,8 @@ async fn submission_loop(
client,
tx_event: tx_event.clone(),
ctrl_c: Arc::clone(&ctrl_c),
instructions,
user_instructions,
base_instructions,
approval_policy,
sandbox_policy,
shell_environment_policy: config.shell_environment_policy.clone(),
@@ -1067,9 +1074,10 @@ async fn run_turn(
let prompt = Prompt {
input,
prev_id,
user_instructions: sess.instructions.clone(),
user_instructions: sess.user_instructions.clone(),
store,
extra_tools,
base_instructions_override: sess.base_instructions.clone(),
};
let mut retries = 0;

View File

@@ -63,7 +63,10 @@ pub struct Config {
pub disable_response_storage: bool,
/// User-provided instructions from instructions.md.
pub instructions: Option<String>,
pub user_instructions: Option<String>,
/// Base instructions override.
pub base_instructions: Option<String>,
/// Optional external notifier command. When set, Codex will spawn this
/// program after each completed *turn* (i.e. when the agent finishes
@@ -327,6 +330,9 @@ pub struct ConfigToml {
/// Experimental rollout resume path (absolute path to .jsonl; undocumented).
pub experimental_resume: Option<PathBuf>,
/// Experimental path to a file whose contents replace the built-in BASE_INSTRUCTIONS.
pub experimental_instructions_file: Option<PathBuf>,
}
impl ConfigToml {
@@ -359,6 +365,7 @@ pub struct ConfigOverrides {
pub model_provider: Option<String>,
pub config_profile: Option<String>,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub base_instructions: Option<String>,
}
impl Config {
@@ -369,7 +376,7 @@ impl Config {
overrides: ConfigOverrides,
codex_home: PathBuf,
) -> std::io::Result<Self> {
let instructions = Self::load_instructions(Some(&codex_home));
let user_instructions = Self::load_instructions(Some(&codex_home));
// Destructure ConfigOverrides fully to ensure all overrides are applied.
let ConfigOverrides {
@@ -380,6 +387,7 @@ impl Config {
model_provider,
config_profile: config_profile_key,
codex_linux_sandbox_exe,
base_instructions,
} = overrides;
let config_profile = match config_profile_key.as_ref().or(cfg.profile.as_ref()) {
@@ -457,6 +465,10 @@ impl Config {
let experimental_resume = cfg.experimental_resume;
let base_instructions = base_instructions.or(Self::get_base_instructions(
cfg.experimental_instructions_file.as_ref(),
));
let config = Self {
model,
model_context_window,
@@ -475,7 +487,8 @@ impl Config {
.or(cfg.disable_response_storage)
.unwrap_or(false),
notify: cfg.notify,
instructions,
user_instructions,
base_instructions,
mcp_servers: cfg.mcp_servers,
model_providers,
project_doc_max_bytes: cfg.project_doc_max_bytes.unwrap_or(PROJECT_DOC_MAX_BYTES),
@@ -525,6 +538,15 @@ impl Config {
}
})
}
fn get_base_instructions(path: Option<&PathBuf>) -> Option<String> {
let path = path.as_ref()?;
std::fs::read_to_string(path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
}
fn default_model() -> String {
@@ -801,7 +823,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: false,
instructions: None,
user_instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -818,6 +840,7 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
},
o3_profile_config
);
@@ -848,7 +871,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: false,
instructions: None,
user_instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -865,6 +888,7 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
};
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config);
@@ -910,7 +934,7 @@ disable_response_storage = true
sandbox_policy: SandboxPolicy::new_read_only_policy(),
shell_environment_policy: ShellEnvironmentPolicy::default(),
disable_response_storage: true,
instructions: None,
user_instructions: None,
notify: None,
cwd: fixture.cwd(),
mcp_servers: HashMap::new(),
@@ -927,6 +951,7 @@ disable_response_storage = true
model_supports_reasoning_summaries: false,
chatgpt_base_url: "https://chatgpt.com/backend-api/".to_string(),
experimental_resume: None,
base_instructions: None,
};
assert_eq!(expected_zdr_profile_config, zdr_profile_config);

View File

@@ -27,16 +27,16 @@ const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
/// string of instructions.
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
match find_project_doc(config).await {
Ok(Some(project_doc)) => match &config.instructions {
Ok(Some(project_doc)) => match &config.user_instructions {
Some(original_instructions) => Some(format!(
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
)),
None => Some(project_doc),
},
Ok(None) => config.instructions.clone(),
Ok(None) => config.user_instructions.clone(),
Err(e) => {
error!("error trying to find project doc: {e:#}");
config.instructions.clone()
config.user_instructions.clone()
}
}
}
@@ -159,7 +159,7 @@ mod tests {
config.cwd = root.path().to_path_buf();
config.project_doc_max_bytes = limit;
config.instructions = instructions.map(ToOwned::to_owned);
config.user_instructions = instructions.map(ToOwned::to_owned);
config
}

View File

@@ -44,8 +44,12 @@ pub enum Op {
model_reasoning_effort: ReasoningEffortConfig,
model_reasoning_summary: ReasoningSummaryConfig,
/// Model instructions
instructions: Option<String>,
/// Model instructions that are appended to the base instructions.
user_instructions: Option<String>,
/// Base instructions override.
base_instructions: Option<String>,
/// When to escalate for approval for execution
approval_policy: AskForApproval,
/// How to sandbox commands executed in the system

View File

@@ -1,5 +1,3 @@
use std::time::Duration;
use codex_core::Codex;
use codex_core::ModelProviderInfo;
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
@@ -11,7 +9,6 @@ mod test_support;
use tempfile::TempDir;
use test_support::load_default_config_for_test;
use test_support::load_sse_fixture_with_id;
use tokio::time::timeout;
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
@@ -86,21 +83,15 @@ async fn includes_session_id_and_model_headers_in_request() {
.await
.unwrap();
let mut current_session_id = None;
// Wait for TaskComplete
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_)))
.await
.unwrap()
.unwrap();
else {
unreachable!()
};
if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg {
current_session_id = Some(session_id.to_string());
}
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
break;
}
}
let current_session_id = Some(session_id.to_string());
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
// get request from the server
let request = &server.received_requests().await.unwrap()[0];
@@ -108,6 +99,76 @@ async fn includes_session_id_and_model_headers_in_request() {
let originator = request.headers.get("originator").unwrap();
assert!(current_session_id.is_some());
assert_eq!(request_body.to_str().unwrap(), &current_session_id.unwrap());
assert_eq!(
request_body.to_str().unwrap(),
current_session_id.as_ref().unwrap()
);
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn includes_base_instructions_override_in_request() {
#![allow(clippy::unwrap_used)]
// Mock server
let server = MockServer::start().await;
// First request must NOT include `previous_response_id`.
let first = ResponseTemplate::new(200)
.insert_header("content-type", "text/event-stream")
.set_body_raw(sse_completed("resp1"), "text/event-stream");
Mock::given(method("POST"))
.and(path("/v1/responses"))
.respond_with(first)
.expect(1)
.mount(&server)
.await;
let model_provider = ModelProviderInfo {
name: "openai".into(),
base_url: format!("{}/v1", server.uri()),
// Environment variable that should exist in the test environment.
// ModelClient will return an error if the environment variable for the
// provider is not set.
env_key: Some("PATH".into()),
env_key_instructions: None,
wire_api: codex_core::WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: None,
};
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home);
config.base_instructions = Some("test instructions".to_string());
config.model_provider = model_provider;
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
let (codex, ..) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
codex
.submit(Op::UserInput {
items: vec![InputItem::Text {
text: "hello".into(),
}],
})
.await
.unwrap();
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
let request = &server.received_requests().await.unwrap()[0];
let request_body = request.body_json::<serde_json::Value>().unwrap();
assert!(
request_body["instructions"]
.as_str()
.unwrap()
.contains("test instructions")
);
}

View File

@@ -76,3 +76,24 @@ pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) ->
})
.collect()
}
#[allow(dead_code)]
pub async fn wait_for_event<F>(
codex: &codex_core::Codex,
mut predicate: F,
) -> codex_core::protocol::EventMsg
where
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
{
use tokio::time::Duration;
use tokio::time::timeout;
loop {
let ev = timeout(Duration::from_secs(1), codex.next_event())
.await
.expect("timeout waiting for event")
.expect("stream ended unexpectedly");
if predicate(&ev.msg) {
return ev.msg;
}
}
}