diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 440d250b..d8684648 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -5,15 +5,11 @@ use crate::model_family::ModelFamily; use crate::models::ContentItem; use crate::models::ResponseItem; use crate::openai_tools::OpenAiTool; -use crate::protocol::AskForApproval; -use crate::protocol::SandboxPolicy; use crate::protocol::TokenUsage; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; use futures::Stream; use serde::Serialize; use std::borrow::Cow; -use std::fmt::Display; -use std::path::PathBuf; use std::pin::Pin; use std::task::Context; use std::task::Poll; @@ -23,62 +19,19 @@ use tokio::sync::mpsc; /// with this content. const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md"); -/// wraps environment context message in a tag for the model to parse more easily. -const ENVIRONMENT_CONTEXT_START: &str = "\n\n"; -const ENVIRONMENT_CONTEXT_END: &str = "\n\n"; - /// wraps user instructions message in a tag for the model to parse more easily. const USER_INSTRUCTIONS_START: &str = "\n\n"; const USER_INSTRUCTIONS_END: &str = "\n\n"; -#[derive(Debug, Clone)] -pub(crate) struct EnvironmentContext { - pub cwd: PathBuf, - pub approval_policy: AskForApproval, - pub sandbox_policy: SandboxPolicy, -} - -impl Display for EnvironmentContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!( - f, - "Current working directory: {}", - self.cwd.to_string_lossy() - )?; - writeln!(f, "Approval policy: {}", self.approval_policy)?; - writeln!(f, "Sandbox policy: {}", self.sandbox_policy)?; - - let network_access = match self.sandbox_policy.clone() { - SandboxPolicy::DangerFullAccess => "enabled", - SandboxPolicy::ReadOnly => "restricted", - SandboxPolicy::WorkspaceWrite { network_access, .. } => { - if network_access { - "enabled" - } else { - "restricted" - } - } - }; - writeln!(f, "Network access: {network_access}")?; - Ok(()) - } -} - -/// API request payload for a single model turn. +/// API request payload for a single model turn #[derive(Default, Debug, Clone)] pub struct Prompt { /// Conversation context input items. pub input: Vec, - /// Optional instructions from the user to amend to the built-in agent - /// instructions. - pub user_instructions: Option, + /// Whether to store response on server side (disable_response_storage = !store). pub store: bool, - /// A list of key-value pairs that will be added as a developer message - /// for the model to use - pub environment_context: Option, - /// Tools available to the model, including additional tools sourced from /// external MCP servers. pub tools: Vec, @@ -100,36 +53,19 @@ impl Prompt { Cow::Owned(sections.join("\n")) } - fn get_formatted_user_instructions(&self) -> Option { - self.user_instructions - .as_ref() - .map(|ui| format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}")) - } - - fn get_formatted_environment_context(&self) -> Option { - self.environment_context - .as_ref() - .map(|ec| format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}")) - } - pub(crate) fn get_formatted_input(&self) -> Vec { - let mut input_with_instructions = Vec::with_capacity(self.input.len() + 2); - if let Some(ec) = self.get_formatted_environment_context() { - input_with_instructions.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { text: ec }], - }); + self.input.clone() + } + + /// Creates a formatted user instructions message from a string + pub(crate) fn format_user_instructions_message(ui: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{USER_INSTRUCTIONS_START}{ui}{USER_INSTRUCTIONS_END}"), + }], } - if let Some(ui) = self.get_formatted_user_instructions() { - input_with_instructions.push(ResponseItem::Message { - id: None, - role: "user".to_string(), - content: vec![ContentItem::InputText { text: ui }], - }); - } - input_with_instructions.extend(self.input.clone()); - input_with_instructions } } @@ -259,7 +195,6 @@ mod tests { #[test] fn get_full_instructions_no_user_content() { let prompt = Prompt { - user_instructions: Some("custom instruction".to_string()), ..Default::default() }; let expected = format!("{BASE_INSTRUCTIONS}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"); diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2ae1db5b..bca5af43 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -38,7 +38,6 @@ use crate::apply_patch::convert_apply_patch_to_protocol; use crate::apply_patch::get_writable_roots; use crate::apply_patch::{self}; use crate::client::ModelClient; -use crate::client_common::EnvironmentContext; use crate::client_common::Prompt; use crate::client_common::ResponseEvent; use crate::config::Config; @@ -46,6 +45,7 @@ use crate::config_types::ReasoningEffort as ReasoningEffortConfig; use crate::config_types::ReasoningSummary as ReasoningSummaryConfig; use crate::config_types::ShellEnvironmentPolicy; use crate::conversation_history::ConversationHistory; +use crate::environment_context::EnvironmentContext; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::error::SandboxErr; @@ -437,6 +437,20 @@ impl Session { show_raw_agent_reasoning: config.show_raw_agent_reasoning, }); + // record the initial user instructions and environment context, regardless of whether we restored items. + if let Some(user_instructions) = sess.get_user_instructions().clone() { + sess.record_conversation_items(&[Prompt::format_user_instructions_message( + &user_instructions, + )]) + .await; + } + sess.record_conversation_items(&[ResponseItem::from(EnvironmentContext::new( + sess.get_cwd().to_path_buf(), + sess.get_approval_policy(), + sess.get_sandbox_policy().clone(), + ))]) + .await; + // Gather history metadata for SessionConfiguredEvent. let (history_log_id, history_entry_count) = crate::message_history::history_metadata(&config).await; @@ -473,6 +487,14 @@ impl Session { &self.cwd } + pub(crate) fn get_user_instructions(&self) -> Option { + self.user_instructions.clone() + } + + pub(crate) fn get_sandbox_policy(&self) -> &SandboxPolicy { + &self.sandbox_policy + } + fn resolve_path(&self, path: Option) -> PathBuf { path.as_ref() .map(PathBuf::from) @@ -1237,15 +1259,9 @@ async fn run_turn( let prompt = Prompt { input, - user_instructions: sess.user_instructions.clone(), store: !sess.disable_response_storage, tools, base_instructions_override: sess.base_instructions.clone(), - environment_context: Some(EnvironmentContext { - cwd: sess.cwd.clone(), - approval_policy: sess.approval_policy, - sandbox_policy: sess.sandbox_policy.clone(), - }), }; let mut retries = 0; @@ -1483,9 +1499,7 @@ async fn run_compact_task( let prompt = Prompt { input: turn_input, - user_instructions: None, store: !sess.disable_response_storage, - environment_context: None, tools: Vec::new(), base_instructions_override: Some(compact_instructions.clone()), }; diff --git a/codex-rs/core/src/config_types.rs b/codex-rs/core/src/config_types.rs index 291dcb64..cbbc6b49 100644 --- a/codex-rs/core/src/config_types.rs +++ b/codex-rs/core/src/config_types.rs @@ -78,8 +78,9 @@ pub enum HistoryPersistence { #[derive(Deserialize, Debug, Clone, PartialEq, Default)] pub struct Tui {} -#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize)] +#[derive(Deserialize, Debug, Clone, Copy, PartialEq, Default, Serialize, Display)] #[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] pub enum SandboxMode { #[serde(rename = "read-only")] #[default] diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs new file mode 100644 index 00000000..d31ddbc1 --- /dev/null +++ b/codex-rs/core/src/environment_context.rs @@ -0,0 +1,86 @@ +use serde::Deserialize; +use serde::Serialize; +use strum_macros::Display as DeriveDisplay; + +use crate::config_types::SandboxMode; +use crate::models::ContentItem; +use crate::models::ResponseItem; +use crate::protocol::AskForApproval; +use crate::protocol::SandboxPolicy; +use std::fmt::Display; +use std::path::PathBuf; + +/// wraps environment context message in a tag for the model to parse more easily. +pub(crate) const ENVIRONMENT_CONTEXT_START: &str = "\n"; +pub(crate) const ENVIRONMENT_CONTEXT_END: &str = ""; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)] +#[serde(rename_all = "kebab-case")] +#[strum(serialize_all = "kebab-case")] +pub enum NetworkAccess { + Restricted, + Enabled, +} +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename = "environment_context", rename_all = "snake_case")] +pub(crate) struct EnvironmentContext { + pub cwd: PathBuf, + pub approval_policy: AskForApproval, + pub sandbox_mode: SandboxMode, + pub network_access: NetworkAccess, +} + +impl EnvironmentContext { + pub fn new( + cwd: PathBuf, + approval_policy: AskForApproval, + sandbox_policy: SandboxPolicy, + ) -> Self { + Self { + cwd, + approval_policy, + sandbox_mode: match sandbox_policy { + SandboxPolicy::DangerFullAccess => SandboxMode::DangerFullAccess, + SandboxPolicy::ReadOnly => SandboxMode::ReadOnly, + SandboxPolicy::WorkspaceWrite { .. } => SandboxMode::WorkspaceWrite, + }, + network_access: match sandbox_policy { + SandboxPolicy::DangerFullAccess => NetworkAccess::Enabled, + SandboxPolicy::ReadOnly => NetworkAccess::Restricted, + SandboxPolicy::WorkspaceWrite { network_access, .. } => { + if network_access { + NetworkAccess::Enabled + } else { + NetworkAccess::Restricted + } + } + }, + } + } +} + +impl Display for EnvironmentContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Current working directory: {}", + self.cwd.to_string_lossy() + )?; + writeln!(f, "Approval policy: {}", self.approval_policy)?; + writeln!(f, "Sandbox mode: {}", self.sandbox_mode)?; + writeln!(f, "Network access: {}", self.network_access)?; + Ok(()) + } +} + +impl From for ResponseItem { + fn from(ec: EnvironmentContext) -> Self { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ContentItem::InputText { + text: format!("{ENVIRONMENT_CONTEXT_START}{ec}{ENVIRONMENT_CONTEXT_END}"), + }], + } + } +} diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index aeab4970..d19fbbdb 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -17,6 +17,7 @@ pub mod config; pub mod config_profile; pub mod config_types; mod conversation_history; +mod environment_context; pub mod error; pub mod exec; pub mod exec_env; diff --git a/codex-rs/core/tests/client.rs b/codex-rs/core/tests/client.rs index 1bcddf07..10c6c66f 100644 --- a/codex-rs/core/tests/client.rs +++ b/codex-rs/core/tests/client.rs @@ -373,11 +373,11 @@ async fn includes_user_instructions_message_in_request() { .contains("be nice") ); assert_message_role(&request_body["input"][0], "user"); - assert_message_starts_with(&request_body["input"][0], "\n\n"); - assert_message_ends_with(&request_body["input"][0], ""); + assert_message_starts_with(&request_body["input"][0], ""); + assert_message_ends_with(&request_body["input"][0], ""); assert_message_role(&request_body["input"][1], "user"); - assert_message_starts_with(&request_body["input"][1], "\n\n"); - assert_message_ends_with(&request_body["input"][1], ""); + assert_message_starts_with(&request_body["input"][1], ""); + assert_message_ends_with(&request_body["input"][1], ""); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] diff --git a/codex-rs/core/tests/prompt_caching.rs b/codex-rs/core/tests/prompt_caching.rs index 8df7ea35..0c2552ee 100644 --- a/codex-rs/core/tests/prompt_caching.rs +++ b/codex-rs/core/tests/prompt_caching.rs @@ -85,7 +85,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests assert_eq!(requests.len(), 2, "expected two POST requests"); let expected_env_text = format!( - "\n\nCurrent working directory: {}\nApproval policy: on-request\nSandbox policy: read-only\nNetwork access: restricted\n\n\n", + "\nCurrent working directory: {}\nApproval policy: on-request\nSandbox mode: read-only\nNetwork access: restricted\n", cwd.path().to_string_lossy() ); let expected_ui_text = @@ -113,7 +113,7 @@ async fn prefixes_context_and_instructions_once_and_consistently_across_requests let body1 = requests[0].body_json::().unwrap(); assert_eq!( body1["input"], - serde_json::json!([expected_env_msg, expected_ui_msg, expected_user_message_1]) + serde_json::json!([expected_ui_msg, expected_env_msg, expected_user_message_1]) ); let expected_user_message_2 = serde_json::json!({