diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index 98ef7f26..dae140bc 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -41,11 +41,9 @@ pub(crate) async fn stream_chat_completions( let full_instructions = prompt.get_full_instructions(model_family); messages.push(json!({"role": "system", "content": full_instructions})); - if let Some(instr) = &prompt.get_formatted_user_instructions() { - messages.push(json!({"role": "user", "content": instr})); - } + let input = prompt.get_formatted_input(); - for item in &prompt.input { + for item in &input { match item { ResponseItem::Message { role, content, .. } => { let mut text = String::new(); diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index e4bb30da..9748cde7 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -34,7 +34,6 @@ use crate::error::Result; use crate::flags::CODEX_RS_SSE_FIXTURE; use crate::model_provider_info::ModelProviderInfo; use crate::model_provider_info::WireApi; -use crate::models::ContentItem; use crate::models::ResponseItem; use crate::openai_tools::create_tools_json_for_responses_api; use crate::protocol::TokenUsage; @@ -146,15 +145,7 @@ impl ModelClient { vec![] }; - let mut input_with_instructions = Vec::with_capacity(prompt.input.len() + 1); - if let Some(ui) = prompt.get_formatted_user_instructions() { - input_with_instructions.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { text: ui }], - }); - } - input_with_instructions.extend(prompt.input.clone()); + let input_with_instructions = prompt.get_formatted_input(); let payload = ResponsesApiRequest { model: &self.config.model, diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 60164f5f..2ca060f4 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,14 +1,20 @@ use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::error::Result; +use crate::git_info::GitInfo; use crate::model_family::ModelFamily; +use crate::models::ContentItem; use crate::models::ResponseItem; use crate::openai_tools::OpenAiTool; +use crate::protocol::AskForApproval; +use crate::protocol::SandboxPolicy; use crate::protocol::TokenUsage; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; use futures::Stream; use serde::Serialize; use std::borrow::Cow; +use std::fmt::Display; +use std::path::PathBuf; use std::pin::Pin; use std::task::Context; use std::task::Poll; @@ -18,10 +24,49 @@ use tokio::sync::mpsc; /// with this content. const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md"); +/// wraps environment context message in a tag for the model to parse more easily. +const ENVIRONMENT_CONTEXT_START: &str = "\n\n"; +const ENVIRONMENT_CONTEXT_END: &str = "\n\n"; + /// wraps user instructions message in a tag for the model to parse more easily. const USER_INSTRUCTIONS_START: &str = "\n\n"; const USER_INSTRUCTIONS_END: &str = "\n\n"; +#[derive(Debug, Clone)] +pub(crate) struct EnvironmentContext { + pub cwd: PathBuf, + pub git_info: Option, + pub approval_policy: AskForApproval, + pub sandbox_policy: SandboxPolicy, +} + +impl Display for EnvironmentContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Current working directory: {}", + self.cwd.to_string_lossy() + )?; + writeln!(f, "Is directory a git repo: {}", self.git_info.is_some())?; + writeln!(f, "Approval policy: {}", self.approval_policy)?; + writeln!(f, "Sandbox policy: {}", self.sandbox_policy)?; + + let network_access = match self.sandbox_policy.clone() { + SandboxPolicy::DangerFullAccess => "enabled", + SandboxPolicy::ReadOnly => "restricted", + SandboxPolicy::WorkspaceWrite { network_access, .. } => { + if network_access { + "enabled" + } else { + "restricted" + } + } + }; + writeln!(f, "Network access: {network_access}")?; + Ok(()) + } +} + /// API request payload for a single model turn. #[derive(Default, Debug, Clone)] pub struct Prompt { @@ -33,6 +78,10 @@ pub struct Prompt { /// Whether to store response on server side (disable_response_storage = !store). pub store: bool, + /// A list of key-value pairs that will be added as a developer message + /// for the model to use + pub environment_context: Option, + /// Tools available to the model, including additional tools sourced from /// external MCP servers. pub tools: Vec, @@ -54,11 +103,37 @@ impl Prompt { Cow::Owned(sections.join("\n")) } - pub(crate) fn get_formatted_user_instructions(&self) -> Option { + fn get_formatted_user_instructions(&self) -> Option { self.user_instructions .as_ref() .map(|ui| format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}")) } + + fn get_formatted_environment_context(&self) -> Option { + self.environment_context + .as_ref() + .map(|ec| format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}")) + } + + pub(crate) fn get_formatted_input(&self) -> Vec { + let mut input_with_instructions = Vec::with_capacity(self.input.len() + 2); + if let Some(ec) = self.get_formatted_environment_context() { + input_with_instructions.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { text: ec }], + }); + } + if let Some(ui) = self.get_formatted_user_instructions() { + input_with_instructions.push(ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { text: ui }], + }); + } + input_with_instructions.extend(self.input.clone()); + input_with_instructions + } } #[derive(Debug)] diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index a7ab664e..c85b1ce2 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -37,6 +37,7 @@ use crate::apply_patch::convert_apply_patch_to_protocol; use crate::apply_patch::get_writable_roots; use crate::apply_patch::{self}; use crate::client::ModelClient; +use crate::client_common::EnvironmentContext; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; @@ -51,6 +52,7 @@ use crate::exec::SandboxType; use crate::exec::StdoutStream; use crate::exec::process_exec_tool_call; use crate::exec_env::create_env; +use crate::git_info::collect_git_info; use crate::mcp_connection_manager::McpConnectionManager; use crate::mcp_tool_call::handle_mcp_tool_call; use crate::models::ContentItem; @@ -1224,6 +1226,12 @@ async fn run_turn( store: !sess.disable_response_storage, tools, base_instructions_override: sess.base_instructions.clone(), + environment_context: Some(EnvironmentContext { + cwd: sess.cwd.clone(), + git_info: collect_git_info(&sess.cwd).await, + approval_policy: sess.approval_policy, + sandbox_policy: sess.sandbox_policy.clone(), + }), }; let mut retries = 0; @@ -1449,6 +1457,7 @@ async fn run_compact_task( input: turn_input, user_instructions: None, store: !sess.disable_response_storage, + environment_context: None, tools: Vec::new(), base_instructions_override: Some(compact_instructions.clone()), }; diff --git a/codex-rs/core/src/git_info.rs b/codex-rs/core/src/git_info.rs index f5dc016e..52d029f6 100644 --- a/codex-rs/core/src/git_info.rs +++ b/codex-rs/core/src/git_info.rs @@ -9,7 +9,7 @@ use tokio::time::timeout; /// Timeout for git commands to prevent freezing on large repositories const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5); -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct GitInfo { /// Current commit hash (SHA) #[serde(skip_serializing_if = "Option::is_none")] diff --git a/codex-rs/core/src/protocol.rs b/codex-rs/core/src/protocol.rs index 9bf85ec4..55000fb6 100644 --- a/codex-rs/core/src/protocol.rs +++ b/codex-rs/core/src/protocol.rs @@ -159,7 +159,8 @@ pub enum AskForApproval { } /// Determines execution restrictions for model shell commands. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Display)] +#[strum(serialize_all = "kebab-case")] #[serde(tag = "mode", rename_all = "kebab-case")] pub enum SandboxPolicy { /// No restrictions whatsoever. Use with caution. diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index f4930202..00f91a87 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -1,3 +1,5 @@ +#![allow(clippy::expect_used)] +#![allow(clippy::unwrap_used)] use std::path::PathBuf; use chrono::Utc; @@ -32,6 +34,32 @@ fn sse_completed(id: &str) -> String { load_sse_fixture_with_id("tests/fixtures/completed_template.json", id) } +fn assert_message_role(request_body: &serde_json::Value, role: &str) { + assert_eq!(request_body["role"].as_str().unwrap(), role); +} + +fn assert_message_starts_with(request_body: &serde_json::Value, text: &str) { + let content = request_body["content"][0]["text"] + .as_str() + .expect("invalid message content"); + + assert!( + content.starts_with(text), + "expected message content '{content}' to start with '{text}'" + ); +} + +fn assert_message_ends_with(request_body: &serde_json::Value, text: &str) { + let content = request_body["content"][0]["text"] + .as_str() + .expect("invalid message content"); + + assert!( + content.ends_with(text), + "expected message content '{content}' to end with '{text}'" + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn includes_session_id_and_model_headers_in_request() { #![allow(clippy::unwrap_used)] @@ -371,19 +399,12 @@ async fn includes_user_instructions_message_in_request() { .unwrap() .contains("be nice") ); - assert_eq!(request_body["input"][0]["role"], "user"); - assert!( - request_body["input"][0]["content"][0]["text"] - .as_str() - .unwrap() - .starts_with("\n\nbe nice") - ); - assert!( - request_body["input"][0]["content"][0]["text"] - .as_str() - .unwrap() - .ends_with("") - ); + assert_message_role(&request_body["input"][0], "user"); + assert_message_starts_with(&request_body["input"][0], "\n\n"); + assert_message_ends_with(&request_body["input"][0], ""); + assert_message_role(&request_body["input"][1], "user"); + assert_message_starts_with(&request_body["input"][1], "\n\n"); + assert_message_ends_with(&request_body["input"][1], ""); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] diff --git a/codex-rs/mcp-server/tests/send_message.rs b/codex-rs/mcp-server/tests/send_message.rs index fd3718e8..6e138909 100644 --- a/codex-rs/mcp-server/tests/send_message.rs +++ b/codex-rs/mcp-server/tests/send_message.rs @@ -99,7 +99,7 @@ async fn test_send_message_success() { response ); // wait for the server to hear the user message - sleep(Duration::from_secs(1)); + sleep(Duration::from_secs(10)); // Ensure the server and tempdir live until end of test drop(server);