Add support for custom base instructions (#1645)
Allows providing custom instructions file as a config parameter and custom instruction text via MCP tool call.
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::Codex;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::exec::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
@@ -11,7 +9,6 @@ mod test_support;
|
||||
use tempfile::TempDir;
|
||||
use test_support::load_default_config_for_test;
|
||||
use test_support::load_sse_fixture_with_id;
|
||||
use tokio::time::timeout;
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
@@ -86,21 +83,15 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut current_session_id = None;
|
||||
// Wait for TaskComplete
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) =
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::SessionConfigured(_)))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if let EventMsg::SessionConfigured(SessionConfiguredEvent { session_id, .. }) = ev.msg {
|
||||
current_session_id = Some(session_id.to_string());
|
||||
}
|
||||
if matches!(ev.msg, EventMsg::TaskComplete(_)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let current_session_id = Some(session_id.to_string());
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// get request from the server
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
@@ -108,6 +99,76 @@ async fn includes_session_id_and_model_headers_in_request() {
|
||||
let originator = request.headers.get("originator").unwrap();
|
||||
|
||||
assert!(current_session_id.is_some());
|
||||
assert_eq!(request_body.to_str().unwrap(), ¤t_session_id.unwrap());
|
||||
assert_eq!(
|
||||
request_body.to_str().unwrap(),
|
||||
current_session_id.as_ref().unwrap()
|
||||
);
|
||||
assert_eq!(originator.to_str().unwrap(), "codex_cli_rs");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn includes_base_instructions_override_in_request() {
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
// Mock server
|
||||
let server = MockServer::start().await;
|
||||
|
||||
// First request – must NOT include `previous_response_id`.
|
||||
let first = ResponseTemplate::new(200)
|
||||
.insert_header("content-type", "text/event-stream")
|
||||
.set_body_raw(sse_completed("resp1"), "text/event-stream");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/v1/responses"))
|
||||
.respond_with(first)
|
||||
.expect(1)
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
name: "openai".into(),
|
||||
base_url: format!("{}/v1", server.uri()),
|
||||
// Environment variable that should exist in the test environment.
|
||||
// ModelClient will return an error if the environment variable for the
|
||||
// provider is not set.
|
||||
env_key: Some("PATH".into()),
|
||||
env_key_instructions: None,
|
||||
wire_api: codex_core::WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: Some(0),
|
||||
stream_max_retries: Some(0),
|
||||
stream_idle_timeout_ms: None,
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
|
||||
config.base_instructions = Some("test instructions".to_string());
|
||||
config.model_provider = model_provider;
|
||||
|
||||
let ctrl_c = std::sync::Arc::new(tokio::sync::Notify::new());
|
||||
let (codex, ..) = Codex::spawn(config, ctrl_c.clone()).await.unwrap();
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![InputItem::Text {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
test_support::wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let request = &server.received_requests().await.unwrap()[0];
|
||||
let request_body = request.body_json::<serde_json::Value>().unwrap();
|
||||
|
||||
assert!(
|
||||
request_body["instructions"]
|
||||
.as_str()
|
||||
.unwrap()
|
||||
.contains("test instructions")
|
||||
);
|
||||
}
|
||||
|
||||
@@ -76,3 +76,24 @@ pub fn load_sse_fixture_with_id(path: impl AsRef<std::path::Path>, id: &str) ->
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn wait_for_event<F>(
|
||||
codex: &codex_core::Codex,
|
||||
mut predicate: F,
|
||||
) -> codex_core::protocol::EventMsg
|
||||
where
|
||||
F: FnMut(&codex_core::protocol::EventMsg) -> bool,
|
||||
{
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
loop {
|
||||
let ev = timeout(Duration::from_secs(1), codex.next_event())
|
||||
.await
|
||||
.expect("timeout waiting for event")
|
||||
.expect("stream ended unexpectedly");
|
||||
if predicate(&ev.msg) {
|
||||
return ev.msg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user